00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024 #include "FernImageDetector.h"
00025
00026 namespace alvar
00027 {
00028
00029 #define PATCH_SIZE 31
00030 #define PYR_LEVELS 1
00031 #define N_VIEWS 5000
00032 #define N_PTS_TO_FIND 400
00033 #define N_PTS_TO_TEACH 200
00034 #define SIZE_BLUR 13
00035
00036 #define N_STRUCTS 50
00037 #define STRUCT_SIZE 11
00038 #define SIGNATURE_SIZE 400
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060 FernClassifierWrapper::FernClassifierWrapper()
00061 : FernClassifier()
00062 {
00063 }
00064
00065 FernClassifierWrapper::FernClassifierWrapper(const FileNode &fileNode)
00066 : FernClassifier(fileNode)
00067 {
00068 }
00069
00070 FernClassifierWrapper::FernClassifierWrapper(const vector<vector<Point2f> > &points,
00071 const vector<Mat> &referenceImages,
00072 const vector<vector<int> > &labels,
00073 int _nclasses, int _patchSize,
00074 int _signatureSize, int _nstructs,
00075 int _structSize, int _nviews,
00076 int _compressionMethod,
00077 const PatchGenerator &patchGenerator)
00078 : FernClassifier(points, referenceImages, labels, _nclasses, _patchSize, _signatureSize,
00079 _nstructs, _structSize, _nviews, _compressionMethod, patchGenerator)
00080 {
00081 }
00082
00083 FernClassifierWrapper::~FernClassifierWrapper()
00084 {
00085 }
00086
00087 void FernClassifierWrapper::readBinary(std::fstream &stream)
00088 {
00089 clear();
00090
00091 stream.read((char *)&verbose, sizeof(verbose));
00092 stream.read((char *)&nstructs, sizeof(nstructs));
00093 stream.read((char *)&structSize, sizeof(structSize));
00094 stream.read((char *)&nclasses, sizeof(nclasses));
00095 stream.read((char *)&signatureSize, sizeof(signatureSize));
00096 stream.read((char *)&compressionMethod, sizeof(compressionMethod));
00097 stream.read((char *)&leavesPerStruct, sizeof(leavesPerStruct));
00098 stream.read((char *)&patchSize.width, sizeof(patchSize.width));
00099 stream.read((char *)&patchSize.height, sizeof(patchSize.height));
00100
00101 std::vector<Feature>::size_type featuresSize;
00102 stream.read((char *)&featuresSize, sizeof(featuresSize));
00103 features.reserve(featuresSize);
00104 unsigned int featuresValue;
00105 Feature value;
00106 for (std::vector<Feature>::size_type i = 0; i < featuresSize; ++i) {
00107 stream.read((char *)&featuresValue, sizeof(featuresValue));
00108 value.x1 = (uchar)(featuresValue % patchSize.width);
00109 value.y1 = (uchar)(featuresValue / patchSize.width);
00110 stream.read((char *)&featuresValue, sizeof(featuresValue));
00111 value.x2 = (uchar)(featuresValue % patchSize.width);
00112 value.y2 = (uchar)(featuresValue / patchSize.width);
00113 features.push_back(value);
00114 }
00115
00116
00117
00118
00119
00120
00121
00122
00123
00124
00125
00126
00127
00128 std::vector<float>::size_type posteriorsSize;
00129 stream.read((char *)&posteriorsSize, sizeof(posteriorsSize));
00130 posteriors.reserve(posteriorsSize);
00131 float posteriorsValue;
00132 for (std::vector<float>::size_type i = 0; i < posteriorsSize; ++i) {
00133 stream.read((char *)&posteriorsValue, sizeof(posteriorsValue));
00134 posteriors.push_back(posteriorsValue);
00135 }
00136 }
00137
00138 void FernClassifierWrapper::writeBinary(std::fstream &stream) const
00139 {
00140 stream.write((char *)&verbose, sizeof(verbose));
00141 stream.write((char *)&nstructs, sizeof(nstructs));
00142 stream.write((char *)&structSize, sizeof(structSize));
00143 stream.write((char *)&nclasses, sizeof(nclasses));
00144 stream.write((char *)&signatureSize, sizeof(signatureSize));
00145 stream.write((char *)&compressionMethod, sizeof(compressionMethod));
00146 stream.write((char *)&leavesPerStruct, sizeof(leavesPerStruct));
00147 stream.write((char *)&patchSize.width, sizeof(patchSize.width));
00148 stream.write((char *)&patchSize.height, sizeof(patchSize.height));
00149
00150 std::vector<Feature>::size_type featuresSize = features.size();
00151 stream.write((char *)&featuresSize, sizeof(featuresSize));
00152 unsigned int featuresValue;
00153 for (std::vector<Feature>::const_iterator itr = features.begin(); itr != features.end(); ++itr) {
00154 featuresValue = itr->y1 * patchSize.width + itr->x1;
00155 stream.write((char *)&featuresValue, sizeof(featuresValue));
00156 featuresValue = itr->y2 * patchSize.width + itr->x2;
00157 stream.write((char *)&featuresValue, sizeof(featuresValue));
00158 }
00159
00160
00161
00162
00163
00164
00165
00166
00167
00168
00169 std::vector<float>::size_type posteriorsSize = posteriors.size();
00170 stream.write((char *)&posteriorsSize, sizeof(posteriorsSize));
00171 for (std::vector<float>::const_iterator itr = posteriors.begin(); itr != posteriors.end(); ++itr) {
00172 stream.write((char *)&*itr, sizeof(*itr));
00173 }
00174 }
00175
00176 FernImageDetector::FernImageDetector(const bool visualize)
00177 : mPatchGenerator(0, 256, 13, true, 0.10, 1.0, -CV_PI*1.0, CV_PI*1.0, -CV_PI*0.0, CV_PI*0.0)
00178 , mLDetector(3, 20, PYR_LEVELS, N_VIEWS, PATCH_SIZE, 2)
00179 , mClassifier()
00180 , mKeyPoints()
00181 , mImagePoints()
00182 , mModelPoints()
00183 , mVisualize(visualize)
00184 , mObjects()
00185 , mSize()
00186 , mCorrespondences()
00187 , mHomography()
00188 , mInlierRatio(0)
00189 {
00190
00191 mClassifier.resize(1);
00192 }
00193
00194 FernImageDetector::~FernImageDetector()
00195 {
00196 }
00197
00198 void FernImageDetector::imagePoints(vector<CvPoint2D64f> &points)
00199 {
00200 points.clear();
00201 for(size_t i = 0; i < mImagePoints.size(); ++i) {
00202 points.push_back(cvPoint2D64f(mImagePoints[i].x, mImagePoints[i].y));
00203 }
00204 }
00205
00206 void FernImageDetector::modelPoints(vector<CvPoint3D64f> &points, bool normalize)
00207 {
00208 points.clear();
00209
00210
00211 for(size_t i = 0; i < mModelPoints.size(); ++i) {
00212 CvPoint3D64f pt = cvPoint3D64f(mModelPoints[i].x, mModelPoints[i].y, 0.0);
00213 if(normalize) {
00214
00215
00216
00217
00218 pt.x -= mSize.width/2;
00219 pt.y -= mSize.height/2;
00220 pt.x /= mSize.width*0.10;
00221 pt.y /= mSize.width*0.10;
00222 }
00223 points.push_back(pt);
00224 }
00225 }
00226
00227 cv::Size FernImageDetector::size()
00228 {
00229 return mSize;
00230 }
00231
00232 cv::Mat FernImageDetector::homography()
00233 {
00234 return mHomography;
00235 }
00236
00237 double FernImageDetector::inlierRatio()
00238 {
00239 return mInlierRatio;
00240 }
00241
00242 void FernImageDetector::train(const std::string &filename)
00243 {
00244 Mat object = imread(filename.c_str(), CV_LOAD_IMAGE_GRAYSCALE);
00245 train(object);
00246 }
00247
00248 void FernImageDetector::train(Mat &object)
00249 {
00250 mObjects.push_back(object.clone());
00251
00252 Mat blurredObject;
00253 GaussianBlur(mObjects[0], blurredObject, Size(SIZE_BLUR, SIZE_BLUR), 0, 0);
00254
00255 if(mVisualize) {
00256 cvNamedWindow("Object", 1);
00257 imshow("Object", blurredObject);
00258 cv::waitKey(2000);
00259 }
00260
00261
00262
00263 mLDetector.getMostStable2D(blurredObject, mKeyPoints, N_PTS_TO_TEACH, mPatchGenerator);
00264
00265 if(mVisualize) {
00266 for(int i = 0; i < (int)mKeyPoints.size(); ++i)
00267 circle(blurredObject, mKeyPoints[i].pt, int(mKeyPoints[i].size/10), CV_RGB(64,64,64));
00268
00269 imshow("Object", blurredObject);
00270 cv::waitKey(2000);
00271 }
00272
00273 mClassifier[0].trainFromSingleView(blurredObject,
00274 mKeyPoints,
00275 PATCH_SIZE,
00276 SIGNATURE_SIZE,
00277 N_STRUCTS,
00278 STRUCT_SIZE,
00279 N_VIEWS,
00280 FernClassifier::COMPRESSION_NONE,
00281 mPatchGenerator);
00282
00283 mSize = cv::Size(object.cols, object.rows);
00284 }
00285
00286 void FernImageDetector::findFeatures(Mat &object, bool planeAssumption)
00287 {
00288
00289
00290 vector<KeyPoint> keypoints;
00291 vector<Mat> objpyr;
00292
00293 GaussianBlur(object, object, Size(SIZE_BLUR, SIZE_BLUR), 0, 0);
00294
00295 mLDetector.nOctaves = 1;
00296 mLDetector(object, keypoints, N_PTS_TO_FIND);
00297
00298 int m = mKeyPoints.size();
00299 int n = keypoints.size();
00300 vector<int> bestMatches(m, -1);
00301 vector<float> maxLogProb(m, -FLT_MAX);
00302 vector<float> signature;
00303 vector<int> pairs;
00304
00305 for(size_t i = 0; i < keypoints.size(); ++i) {
00306 Point2f pt = keypoints[i].pt;
00307
00308 int k = mClassifier[0](object , pt, signature);
00309 if(k >= 0 && (bestMatches[k] < 0 || signature[k] > maxLogProb[k])) {
00310 maxLogProb[k] = signature[k];
00311 bestMatches[k] = i;
00312 }
00313 }
00314
00315 for(int i = 0; i < m; i++ )
00316 if(bestMatches[i] >= 0) {
00317 pairs.push_back(i);
00318 pairs.push_back(bestMatches[i]);
00319 }
00320
00321 mCorrespondences = Mat(mObjects[0].rows + object.rows, std::max( mObjects[0].cols, object.cols), CV_8UC3);
00322 mCorrespondences = Scalar(0.);
00323 Mat part(mCorrespondences, Rect(0, 0, mObjects[0].cols, mObjects[0].rows));
00324 cvtColor(mObjects[0], part, CV_GRAY2BGR);
00325 part = Mat(mCorrespondences, Rect(0, mObjects[0].rows, object.cols, object.rows));
00326 cvtColor(object, part, CV_GRAY2BGR);
00327
00328 for(int i = 0; i < (int)keypoints.size(); ++i)
00329 circle(object, keypoints[i].pt, int(keypoints[i].size/5), CV_RGB(64,64,64));
00330
00331 vector<Point2f> fromPt, toPt;
00332 vector<uchar> mask;
00333 for(int i = 0; i < m; ++i)
00334 if( bestMatches[i] >= 0 ){
00335 fromPt.push_back(mKeyPoints[i].pt);
00336 toPt.push_back(keypoints[bestMatches[i]].pt);
00337 }
00338
00339 static double valmin = 1.0;
00340 static double valmax = 0.0;
00341 mModelPoints.clear();
00342 mImagePoints.clear();
00343 int n_inliers = 0;
00344
00345 if(planeAssumption && fromPt.size() > 8) {
00346 cv::Mat H = cv::findHomography(Mat(fromPt), Mat(toPt), mask, RANSAC, 20);
00347 mHomography = H;
00348
00349
00350 for(size_t i = 0, j = 0; i < (int)pairs.size(); i += 2, ++j) {
00351 if(mask[j]) {
00352 cv::Point2f pi(keypoints[pairs[i+1]].pt);
00353 cv::Point2f pw(mKeyPoints[pairs[i]].pt);
00354 mModelPoints.push_back(pw);
00355 mImagePoints.push_back(pi);
00356 line(mCorrespondences, mKeyPoints[pairs[i]].pt,
00357 keypoints[pairs[i+1]].pt + Point2f(0.0,(float)mObjects[0].rows),
00358 Scalar(i*i%244,100-i*100%30,i*i-50*i));
00359 n_inliers++;
00360 }
00361 }
00362 } else {
00363 for(size_t i = 0, j = 0; i < (int)pairs.size(); i += 2, ++j) {
00364 cv::Point2f pi(keypoints[pairs[i+1]].pt);
00365 cv::Point2f pw(mKeyPoints[pairs[i]].pt);
00366 mModelPoints.push_back(pw);
00367 mImagePoints.push_back(pi);
00368 line(mCorrespondences, mKeyPoints[pairs[i]].pt,
00369 keypoints[pairs[i+1]].pt + Point2f(0.0,(float)mObjects[0].rows),
00370 Scalar(i*i%244,100-i*100%30,i*i-50*i));
00371 }
00372 }
00373
00374
00375 double val = 0.0;
00376 if(fromPt.size()>0) val = 1.*n_inliers/fromPt.size();
00377 if(val > valmax) valmax = val;
00378 if(val < valmin) valmin = val;
00379
00380 mInlierRatio = val;
00381
00382 if (mVisualize) {
00383 cvNamedWindow("Matches", 1);
00384 imshow("Matches", mCorrespondences);
00385 cv::waitKey(1);
00386 }
00387 }
00388
00389 bool FernImageDetector::read(const std::string &filename, const bool binary)
00390 {
00391 if (binary) {
00392 std::fstream bs(filename.c_str(), std::fstream::in | std::fstream::binary);
00393
00394 if (!bs.is_open()) {
00395 return false;
00396 }
00397
00398 bs.read((char *)&mLDetector.radius, sizeof(mLDetector.radius));
00399 bs.read((char *)&mLDetector.threshold, sizeof(mLDetector.threshold));
00400 bs.read((char *)&mLDetector.nOctaves, sizeof(mLDetector.nOctaves));
00401 bs.read((char *)&mLDetector.nViews, sizeof(mLDetector.nViews));
00402 bs.read((char *)&mLDetector.verbose, sizeof(mLDetector.verbose));
00403 bs.read((char *)&mLDetector.baseFeatureSize, sizeof(mLDetector.baseFeatureSize));
00404 bs.read((char *)&mLDetector.clusteringDistance, sizeof(mLDetector.clusteringDistance));
00405
00406 mClassifier[0].readBinary(bs);
00407
00408 std::vector<float>::size_type size;
00409 bs.read((char *)&size, sizeof(size));
00410 mKeyPoints.reserve(size);
00411 KeyPoint value;
00412 for (std::vector<float>::size_type i = 0; i < size; ++i) {
00413 bs.read((char *)&value.pt.x, sizeof(value.pt.x));
00414 bs.read((char *)&value.pt.y, sizeof(value.pt.y));
00415 bs.read((char *)&value.size, sizeof(value.size));
00416 bs.read((char *)&value.angle, sizeof(value.angle));
00417 bs.read((char *)&value.response, sizeof(value.response));
00418 bs.read((char *)&value.octave, sizeof(value.octave));
00419 bs.read((char *)&value.class_id, sizeof(value.class_id));
00420 mKeyPoints.push_back(value);
00421 }
00422
00423 bs.read((char *)&mSize.width, sizeof(mSize.width));
00424 bs.read((char *)&mSize.height, sizeof(mSize.height));
00425
00426 std::vector<Mat>::size_type objectsSize;
00427 bs.read((char *)&objectsSize, sizeof(objectsSize));
00428 mObjects.reserve(objectsSize);
00429 int rows;
00430 int cols;
00431 int type;
00432 for (std::vector<Mat>::size_type i = 0; i < objectsSize; ++i) {
00433 bs.read((char *)&rows, sizeof(rows));
00434 bs.read((char *)&cols, sizeof(cols));
00435 bs.read((char *)&type, sizeof(type));
00436 Mat objectsValue(rows, cols, type);
00437 bs.read((char *)objectsValue.data, objectsValue.elemSize() * objectsValue.total());
00438 mObjects.push_back(objectsValue);
00439 }
00440
00441 bs.close();
00442 }
00443 else {
00444 FileStorage fs(filename, FileStorage::READ);
00445
00446 if (!fs.isOpened()) {
00447 return false;
00448 }
00449
00450 FileNode node = fs.getFirstTopLevelNode();
00451 std::cout << "loaded file" << std::endl;
00452 cv::read(node["model_points"], mKeyPoints);
00453 std::cout << "loaded model points" << std::endl;
00454 mClassifier[0].read(node["fern_classifier"]);
00455 std::cout << "loaded classifier" << std::endl;
00456 }
00457
00458 return true;
00459 }
00460
00461 bool FernImageDetector::write(const std::string &filename, const bool binary)
00462 {
00463 if (binary) {
00464 std::fstream bs(filename.c_str(), std::fstream::out | std::fstream::binary);
00465
00466 if (!bs.is_open()) {
00467 return false;
00468 }
00469
00470 bs.write((char *)&mLDetector.radius, sizeof(mLDetector.radius));
00471 bs.write((char *)&mLDetector.threshold, sizeof(mLDetector.threshold));
00472 bs.write((char *)&mLDetector.nOctaves, sizeof(mLDetector.nOctaves));
00473 bs.write((char *)&mLDetector.nViews, sizeof(mLDetector.nViews));
00474 bs.write((char *)&mLDetector.verbose, sizeof(mLDetector.verbose));
00475 bs.write((char *)&mLDetector.baseFeatureSize, sizeof(mLDetector.baseFeatureSize));
00476 bs.write((char *)&mLDetector.clusteringDistance, sizeof(mLDetector.clusteringDistance));
00477
00478 mClassifier[0].writeBinary(bs);
00479
00480 std::vector<float>::size_type size = mKeyPoints.size();
00481 bs.write((char *)&size, sizeof(size));
00482 for (std::vector<KeyPoint>::const_iterator itr = mKeyPoints.begin(); itr != mKeyPoints.end(); ++itr) {
00483 bs.write((char *)&itr->pt.x, sizeof(itr->pt.x));
00484 bs.write((char *)&itr->pt.y, sizeof(itr->pt.y));
00485 bs.write((char *)&itr->size, sizeof(itr->size));
00486 bs.write((char *)&itr->angle, sizeof(itr->angle));
00487 bs.write((char *)&itr->response, sizeof(itr->response));
00488 bs.write((char *)&itr->octave, sizeof(itr->octave));
00489 bs.write((char *)&itr->class_id, sizeof(itr->class_id));
00490 }
00491
00492 bs.write((char *)&mSize.width, sizeof(mSize.width));
00493 bs.write((char *)&mSize.height, sizeof(mSize.height));
00494
00495 std::vector<Mat>::size_type objectsSize = mObjects.size();
00496 bs.write((char *)&objectsSize, sizeof(objectsSize));
00497 for (std::vector<Mat>::const_iterator itr = mObjects.begin(); itr != mObjects.end(); ++itr) {
00498 bs.write((char *)&itr->rows, sizeof(itr->rows));
00499 bs.write((char *)&itr->cols, sizeof(itr->cols));
00500 int type = itr->type();
00501 bs.write((char *)&type, sizeof(type));
00502 bs.write((char *)itr->data, itr->elemSize() * itr->total());
00503 }
00504
00505 bs.close();
00506 }
00507 else {
00508 FileStorage fs(filename, FileStorage::WRITE);
00509
00510 if (!fs.isOpened()) {
00511 return false;
00512 }
00513
00514 WriteStructContext ws(fs, "fern_image_detector", CV_NODE_MAP);
00515 cv::write(fs, "model_points", mKeyPoints);
00516 mClassifier[0].write(fs, "fern_classifier");
00517 }
00518
00519 return true;
00520 }
00521
00522 }