00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073
00074
00075
00076
00077
00078
00079
00080
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097
00098
00099
00100 #ifndef __GCOPTIMIZATION_H__
00101 #define __GCOPTIMIZATION_H__
00102
00103
00104 #if defined(_MSC_VER) && _MSC_VER < 1400
00105 #error Requires Visual C++ 2005 (VC8) compiler or later.
00106 #endif
00107
00108 #include "energy.h"
00109 #include "graph.cpp"
00110 #include "maxflow.cpp"
00111 #include <cstddef>
00113
00115
00116 class GCException {
00117 public:
00118 const char* message;
00119 GCException( const char* m ): message(m) { }
00120 void Report();
00121 };
00122
00123 #ifdef _WIN32
00124 typedef __int64 gcoclock_t;
00125 #else
00126 #include <ctime>
00127 typedef clock_t gcoclock_t;
00128 #endif
00129 extern "C" gcoclock_t gcoclock();
00130 extern "C" gcoclock_t GCO_CLOCKS_PER_SEC;
00131
00132 #ifdef _MSC_EXTENSIONS
00133 #define OLGA_INLINE __forceinline
00134 #else
00135 #define OLGA_INLINE inline
00136 #endif
00137
00138 #ifndef GCO_MAX_ENERGYTERM
00139 #define GCO_MAX_ENERGYTERM 10000000 // maximum safe coefficient to avoid integer overflow
00140
00141
00142 #endif
00143
00144
00146
00148 class LinkedBlockList;
00149
00150 class GCoptimization
00151 {
00152 public:
00153 #ifdef GCO_ENERGYTYPE32
00154 typedef int EnergyType;
00155 #else
00156 typedef long long EnergyType;
00157 #endif
00158 typedef int EnergyTermType;
00159 typedef Energy<EnergyTermType,EnergyTermType,EnergyType> EnergyT;
00160 typedef EnergyT::Var VarID;
00161 typedef int LabelID;
00162 typedef VarID SiteID;
00163 typedef EnergyTermType (*SmoothCostFn)(SiteID s1, SiteID s2, LabelID l1, LabelID l2);
00164 typedef EnergyTermType (*DataCostFn)(SiteID s, LabelID l);
00165 typedef EnergyTermType (*SmoothCostFnExtra)(SiteID s1, SiteID s2, LabelID l1, LabelID l2,void *);
00166 typedef EnergyTermType (*DataCostFnExtra)(SiteID s, LabelID l,void *);
00167
00168 GCoptimization(SiteID num_sites, LabelID num_labels);
00169 virtual ~GCoptimization();
00170
00171
00172
00173 EnergyType expansion(int max_num_iterations=-1);
00174
00175
00176 bool alpha_expansion(LabelID alpha_label);
00177
00178
00179
00180 EnergyType swap(int max_num_iterations=-1);
00181
00182
00183 void alpha_beta_swap(LabelID alpha_label, LabelID beta_label);
00184
00185
00186
00187
00188 void alpha_beta_swap(LabelID alpha_label, LabelID beta_label, SiteID *alphaSites,
00189 SiteID alpha_size, SiteID *betaSites, SiteID beta_size);
00190
00191 struct DataCostFunctor;
00192 struct SmoothCostFunctor;
00193
00194
00195 void setDataCost(DataCostFn fn);
00196 void setDataCost(DataCostFnExtra fn, void *extraData);
00197 void setDataCost(EnergyTermType *dataArray);
00198 void setDataCost(SiteID s, LabelID l, EnergyTermType e);
00199 void setDataCostFunctor(DataCostFunctor* f);
00200 struct DataCostFunctor {
00201 virtual EnergyTermType compute(SiteID s, LabelID l) = 0;
00202 };
00203
00204
00205 struct SparseDataCost {
00206 SiteID site;
00207 EnergyTermType cost;
00208 };
00209 void setDataCost(LabelID l, SparseDataCost *costs, SiteID count);
00210
00211
00212
00213 void setSmoothCost(SmoothCostFn fn);
00214 void setSmoothCost(SmoothCostFnExtra fn, void *extraData);
00215 void setSmoothCost(LabelID l1, LabelID l2, EnergyTermType e);
00216 void setSmoothCost(EnergyTermType *smoothArray);
00217 void setSmoothCostFunctor(SmoothCostFunctor* f);
00218 struct SmoothCostFunctor {
00219 virtual EnergyTermType compute(SiteID s1, SiteID s2, LabelID l1, LabelID l2) = 0;
00220 };
00221
00222
00223
00224 void setLabelCost(EnergyTermType cost);
00225 void setLabelCost(EnergyTermType* costArray);
00226 void setLabelSubsetCost(LabelID* labels, LabelID numLabels, EnergyTermType cost);
00227
00228
00229 LabelID whatLabel(SiteID site);
00230 void whatLabel(SiteID start, SiteID count, LabelID* labeling);
00231
00232
00233 void setLabel(SiteID site, LabelID label);
00234
00235
00236
00237
00238
00239
00240 void setLabelOrder(bool isRandom);
00241 void setLabelOrder(const LabelID* order, LabelID size);
00242
00243
00244 EnergyType compute_energy();
00245
00246
00247 EnergyType giveDataEnergy();
00248 EnergyType giveSmoothEnergy();
00249 EnergyType giveLabelEnergy();
00250
00251
00252 SiteID numSites() const;
00253 LabelID numLabels() const;
00254
00255
00256
00257
00258
00259 void setVerbosity(int level) { m_verbosity = level; }
00260
00261 protected:
00262 struct LabelCost {
00263 ~LabelCost() { delete [] labels; }
00264 EnergyTermType cost;
00265 bool active;
00266 VarID aux;
00267 LabelCost* next;
00268 LabelID numLabels;
00269 LabelID* labels;
00270 };
00271
00272 struct LabelCostIter {
00273 LabelCost* node;
00274 LabelCostIter* next;
00275 };
00276
00277 LabelID m_num_labels;
00278 SiteID m_num_sites;
00279 LabelID *m_labeling;
00280 SiteID *m_lookupSiteVar;
00281
00282 LabelID *m_labelTable;
00283 int m_stepsThisCycle;
00284 int m_stepsThisCycleTotal;
00285 int m_random_label_order;
00286 EnergyTermType* m_datacostIndividual;
00287 EnergyTermType* m_smoothcostIndividual;
00288 EnergyTermType* m_labelingDataCosts;
00289 SiteID* m_labelCounts;
00290 SiteID* m_activeLabelCounts;
00291 LabelCost* m_labelcostsAll;
00292 LabelCostIter** m_labelcostsByLabel;
00293 int m_labelcostCount;
00294 bool m_labelingInfoDirty;
00295 int m_verbosity;
00296
00297 void* m_datacostFn;
00298 void* m_smoothcostFn;
00299 EnergyType m_beforeExpansionEnergy;
00300
00301 SiteID *m_numNeighbors;
00302 SiteID m_numNeighborsTotal;
00303
00304 EnergyType (GCoptimization::*m_giveSmoothEnergyInternal)();
00305 SiteID (GCoptimization::*m_queryActiveSitesExpansion)(LabelID, SiteID*);
00306 void (GCoptimization::*m_setupDataCostsExpansion)(SiteID,LabelID,EnergyT*,SiteID*);
00307 void (GCoptimization::*m_setupSmoothCostsExpansion)(SiteID,LabelID,EnergyT*,SiteID*);
00308 void (GCoptimization::*m_setupDataCostsSwap)(SiteID,LabelID,LabelID,EnergyT*,SiteID*);
00309 void (GCoptimization::*m_setupSmoothCostsSwap)(SiteID,LabelID,LabelID,EnergyT*,SiteID*);
00310 void (GCoptimization::*m_applyNewLabeling)(EnergyT*,SiteID*,SiteID,LabelID);
00311 void (GCoptimization::*m_updateLabelingDataCosts)();
00312
00313 void (*m_datacostFnDelete)(void* f);
00314 void (*m_smoothcostFnDelete)(void* f);
00315 bool (GCoptimization::*m_solveSpecialCases)(EnergyType&);
00316
00317
00318 virtual void giveNeighborInfo(SiteID site, SiteID *numSites, SiteID **neighbors, EnergyTermType **weights)=0;
00319 virtual void finalizeNeighbors() = 0;
00320
00321 struct DataCostFnFromArray {
00322 DataCostFnFromArray(EnergyTermType* theArray, LabelID num_labels)
00323 : m_array(theArray), m_num_labels(num_labels){}
00324 OLGA_INLINE EnergyTermType compute(SiteID s, LabelID l){return m_array[s*m_num_labels+l];}
00325 private:
00326 const EnergyTermType* const m_array;
00327 const LabelID m_num_labels;
00328 };
00329
00330 struct DataCostFnFromFunction {
00331 DataCostFnFromFunction(DataCostFn fn): m_fn(fn){}
00332 OLGA_INLINE EnergyTermType compute(SiteID s, LabelID l){return m_fn(s,l);}
00333 private:
00334 const DataCostFn m_fn;
00335 };
00336
00337 struct DataCostFnFromFunctionExtra {
00338 DataCostFnFromFunctionExtra(DataCostFnExtra fn,void *extraData): m_fn(fn),m_extraData(extraData){}
00339 OLGA_INLINE EnergyTermType compute(SiteID s, LabelID l){return m_fn(s,l,m_extraData);}
00340 private:
00341 const DataCostFnExtra m_fn;
00342 void *m_extraData;
00343 };
00344
00345 struct SmoothCostFnFromArray {
00346 SmoothCostFnFromArray(EnergyTermType* theArray, LabelID num_labels)
00347 : m_array(theArray), m_num_labels(num_labels){}
00348 OLGA_INLINE EnergyTermType compute(SiteID s1, SiteID s2, LabelID l1, LabelID l2){return m_array[l1*m_num_labels+l2];}
00349 private:
00350 const EnergyTermType* const m_array;
00351 const LabelID m_num_labels;
00352 };
00353
00354 struct SmoothCostFnFromFunction {
00355 SmoothCostFnFromFunction(SmoothCostFn fn)
00356 : m_fn(fn){}
00357 OLGA_INLINE EnergyTermType compute(SiteID s1, SiteID s2, LabelID l1, LabelID l2){return m_fn(s1,s2,l1,l2);}
00358 private:
00359 const SmoothCostFn m_fn;
00360 };
00361
00362 struct SmoothCostFnFromFunctionExtra {
00363 SmoothCostFnFromFunctionExtra(SmoothCostFnExtra fn,void *extraData)
00364 : m_fn(fn),m_extraData(extraData){}
00365 OLGA_INLINE EnergyTermType compute(SiteID s1, SiteID s2, LabelID l1, LabelID l2){return m_fn(s1,s2,l1,l2,m_extraData);}
00366 private:
00367 const SmoothCostFnExtra m_fn;
00368 void *m_extraData;
00369 };
00370
00371 struct SmoothCostFnPotts {
00372 OLGA_INLINE EnergyTermType compute(SiteID, SiteID, LabelID l1, LabelID l2){return l1 != l2 ? 1 : 0;}
00373 };
00374
00376
00377
00378
00380 class DataCostFnSparse {
00381
00382
00383
00384
00385
00386 static const int cLogSitesPerBucket = 9;
00387 static const int cSitesPerBucket = (1 << cLogSitesPerBucket);
00388 static const size_t cDataCostPtrMask = ~(sizeof(SparseDataCost)-1);
00389 static const ptrdiff_t cLinearSearchSize = 64/sizeof(SparseDataCost);
00390
00391 struct DataCostBucket {
00392 const SparseDataCost* begin;
00393 const SparseDataCost* end;
00394 const SparseDataCost* predict;
00395 };
00396
00397 public:
00398 DataCostFnSparse(SiteID num_sites, LabelID num_labels);
00399 DataCostFnSparse(const DataCostFnSparse& src);
00400 ~DataCostFnSparse();
00401
00402 void set(LabelID l, const SparseDataCost* costs, SiteID count);
00403 EnergyTermType compute(SiteID s, LabelID l);
00404 SiteID queryActiveSitesExpansion(LabelID alpha_label, const LabelID* labeling, SiteID* activeSites);
00405
00406 class iterator {
00407 public:
00408 OLGA_INLINE iterator(): m_ptr(0) { }
00409 OLGA_INLINE iterator& operator++() { m_ptr++; return *this; }
00410 OLGA_INLINE SiteID site() const { return m_ptr->site; }
00411 OLGA_INLINE EnergyTermType cost() const { return m_ptr->cost; }
00412 OLGA_INLINE bool operator==(const iterator& b) const { return m_ptr == b.m_ptr; }
00413 OLGA_INLINE bool operator!=(const iterator& b) const { return m_ptr != b.m_ptr; }
00414 OLGA_INLINE ptrdiff_t operator- (const iterator& b) const { return m_ptr - b.m_ptr; }
00415 private:
00416 OLGA_INLINE iterator(const SparseDataCost* ptr): m_ptr(ptr) { }
00417 const SparseDataCost* m_ptr;
00418 friend class DataCostFnSparse;
00419 };
00420
00421 OLGA_INLINE iterator begin(LabelID label) const { return m_buckets[label*m_buckets_per_label].begin; }
00422 OLGA_INLINE iterator end(LabelID label) const { return m_buckets[label*m_buckets_per_label + m_buckets_per_label-1].end; }
00423
00424 private:
00425 EnergyTermType search(DataCostBucket& b, SiteID s);
00426 const SiteID m_num_sites;
00427 const LabelID m_num_labels;
00428 const int m_buckets_per_label;
00429 mutable DataCostBucket* m_buckets;
00430 };
00431
00432 template <typename DataCostT> SiteID queryActiveSitesExpansion(LabelID alpha_label, SiteID* activeSites);
00433 template <typename DataCostT> void setupDataCostsExpansion(SiteID size,LabelID alpha_label,EnergyT *e,SiteID *activeSites);
00434 template <typename DataCostT> void setupDataCostsSwap(SiteID size,LabelID alpha_label,LabelID beta_label,EnergyT *e,SiteID *activeSites);
00435 template <typename SmoothCostT> void setupSmoothCostsExpansion(SiteID size,LabelID alpha_label,EnergyT *e,SiteID *activeSites);
00436 template <typename SmoothCostT> void setupSmoothCostsSwap(SiteID size,LabelID alpha_label,LabelID beta_label,EnergyT *e,SiteID *activeSites);
00437 template <typename DataCostT> void applyNewLabeling(EnergyT *e,SiteID *activeSites,SiteID size,LabelID alpha_label);
00438 template <typename DataCostT> void updateLabelingDataCosts();
00439 template <typename UserFunctor> void specializeDataCostFunctor(const UserFunctor f);
00440 template <typename UserFunctor> void specializeSmoothCostFunctor(const UserFunctor f);
00441
00442 EnergyType setupLabelCostsExpansion(SiteID size,LabelID alpha_label,EnergyT *e,SiteID *activeSites);
00443 void updateLabelingInfo(bool updateCounts=true,bool updateActive=true,bool updateCosts=true);
00444
00445
00446 void addterm1_checked(EnergyT *e,VarID i,EnergyTermType e0,EnergyTermType e1);
00447 void addterm1_checked(EnergyT *e,VarID i,EnergyTermType e0,EnergyTermType e1,EnergyTermType w);
00448 void addterm2_checked(EnergyT *e,VarID i,VarID j,EnergyTermType e00,EnergyTermType e01,EnergyTermType e10,EnergyTermType e11,EnergyTermType w);
00449
00450
00451 template <typename SmoothCostT> EnergyType giveSmoothEnergyInternal();
00452 template <typename Functor> static void deleteFunctor(void* f) { delete reinterpret_cast<Functor*>(f); }
00453
00454 static void handleError(const char *message);
00455 static void checkInterrupt();
00456
00457 private:
00458
00459 EnergyType oneExpansionIteration();
00460 EnergyType oneSwapIteration();
00461 void printStatus1(const char* extraMsg=0);
00462 void printStatus1(int cycle, bool isSwap, gcoclock_t ticks0);
00463 void printStatus2(int alpha, int beta, int numVars, gcoclock_t ticks0);
00464
00465 void permuteLabelTable();
00466
00467 template <typename DataCostT> bool solveSpecialCases(EnergyType& energy);
00468 template <typename DataCostT> EnergyType solveGreedy();
00469
00471
00472
00473
00475 template <typename DataCostT>
00476 class GreedyIter {
00477 public:
00478 GreedyIter(DataCostT& dc, SiteID numSites)
00479 : m_dc(dc), m_site(0), m_numSites(numSites), m_label(0), m_lbegin(0), m_lend(0)
00480 { }
00481
00482 OLGA_INLINE void start(const LabelID* labels, LabelID labelCount=1)
00483 {
00484 m_site = labelCount ? 0 : m_numSites;
00485 m_label = m_lbegin = labels;
00486 m_lend = labels + labelCount;
00487 }
00488 OLGA_INLINE SiteID site() const { return m_site; }
00489 OLGA_INLINE SiteID label() const { return *m_label; }
00490 OLGA_INLINE bool done() const { return m_site == m_numSites; }
00491 OLGA_INLINE GreedyIter& operator++()
00492 {
00493
00494
00495
00496 if (++m_label >= m_lend) {
00497 m_label = m_lbegin;
00498 ++m_site;
00499 }
00500 return *this;
00501 }
00502 OLGA_INLINE EnergyTermType compute() const { return m_dc.compute(m_site,*m_label); }
00503 OLGA_INLINE SiteID feasibleSites() const { return m_numSites; }
00504
00505 private:
00506 SiteID m_site;
00507 DataCostT& m_dc;
00508 const SiteID m_numSites;
00509 const LabelID* m_label;
00510 const LabelID* m_lbegin;
00511 const LabelID* m_lend;
00512 };
00513 };
00514
00515
00517
00519
00520 class GCoptimizationGridGraph: public GCoptimization
00521 {
00522 public:
00523 GCoptimizationGridGraph(SiteID width,SiteID height,LabelID num_labels);
00524 virtual ~GCoptimizationGridGraph();
00525
00526 void setSmoothCostVH(EnergyTermType *smoothArray, EnergyTermType *vCosts, EnergyTermType *hCosts);
00527
00528 protected:
00529 virtual void giveNeighborInfo(SiteID site, SiteID *numSites, SiteID **neighbors, EnergyTermType **weights);
00530 virtual void finalizeNeighbors();
00531 EnergyTermType m_unityWeights[4];
00532 int m_weightedGraph;
00533
00534 private:
00535 SiteID m_width;
00536 SiteID m_height;
00537 SiteID *m_neighbors;
00538 EnergyTermType *m_neighborsWeights;
00539
00540 void setupNeighbData(SiteID startY,SiteID endY,SiteID startX,SiteID endX,SiteID maxInd,SiteID *indexes);
00541 void computeNeighborWeights(EnergyTermType *vCosts,EnergyTermType *hCosts);
00542 };
00543
00545
00546 class GCoptimizationGeneralGraph:public GCoptimization
00547 {
00548 public:
00549
00550
00551 GCoptimizationGeneralGraph(SiteID num_sites,LabelID num_labels);
00552 virtual ~GCoptimizationGeneralGraph();
00553
00554
00555
00556
00557
00558
00559 void setNeighbors(SiteID site1, SiteID site2, EnergyTermType weight=1);
00560
00561
00562
00563
00564
00565
00566 void setAllNeighbors(SiteID *numNeighbors,SiteID **neighborsIndexes,EnergyTermType **neighborsWeights);
00567
00568 protected:
00569 virtual void giveNeighborInfo(SiteID site, SiteID *numSites, SiteID **neighbors, EnergyTermType **weights);
00570 virtual void finalizeNeighbors();
00571
00572 private:
00573
00574 typedef struct NeighborStruct{
00575 SiteID to_node;
00576 EnergyTermType weight;
00577 } Neighbor;
00578
00579 LinkedBlockList *m_neighbors;
00580 bool m_needToFinishSettingNeighbors;
00581 SiteID **m_neighborsIndexes;
00582 EnergyTermType **m_neighborsWeights;
00583 bool m_needTodeleteNeighbors;
00584 };
00585
00586
00588
00590
00591
00592 OLGA_INLINE GCoptimization::SiteID GCoptimization::numSites() const
00593 {
00594 return m_num_sites;
00595 }
00596
00597 OLGA_INLINE GCoptimization::LabelID GCoptimization::numLabels() const
00598 {
00599 return m_num_labels;
00600 }
00601
00602 OLGA_INLINE void GCoptimization::setLabel(SiteID site, LabelID label)
00603 {
00604 assert(label >= 0 && label < m_num_labels && site >= 0 && site < m_num_sites);
00605 m_labeling[site] = label;
00606 m_labelingInfoDirty = true;
00607 }
00608
00609 OLGA_INLINE GCoptimization::LabelID GCoptimization::whatLabel(SiteID site)
00610 {
00611 assert(site >= 0 && site < m_num_sites);
00612 return m_labeling[site];
00613 }
00614
00615 #endif