00001 #include "vocabulary_tree/database.h"
00002 #include <boost/accumulators/accumulators.hpp>
00003 #include <boost/accumulators/statistics/tail.hpp>
00004 #include <cmath>
00005 #include <fstream>
00006 #include <stdexcept>
00007 #include <boost/format.hpp>
00008
00009 namespace vt {
00010
00011 Database::Database(uint32_t num_words)
00012 : word_files_(num_words),
00013 word_weights_(num_words, 1.0f)
00014 {
00015 }
00016
00017 DocId Database::insert(const std::vector<Word>& document)
00018 {
00020 DocId doc_id = database_vectors_.size();
00021
00022
00023 for (std::vector<Word>::const_iterator it = document.begin(), end = document.end(); it != end; ++it) {
00024 Word word = *it;
00025 InvertedFile& file = word_files_[word];
00026 if (file.empty() || file.back().id != doc_id)
00027 file.push_back(WordFrequency(doc_id, 1));
00028 else
00029 file.back().count++;
00030 }
00031
00032
00033 database_vectors_.resize(doc_id + 1);
00034 computeVector(document, database_vectors_.back());
00035
00036 return doc_id;
00037 }
00038
00039 void Database::find(const std::vector<Word>& document, size_t N, std::vector<Match>& matches) const
00040 {
00041 DocumentVector query;
00042 computeVector(document, query);
00043
00044
00045 using namespace boost::accumulators;
00046 typedef tag::tail<left> bestN_tag;
00047 accumulator_set<Match, features<bestN_tag> > acc(bestN_tag::cache_size = N);
00048
00050 for (DocId i = 0; i < (DocId)database_vectors_.size(); ++i) {
00051 float distance = sparseDistance(query, database_vectors_[i]);
00052 acc( Match(i, distance) );
00053 }
00054
00055 extractor<bestN_tag> bestN;
00056 matches.resize( std::min(N, database_vectors_.size()) );
00057 std::copy(bestN(acc).begin(), bestN(acc).end(), matches.begin());
00058 }
00059
00060 DocId Database::findAndInsert(const std::vector<Word>& document, size_t N, std::vector<Match>& matches)
00061 {
00063 find(document, N, matches);
00064 return insert(document);
00065 }
00066
00067 void Database::computeTfIdfWeights(float default_weight)
00068 {
00069 float N = (float)database_vectors_.size();
00070 size_t num_words = word_files_.size();
00071 for (size_t i = 0; i < num_words; ++i) {
00072 size_t Ni = word_files_[i].size();
00073 if (Ni != 0)
00074 word_weights_[i] = std::log(N / Ni);
00075 else
00076 word_weights_[i] = default_weight;
00077 }
00078 }
00079
00080 void Database::saveWeights(const std::string& file) const
00081 {
00082 std::ofstream out(file.c_str(), std::ios_base::binary);
00083 uint32_t num_words = word_weights_.size();
00084 out.write((char*)(&num_words), sizeof(uint32_t));
00085 out.write((char*)(&word_weights_[0]), num_words * sizeof(float));
00086 }
00087
00088 void Database::loadWeights(const std::string& file)
00089 {
00090 std::ifstream in;
00091 in.exceptions(std::ifstream::eofbit | std::ifstream::failbit | std::ifstream::badbit);
00092
00093 try {
00094 in.open(file.c_str(), std::ios_base::binary);
00095 uint32_t num_words = 0;
00096 in.read((char*)(&num_words), sizeof(uint32_t));
00097 word_files_.resize(num_words);
00098 word_weights_.resize(num_words);
00099 in.read((char*)(&word_weights_[0]), num_words * sizeof(float));
00100 }
00101 catch (std::ifstream::failure& e) {
00102 throw std::runtime_error( (boost::format("Failed to load vocabulary weights file '%s'") % file).str() );
00103 }
00104 }
00105
00106 void Database::computeVector(const std::vector<Word>& document, DocumentVector& v) const
00107 {
00108 for (std::vector<Word>::const_iterator it = document.begin(), end = document.end(); it != end; ++it) {
00109 Word word = *it;
00110 v[word] += word_weights_[word];
00111 }
00112 normalize(v);
00113 }
00114
00115 void Database::normalize(DocumentVector& v)
00116 {
00117 float sum = 0.0f;
00118 for (DocumentVector::iterator i = v.begin(), ie = v.end(); i != ie; ++i)
00119 sum += i->second;
00120 float inv_sum = 1.0f / sum;
00121 for (DocumentVector::iterator i = v.begin(), ie = v.end(); i != ie; ++i)
00122 i->second *= inv_sum;
00123 }
00124
00125 float Database::sparseDistance(const DocumentVector& v1, const DocumentVector& v2)
00126 {
00127 float distance = 0.0f;
00128 DocumentVector::const_iterator i1 = v1.begin(), i1e = v1.end();
00129 DocumentVector::const_iterator i2 = v2.begin(), i2e = v2.end();
00130
00131 while (i1 != i1e && i2 != i2e) {
00132 if (i2->first < i1->first) {
00133 distance += i2->second;
00134 ++i2;
00135 }
00136 else if (i1->first < i2->first) {
00137 distance += i1->second;
00138 ++i1;
00139 }
00140 else {
00141 distance += fabs(i1->second - i2->second);
00142 ++i1; ++i2;
00143 }
00144 }
00145
00146 while (i1 != i1e) {
00147 distance += i1->second;
00148 ++i1;
00149 }
00150
00151 while (i2 != i2e) {
00152 distance += i2->second;
00153 ++i2;
00154 }
00155
00156 return distance;
00157 }
00158
00159 }