database.cpp
Go to the documentation of this file.
00001 #include "vocabulary_tree/database.h"
00002 #include <boost/accumulators/accumulators.hpp>
00003 #include <boost/accumulators/statistics/tail.hpp>
00004 #include <cmath>
00005 #include <fstream>
00006 #include <stdexcept>
00007 #include <boost/format.hpp>
00008 
00009 namespace vt {
00010 
00011 Database::Database(uint32_t num_words)
00012   : word_files_(num_words),
00013     word_weights_(num_words, 1.0f)
00014 {
00015 }
00016 
00017 DocId Database::insert(const std::vector<Word>& document)
00018 {
00020   DocId doc_id = database_vectors_.size();
00021 
00022   // For each word, retrieve its inverted file and increment the count for doc_id.
00023   for (std::vector<Word>::const_iterator it = document.begin(), end = document.end(); it != end; ++it) {
00024     Word word = *it;
00025     InvertedFile& file = word_files_[word];
00026     if (file.empty() || file.back().id != doc_id)
00027       file.push_back(WordFrequency(doc_id, 1));
00028     else
00029       file.back().count++;
00030   }
00031 
00032   // Precompute the document vector to compare queries against.
00033   database_vectors_.resize(doc_id + 1);
00034   computeVector(document, database_vectors_.back());
00035   
00036   return doc_id;
00037 }
00038 
00039 void Database::find(const std::vector<Word>& document, size_t N, std::vector<Match>& matches) const
00040 {
00041   DocumentVector query;
00042   computeVector(document, query);
00043 
00044   // Accumulate the best N matches
00045   using namespace boost::accumulators;
00046   typedef tag::tail<left> bestN_tag;
00047   accumulator_set<Match, features<bestN_tag> > acc(bestN_tag::cache_size = N);
00048 
00050   for (DocId i = 0; i < (DocId)database_vectors_.size(); ++i) {
00051     float distance = sparseDistance(query, database_vectors_[i]);
00052     acc( Match(i, distance) );
00053   }
00054 
00055   extractor<bestN_tag> bestN;
00056   matches.resize( std::min(N, database_vectors_.size()) );
00057   std::copy(bestN(acc).begin(), bestN(acc).end(), matches.begin());
00058 }
00059 
00060 DocId Database::findAndInsert(const std::vector<Word>& document, size_t N, std::vector<Match>& matches)
00061 {
00063   find(document, N, matches);
00064   return insert(document);
00065 }
00066 
00067 void Database::computeTfIdfWeights(float default_weight)
00068 {
00069   float N = (float)database_vectors_.size();
00070   size_t num_words = word_files_.size();
00071   for (size_t i = 0; i < num_words; ++i) {
00072     size_t Ni = word_files_[i].size();
00073     if (Ni != 0)
00074       word_weights_[i] = std::log(N / Ni);
00075     else
00076       word_weights_[i] = default_weight;
00077   }
00078 }
00079 
00080 void Database::saveWeights(const std::string& file) const
00081 {
00082   std::ofstream out(file.c_str(), std::ios_base::binary);
00083   uint32_t num_words = word_weights_.size();
00084   out.write((char*)(&num_words), sizeof(uint32_t));
00085   out.write((char*)(&word_weights_[0]), num_words * sizeof(float));
00086 }
00087 
00088 void Database::loadWeights(const std::string& file)
00089 {
00090   std::ifstream in;
00091   in.exceptions(std::ifstream::eofbit | std::ifstream::failbit | std::ifstream::badbit);
00092 
00093   try {
00094     in.open(file.c_str(), std::ios_base::binary);
00095     uint32_t num_words = 0;
00096     in.read((char*)(&num_words), sizeof(uint32_t));
00097     word_files_.resize(num_words); // Inverted files start out empty
00098     word_weights_.resize(num_words);
00099     in.read((char*)(&word_weights_[0]), num_words * sizeof(float));
00100   }
00101   catch (std::ifstream::failure& e) {
00102     throw std::runtime_error( (boost::format("Failed to load vocabulary weights file '%s'") % file).str() );
00103   }
00104 }
00105 
00106 void Database::computeVector(const std::vector<Word>& document, DocumentVector& v) const
00107 {
00108   for (std::vector<Word>::const_iterator it = document.begin(), end = document.end(); it != end; ++it) {
00109     Word word = *it;
00110     v[word] += word_weights_[word];
00111   }
00112   normalize(v);
00113 }
00114 
00115 void Database::normalize(DocumentVector& v)
00116 {
00117   float sum = 0.0f;
00118   for (DocumentVector::iterator i = v.begin(), ie = v.end(); i != ie; ++i)
00119     sum += i->second;
00120   float inv_sum = 1.0f / sum;
00121   for (DocumentVector::iterator i = v.begin(), ie = v.end(); i != ie; ++i)
00122     i->second *= inv_sum;
00123 }
00124 
00125 float Database::sparseDistance(const DocumentVector& v1, const DocumentVector& v2)
00126 {
00127   float distance = 0.0f;
00128   DocumentVector::const_iterator i1 = v1.begin(), i1e = v1.end();
00129   DocumentVector::const_iterator i2 = v2.begin(), i2e = v2.end();
00130 
00131   while (i1 != i1e && i2 != i2e) {
00132     if (i2->first < i1->first) {
00133       distance += i2->second;
00134       ++i2;
00135     }
00136     else if (i1->first < i2->first) {
00137       distance += i1->second;
00138       ++i1;
00139     }
00140     else {
00141       distance += fabs(i1->second - i2->second);
00142       ++i1; ++i2;
00143     }
00144   }
00145 
00146   while (i1 != i1e) {
00147     distance += i1->second;
00148     ++i1;
00149   }
00150 
00151   while (i2 != i2e) {
00152     distance += i2->second;
00153     ++i2;
00154   }
00155   
00156   return distance;
00157 }
00158 
00159 } //namespace vt


vocabulary_tree
Author(s): Patrick Mihelich
autogenerated on Thu Jan 2 2014 12:12:26