00001 #include <vocabulary_tree/vocabulary_tree.h>
00002 #include <vocabulary_tree/database.h>
00003 #include <boost/lexical_cast.hpp>
00004 #include <map>
00005 #include <cstdio>
00006 #include <fstream>
00007 
00008 static const unsigned int DIMENSION = 176;
00009 
00010 int main(int argc, char** argv)
00011 {
00012   if (argc < 5) {
00013     printf("Usage: %s vocab.tree signatures.dat objects.dat output.weights [NUM_SIGS]\n", argv[0]);
00014     return 0;
00015   }
00016 
00017   
00018   typedef Eigen::Matrix<float, 1, DIMENSION> Feature;
00019   vt::VocabularyTree<Feature> tree(argv[1]);
00020 
00021   
00022   std::ifstream sig_is(argv[2], std::ios::binary);
00023   int length, num_sigs;
00024   if (argc == 5) {
00025     sig_is.seekg(0, std::ios::end);
00026     length = sig_is.tellg();
00027     num_sigs = length / (DIMENSION * sizeof(float));
00028     sig_is.seekg(0, std::ios::beg);
00029   }
00030   else {
00031     num_sigs = boost::lexical_cast<int>(argv[5]);
00032     length = num_sigs * DIMENSION * sizeof(float);
00033   }
00034   printf("Training from %d descriptors\n", num_sigs);
00035   printf("Data length = %d\n", length);
00036 
00037   
00038   typedef std::vector<Feature, Eigen::aligned_allocator<Feature> > FeatureVector;
00039   FeatureVector features(num_sigs);
00040   sig_is.read((char*)&features[0], length);
00041   printf("Done reading in descriptors\n");
00042 
00043   
00044   std::ifstream obj_is(argv[3], std::ios::binary);
00045   std::vector<unsigned int> object_ids(num_sigs);
00046   obj_is.read((char*)&object_ids[0], num_sigs * sizeof(unsigned int));
00047 
00048   
00049   typedef std::map<unsigned int, vt::Document> DocumentMap;
00050   DocumentMap documents;
00051   for (int i = 0; i < num_sigs; ++i) {
00052     documents[ object_ids[i] ].push_back( tree.quantize(features[i]) );
00053   }
00054   
00055   
00056   vt::Database db(tree.words());
00057   for (DocumentMap::const_iterator i = documents.begin(), ie = documents.end(); i != ie; ++i) {
00058     db.insert(i->second);
00059   }
00060 
00061   
00062   db.computeTfIdfWeights();
00063   db.saveWeights(argv[4]);
00064   
00065   return 0;
00066 }