00001 #ifndef VOCABULARY_TREE_VOCABULARY_TREE_H
00002 #define VOCABULARY_TREE_VOCABULARY_TREE_H
00003
00004 #include "vocabulary_tree/distance.h"
00005 #include "vocabulary_tree/feature_allocator.h"
00006 #include <stdint.h>
00007 #include <vector>
00008 #include <cassert>
00009 #include <limits>
00010 #include <fstream>
00011 #include <stdexcept>
00012 #include <boost/format.hpp>
00013
00014 namespace vt {
00015
00016 typedef int32_t Word;
00017
00030 template<class Feature, class Distance = distance::L2<Feature>,
00031 class FeatureAllocator = typename DefaultAllocator<Feature>::type>
00032 class VocabularyTree
00033 {
00034 public:
00042 VocabularyTree(Distance d = Distance());
00043
00050 VocabularyTree(const std::string& file, Distance d = Distance());
00051
00053 Word quantize(const Feature& f) const;
00054
00056 uint32_t levels() const;
00058 uint32_t splits() const;
00060 uint32_t words() const;
00061
00063 void clear();
00064
00066 void save(const std::string& file) const;
00068 void load(const std::string& file);
00069
00070 protected:
00071 typedef typename Distance::result_type distance_type;
00072
00073 std::vector<Feature, FeatureAllocator> centers_;
00074 std::vector<uint8_t> valid_centers_;
00075 Distance distance_;
00076
00077 uint32_t k_;
00078 uint32_t levels_;
00079 uint32_t num_words_;
00080 uint32_t word_start_;
00081
00082 bool initialized() const { return num_words_ != 0; }
00083
00084 void setNodeCounts();
00085 };
00086
00087
00088 template<class Feature, class Distance, class FeatureAllocator>
00089 VocabularyTree<Feature, Distance, FeatureAllocator>::VocabularyTree(Distance d)
00090 : distance_(d), k_(0), levels_(0), num_words_(0), word_start_(0)
00091 {
00092 }
00093
00094 template<class Feature, class Distance, class FeatureAllocator>
00095 VocabularyTree<Feature, Distance, FeatureAllocator>::VocabularyTree(const std::string& file, Distance d)
00096 : distance_(d), k_(0), levels_(0), num_words_(0), word_start_(0)
00097 {
00098 load(file);
00099 }
00100
00101 template<class Feature, class Distance, class FeatureAllocator>
00102 Word VocabularyTree<Feature, Distance, FeatureAllocator>::quantize(const Feature& f) const
00103 {
00104 assert( initialized() );
00105
00106 int32_t index = -1;
00107 for (unsigned level = 0; level < levels_; ++level) {
00108
00109 int32_t first_child = (index + 1) * splits();
00110
00111 int32_t best_child = first_child;
00112 distance_type best_distance = std::numeric_limits<distance_type>::max();
00113 for (int32_t child = first_child; child < first_child + (int32_t)splits(); ++child) {
00114 if (!valid_centers_[child])
00115 break;
00116 distance_type child_distance = distance_(f, centers_[child]);
00117 if (child_distance < best_distance) {
00118 best_child = child;
00119 best_distance = child_distance;
00120 }
00121 }
00122 index = best_child;
00123 }
00124
00125 return index - word_start_;
00126 };
00127
00128 template<class Feature, class Distance, class FeatureAllocator>
00129 uint32_t VocabularyTree<Feature, Distance, FeatureAllocator>::levels() const
00130 {
00131 return levels_;
00132 }
00133
00134 template<class Feature, class Distance, class FeatureAllocator>
00135 uint32_t VocabularyTree<Feature, Distance, FeatureAllocator>::splits() const
00136 {
00137 return k_;
00138 }
00139
00140 template<class Feature, class Distance, class FeatureAllocator>
00141 uint32_t VocabularyTree<Feature, Distance, FeatureAllocator>::words() const
00142 {
00143 return num_words_;
00144 }
00145
00146 template<class Feature, class Distance, class FeatureAllocator>
00147 void VocabularyTree<Feature, Distance, FeatureAllocator>::clear()
00148 {
00149 centers_.clear();
00150 valid_centers_.clear();
00151 k_ = levels_ = num_words_ = word_start_ = 0;
00152 }
00153
00154 template<class Feature, class Distance, class FeatureAllocator>
00155 void VocabularyTree<Feature, Distance, FeatureAllocator>::save(const std::string& file) const
00156 {
00159 assert( initialized() );
00160
00161 std::ofstream out(file.c_str(), std::ios_base::binary);
00162 out.write((char*)(&k_), sizeof(uint32_t));
00163 out.write((char*)(&levels_), sizeof(uint32_t));
00164 uint32_t size = centers_.size();
00165 out.write((char*)(&size), sizeof(uint32_t));
00166 out.write((char*)(¢ers_[0]), centers_.size() * sizeof(Feature));
00167 out.write((char*)(&valid_centers_[0]), valid_centers_.size());
00168 }
00169
00170 template<class Feature, class Distance, class FeatureAllocator>
00171 void VocabularyTree<Feature, Distance, FeatureAllocator>::load(const std::string& file)
00172 {
00173 clear();
00174
00175 std::ifstream in;
00176 in.exceptions(std::ifstream::eofbit | std::ifstream::failbit | std::ifstream::badbit);
00177
00178 uint32_t size;
00179 try {
00180 in.open(file.c_str(), std::ios_base::binary);
00181 in.read((char*)(&k_), sizeof(uint32_t));
00182 in.read((char*)(&levels_), sizeof(uint32_t));
00183 in.read((char*)(&size), sizeof(uint32_t));
00184 centers_.resize(size);
00185 valid_centers_.resize(size);
00186 in.read((char*)(¢ers_[0]), centers_.size() * sizeof(Feature));
00187 in.read((char*)(&valid_centers_[0]), valid_centers_.size());
00188 }
00189 catch (std::ifstream::failure& e) {
00190 throw std::runtime_error( (boost::format("Failed to load vocabulary tree file '%s'") % file).str() );
00191 }
00192
00193 setNodeCounts();
00194 assert(size == num_words_ + word_start_);
00195 }
00196
00197 template<class Feature, class Distance, class FeatureAllocator>
00198 void VocabularyTree<Feature, Distance, FeatureAllocator>::setNodeCounts()
00199 {
00200 num_words_ = k_;
00201 word_start_ = 0;
00202 for (uint32_t i = 0; i < levels_ - 1; ++i) {
00203 word_start_ += num_words_;
00204 num_words_ *= k_;
00205 }
00206 }
00207
00208 }
00209
00210 #endif