00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 
00014 
00015 
00016 
00017 
00018 
00019 
00020 
00021 
00022 
00023 
00024 
00025 
00026 
00027 
00028 
00029 
00030 
00031 
00032 
00033 
00034 
00035 
00036 
00037 
00038 
00039 
00040 
00041 
00042 
00043 
00044 
00045 
00046 
00047 
00048 
00049 
00050 
00051 
00052 
00053 
00054 
00055 
00056 
00057 
00058 
00059 
00060 
00061 
00062 
00063 
00064 
00065 
00066 
00067 
00068 
00069 
00070 
00071 
00072 
00073 
00074 
00075 
00076 
00077 
00078 
00079 
00080 
00081 
00082 
00083 
00084 
00085 
00086 
00087 
00088 
00089 
00090 
00091 
00092 
00093 
00094 
00095 
00096 
00097 
00098 
00099 
00100 #ifndef __GCOPTIMIZATION_H__
00101 #define __GCOPTIMIZATION_H__
00102 
00103 
00104 #if defined(_MSC_VER) && _MSC_VER < 1400
00105 #error Requires Visual C++ 2005 (VC8) compiler or later.
00106 #endif
00107 
00108 #include "energy.h"
00109 #include "graph.cpp"
00110 #include "maxflow.cpp"
00111 #include <cstddef>
00113 
00115 
00116 class GCException {
00117 public:
00118         const char* message;
00119         GCException( const char* m ): message(m) { }
00120         void Report();
00121 };
00122 
00123 #ifdef _WIN32
00124 typedef __int64 gcoclock_t;
00125 #else
00126 #include <ctime>
00127 typedef clock_t gcoclock_t;
00128 #endif
00129 extern "C" gcoclock_t gcoclock(); 
00130 extern "C" gcoclock_t GCO_CLOCKS_PER_SEC; 
00131 
00132 #ifdef _MSC_EXTENSIONS
00133 #define OLGA_INLINE __forceinline
00134 #else
00135 #define OLGA_INLINE inline
00136 #endif
00137 
00138 #ifndef GCO_MAX_ENERGYTERM
00139 #define GCO_MAX_ENERGYTERM 10000000  // maximum safe coefficient to avoid integer overflow
00140                                      
00141                                      
00142 #endif
00143 
00144 
00146 
00148 class LinkedBlockList;
00149 
00150 class GCoptimization
00151 {
00152 public: 
00153 #ifdef GCO_ENERGYTYPE32
00154         typedef int EnergyType;        
00155 #else
00156         typedef long long EnergyType;  
00157 #endif
00158         typedef int EnergyTermType;    
00159         typedef Energy<EnergyTermType,EnergyTermType,EnergyType> EnergyT;
00160         typedef EnergyT::Var VarID;
00161         typedef int LabelID;                     
00162         typedef VarID SiteID;                    
00163         typedef EnergyTermType (*SmoothCostFn)(SiteID s1, SiteID s2, LabelID l1, LabelID l2);
00164         typedef EnergyTermType (*DataCostFn)(SiteID s, LabelID l);
00165         typedef EnergyTermType (*SmoothCostFnExtra)(SiteID s1, SiteID s2, LabelID l1, LabelID l2,void *);
00166         typedef EnergyTermType (*DataCostFnExtra)(SiteID s, LabelID l,void *);
00167 
00168         GCoptimization(SiteID num_sites, LabelID num_labels);
00169         virtual ~GCoptimization();
00170 
00171         
00172         
00173         EnergyType expansion(int max_num_iterations=-1);
00174 
00175         
00176         bool alpha_expansion(LabelID alpha_label);
00177 
00178         
00179         
00180         EnergyType swap(int max_num_iterations=-1);
00181 
00182         
00183         void alpha_beta_swap(LabelID alpha_label, LabelID beta_label);
00184 
00185         
00186         
00187         
00188         void alpha_beta_swap(LabelID alpha_label, LabelID beta_label, SiteID *alphaSites, 
00189                                  SiteID alpha_size, SiteID *betaSites, SiteID beta_size);
00190 
00191         struct DataCostFunctor;      
00192         struct SmoothCostFunctor;    
00193 
00194         
00195         void setDataCost(DataCostFn fn);
00196         void setDataCost(DataCostFnExtra fn, void *extraData);
00197         void setDataCost(EnergyTermType *dataArray);
00198         void setDataCost(SiteID s, LabelID l, EnergyTermType e); 
00199         void setDataCostFunctor(DataCostFunctor* f);
00200         struct DataCostFunctor {
00201                 virtual EnergyTermType compute(SiteID s, LabelID l) = 0;
00202         };
00203         
00204         
00205         struct SparseDataCost {
00206                 SiteID site;
00207                 EnergyTermType cost;
00208         };
00209         void setDataCost(LabelID l, SparseDataCost *costs, SiteID count);
00210 
00211         
00212         
00213         void setSmoothCost(SmoothCostFn fn);
00214         void setSmoothCost(SmoothCostFnExtra fn, void *extraData);
00215         void setSmoothCost(LabelID l1, LabelID l2, EnergyTermType e); 
00216         void setSmoothCost(EnergyTermType *smoothArray);
00217         void setSmoothCostFunctor(SmoothCostFunctor* f);
00218         struct SmoothCostFunctor {
00219                 virtual EnergyTermType compute(SiteID s1, SiteID s2, LabelID l1, LabelID l2) = 0;
00220         };
00221 
00222         
00223         
00224         void setLabelCost(EnergyTermType cost);
00225         void setLabelCost(EnergyTermType* costArray);
00226         void setLabelSubsetCost(LabelID* labels, LabelID numLabels, EnergyTermType cost);
00227 
00228         
00229         LabelID whatLabel(SiteID site);
00230         void    whatLabel(SiteID start, SiteID count, LabelID* labeling);
00231 
00232         
00233         void setLabel(SiteID site, LabelID label);
00234 
00235         
00236         
00237         
00238         
00239         
00240         void setLabelOrder(bool isRandom);
00241         void setLabelOrder(const LabelID* order, LabelID size);
00242 
00243         
00244         EnergyType compute_energy();
00245 
00246         
00247         EnergyType giveDataEnergy();
00248         EnergyType giveSmoothEnergy();
00249         EnergyType giveLabelEnergy();
00250 
00251         
00252         SiteID  numSites() const;
00253         LabelID numLabels() const;
00254 
00255         
00256         
00257         
00258         
00259         void setVerbosity(int level) { m_verbosity = level; }
00260 
00261 protected:
00262         struct LabelCost {
00263                 ~LabelCost() { delete [] labels; }
00264                 EnergyTermType cost; 
00265                 bool active;     
00266                 VarID aux;
00267                 LabelCost* next; 
00268                 LabelID numLabels;
00269                 LabelID* labels;
00270         };
00271 
00272         struct LabelCostIter {
00273                 LabelCost* node;
00274                 LabelCostIter* next; 
00275         };
00276 
00277         LabelID m_num_labels;
00278         SiteID  m_num_sites;
00279         LabelID *m_labeling;
00280         SiteID  *m_lookupSiteVar; 
00281                                   
00282         LabelID *m_labelTable;    
00283         int      m_stepsThisCycle;
00284         int      m_stepsThisCycleTotal;
00285         int      m_random_label_order;
00286         EnergyTermType* m_datacostIndividual;
00287         EnergyTermType* m_smoothcostIndividual;
00288         EnergyTermType* m_labelingDataCosts;
00289         SiteID*         m_labelCounts;
00290         SiteID*         m_activeLabelCounts;
00291         LabelCost*      m_labelcostsAll;
00292         LabelCostIter** m_labelcostsByLabel;
00293         int             m_labelcostCount;
00294         bool            m_labelingInfoDirty;
00295         int             m_verbosity;
00296 
00297         void*   m_datacostFn;
00298         void*   m_smoothcostFn;
00299         EnergyType m_beforeExpansionEnergy;
00300 
00301         SiteID *m_numNeighbors;              
00302         SiteID  m_numNeighborsTotal;         
00303 
00304         EnergyType (GCoptimization::*m_giveSmoothEnergyInternal)();
00305         SiteID (GCoptimization::*m_queryActiveSitesExpansion)(LabelID, SiteID*);
00306         void (GCoptimization::*m_setupDataCostsExpansion)(SiteID,LabelID,EnergyT*,SiteID*);
00307         void (GCoptimization::*m_setupSmoothCostsExpansion)(SiteID,LabelID,EnergyT*,SiteID*);
00308         void (GCoptimization::*m_setupDataCostsSwap)(SiteID,LabelID,LabelID,EnergyT*,SiteID*);
00309         void (GCoptimization::*m_setupSmoothCostsSwap)(SiteID,LabelID,LabelID,EnergyT*,SiteID*);
00310         void (GCoptimization::*m_applyNewLabeling)(EnergyT*,SiteID*,SiteID,LabelID);
00311         void (GCoptimization::*m_updateLabelingDataCosts)();
00312 
00313         void (*m_datacostFnDelete)(void* f);
00314         void (*m_smoothcostFnDelete)(void* f);
00315         bool (GCoptimization::*m_solveSpecialCases)(EnergyType&);
00316 
00317         
00318         virtual void giveNeighborInfo(SiteID site, SiteID *numSites, SiteID **neighbors, EnergyTermType **weights)=0;
00319         virtual void finalizeNeighbors() = 0;
00320 
00321         struct DataCostFnFromArray {
00322                 DataCostFnFromArray(EnergyTermType* theArray, LabelID num_labels)
00323                         : m_array(theArray), m_num_labels(num_labels){}
00324                 OLGA_INLINE EnergyTermType compute(SiteID s, LabelID l){return m_array[s*m_num_labels+l];}
00325         private:
00326                 const EnergyTermType* const m_array;
00327                 const LabelID m_num_labels;
00328         };
00329 
00330         struct DataCostFnFromFunction {
00331                 DataCostFnFromFunction(DataCostFn fn): m_fn(fn){}
00332                 OLGA_INLINE EnergyTermType compute(SiteID s, LabelID l){return m_fn(s,l);}
00333         private:
00334                 const DataCostFn m_fn;
00335         };
00336 
00337         struct DataCostFnFromFunctionExtra {
00338                 DataCostFnFromFunctionExtra(DataCostFnExtra fn,void *extraData): m_fn(fn),m_extraData(extraData){}
00339                 OLGA_INLINE EnergyTermType compute(SiteID s, LabelID l){return m_fn(s,l,m_extraData);}
00340         private:
00341                 const DataCostFnExtra m_fn;
00342                 void *m_extraData;
00343         };
00344 
00345         struct SmoothCostFnFromArray {
00346                 SmoothCostFnFromArray(EnergyTermType* theArray, LabelID num_labels)
00347                         : m_array(theArray), m_num_labels(num_labels){}
00348                 OLGA_INLINE EnergyTermType compute(SiteID s1, SiteID s2, LabelID l1, LabelID l2){return m_array[l1*m_num_labels+l2];}
00349         private:
00350                 const EnergyTermType* const m_array;
00351                 const LabelID m_num_labels;
00352         };
00353 
00354         struct SmoothCostFnFromFunction {
00355                 SmoothCostFnFromFunction(SmoothCostFn fn)
00356                         : m_fn(fn){}
00357                 OLGA_INLINE EnergyTermType compute(SiteID s1, SiteID s2, LabelID l1, LabelID l2){return m_fn(s1,s2,l1,l2);}
00358         private:
00359                 const SmoothCostFn m_fn;
00360         };
00361 
00362         struct SmoothCostFnFromFunctionExtra {
00363                 SmoothCostFnFromFunctionExtra(SmoothCostFnExtra fn,void *extraData)
00364                         : m_fn(fn),m_extraData(extraData){}
00365                 OLGA_INLINE EnergyTermType compute(SiteID s1, SiteID s2, LabelID l1, LabelID l2){return m_fn(s1,s2,l1,l2,m_extraData);}
00366         private:
00367                 const SmoothCostFnExtra m_fn;
00368                 void *m_extraData;
00369         };
00370 
00371         struct SmoothCostFnPotts {
00372                 OLGA_INLINE EnergyTermType compute(SiteID, SiteID, LabelID l1, LabelID l2){return l1 != l2 ? 1 : 0;}
00373         };
00374 
00376         
00377         
00378         
00380         class DataCostFnSparse {
00381                 
00382                 
00383                 
00384                 
00385                 
00386                 static const int cLogSitesPerBucket = 9;
00387                 static const int cSitesPerBucket = (1 << cLogSitesPerBucket); 
00388                 static const size_t    cDataCostPtrMask = ~(sizeof(SparseDataCost)-1);
00389                 static const ptrdiff_t cLinearSearchSize = 64/sizeof(SparseDataCost);
00390 
00391                 struct DataCostBucket {
00392                         const SparseDataCost* begin;
00393                         const SparseDataCost* end;     
00394                         const SparseDataCost* predict; 
00395                 };
00396 
00397         public:
00398                 DataCostFnSparse(SiteID num_sites, LabelID num_labels);
00399                 DataCostFnSparse(const DataCostFnSparse& src);
00400                 ~DataCostFnSparse();
00401 
00402                 void           set(LabelID l, const SparseDataCost* costs, SiteID count);
00403                 EnergyTermType compute(SiteID s, LabelID l);
00404                 SiteID         queryActiveSitesExpansion(LabelID alpha_label, const LabelID* labeling, SiteID* activeSites);
00405 
00406                 class iterator {
00407                 public:
00408                         OLGA_INLINE iterator(): m_ptr(0) { }
00409                         OLGA_INLINE iterator& operator++() { m_ptr++; return *this; }
00410                         OLGA_INLINE SiteID         site() const { return m_ptr->site; }
00411                         OLGA_INLINE EnergyTermType cost() const { return m_ptr->cost; }
00412                         OLGA_INLINE bool      operator==(const iterator& b) const { return m_ptr == b.m_ptr; }
00413                         OLGA_INLINE bool      operator!=(const iterator& b) const { return m_ptr != b.m_ptr; }
00414                         OLGA_INLINE ptrdiff_t operator- (const iterator& b) const { return m_ptr  - b.m_ptr; }
00415                 private:
00416                         OLGA_INLINE iterator(const SparseDataCost* ptr): m_ptr(ptr) { }
00417                         const SparseDataCost* m_ptr;
00418                         friend class DataCostFnSparse;
00419                 };
00420 
00421                 OLGA_INLINE iterator begin(LabelID label) const { return m_buckets[label*m_buckets_per_label].begin; }
00422                 OLGA_INLINE iterator end(LabelID label)   const { return m_buckets[label*m_buckets_per_label + m_buckets_per_label-1].end; }
00423 
00424         private:
00425                 EnergyTermType search(DataCostBucket& b, SiteID s);
00426                 const SiteID  m_num_sites;
00427                 const LabelID m_num_labels;
00428                 const int m_buckets_per_label;
00429                 mutable DataCostBucket* m_buckets;
00430         };
00431 
00432         template <typename DataCostT> SiteID queryActiveSitesExpansion(LabelID alpha_label, SiteID* activeSites);
00433         template <typename DataCostT>   void setupDataCostsExpansion(SiteID size,LabelID alpha_label,EnergyT *e,SiteID *activeSites);
00434         template <typename DataCostT>   void setupDataCostsSwap(SiteID size,LabelID alpha_label,LabelID beta_label,EnergyT *e,SiteID *activeSites);
00435         template <typename SmoothCostT> void setupSmoothCostsExpansion(SiteID size,LabelID alpha_label,EnergyT *e,SiteID *activeSites);
00436         template <typename SmoothCostT> void setupSmoothCostsSwap(SiteID size,LabelID alpha_label,LabelID beta_label,EnergyT *e,SiteID *activeSites);
00437         template <typename DataCostT>   void applyNewLabeling(EnergyT *e,SiteID *activeSites,SiteID size,LabelID alpha_label);
00438         template <typename DataCostT>   void updateLabelingDataCosts();
00439         template <typename UserFunctor> void specializeDataCostFunctor(const UserFunctor f);
00440         template <typename UserFunctor> void specializeSmoothCostFunctor(const UserFunctor f);
00441 
00442         EnergyType setupLabelCostsExpansion(SiteID size,LabelID alpha_label,EnergyT *e,SiteID *activeSites);
00443         void       updateLabelingInfo(bool updateCounts=true,bool updateActive=true,bool updateCosts=true);
00444         
00445         
00446         void addterm1_checked(EnergyT *e,VarID i,EnergyTermType e0,EnergyTermType e1);
00447         void addterm1_checked(EnergyT *e,VarID i,EnergyTermType e0,EnergyTermType e1,EnergyTermType w);
00448         void addterm2_checked(EnergyT *e,VarID i,VarID j,EnergyTermType e00,EnergyTermType e01,EnergyTermType e10,EnergyTermType e11,EnergyTermType w);
00449 
00450         
00451         template <typename SmoothCostT> EnergyType giveSmoothEnergyInternal();
00452         template <typename Functor> static void deleteFunctor(void* f) { delete reinterpret_cast<Functor*>(f); }
00453 
00454         static void handleError(const char *message);
00455         static void checkInterrupt();
00456 
00457 private:
00458         
00459         EnergyType oneExpansionIteration();
00460         EnergyType oneSwapIteration();
00461         void printStatus1(const char* extraMsg=0);
00462         void printStatus1(int cycle, bool isSwap, gcoclock_t ticks0);
00463         void printStatus2(int alpha, int beta, int numVars, gcoclock_t ticks0);
00464         
00465         void permuteLabelTable();
00466 
00467         template <typename DataCostT> bool       solveSpecialCases(EnergyType& energy);
00468         template <typename DataCostT> EnergyType solveGreedy();
00469 
00471         
00472         
00473         
00475         template <typename DataCostT>
00476         class GreedyIter {
00477         public:
00478                 GreedyIter(DataCostT& dc, SiteID numSites)
00479                 : m_dc(dc), m_site(0), m_numSites(numSites), m_label(0), m_lbegin(0), m_lend(0)
00480                 { }
00481 
00482                 OLGA_INLINE void start(const LabelID* labels, LabelID labelCount=1)
00483                 {
00484                         m_site = labelCount ? 0 : m_numSites;
00485                         m_label = m_lbegin = labels;
00486                         m_lend = labels + labelCount;
00487                 }
00488                 OLGA_INLINE SiteID site()  const { return m_site; }
00489                 OLGA_INLINE SiteID label() const { return *m_label; }
00490                 OLGA_INLINE bool   done()  const { return m_site == m_numSites; }
00491                 OLGA_INLINE GreedyIter& operator++() 
00492                 {
00493                         
00494                         
00495                         
00496                         if (++m_label >= m_lend) {
00497                                 m_label = m_lbegin;
00498                                 ++m_site;
00499                         }
00500                         return *this;
00501                 }
00502                 OLGA_INLINE EnergyTermType compute() const { return m_dc.compute(m_site,*m_label); }
00503                 OLGA_INLINE SiteID feasibleSites() const { return m_numSites; }
00504 
00505         private:
00506                 SiteID m_site;
00507                 DataCostT& m_dc;
00508                 const SiteID m_numSites;
00509                 const LabelID* m_label;
00510                 const LabelID* m_lbegin;
00511                 const LabelID* m_lend;
00512         };
00513 };
00514 
00515 
00517 
00519 
00520 class GCoptimizationGridGraph: public GCoptimization
00521 {
00522 public:
00523         GCoptimizationGridGraph(SiteID width,SiteID height,LabelID num_labels);
00524         virtual ~GCoptimizationGridGraph();
00525 
00526         void setSmoothCostVH(EnergyTermType *smoothArray, EnergyTermType *vCosts, EnergyTermType *hCosts);
00527 
00528 protected:
00529         virtual void giveNeighborInfo(SiteID site, SiteID *numSites, SiteID **neighbors, EnergyTermType **weights);
00530         virtual void finalizeNeighbors();
00531         EnergyTermType m_unityWeights[4];
00532         int m_weightedGraph;  
00533 
00534 private:
00535         SiteID m_width;
00536         SiteID m_height;
00537         SiteID *m_neighbors;                 
00538         EnergyTermType *m_neighborsWeights;    
00539         
00540         void setupNeighbData(SiteID startY,SiteID endY,SiteID startX,SiteID endX,SiteID maxInd,SiteID *indexes);
00541         void computeNeighborWeights(EnergyTermType *vCosts,EnergyTermType *hCosts);
00542 };
00543 
00545 
00546 class GCoptimizationGeneralGraph:public GCoptimization
00547 {
00548 public:
00549         
00550         
00551         GCoptimizationGeneralGraph(SiteID num_sites,LabelID num_labels);
00552         virtual ~GCoptimizationGeneralGraph();
00553 
00554         
00555         
00556         
00557         
00558         
00559         void setNeighbors(SiteID site1, SiteID site2, EnergyTermType weight=1);
00560 
00561         
00562         
00563         
00564         
00565         
00566         void setAllNeighbors(SiteID *numNeighbors,SiteID **neighborsIndexes,EnergyTermType **neighborsWeights);
00567 
00568 protected: 
00569         virtual void giveNeighborInfo(SiteID site, SiteID *numSites, SiteID **neighbors, EnergyTermType **weights);
00570         virtual void finalizeNeighbors();
00571 
00572 private:
00573 
00574         typedef struct NeighborStruct{
00575                 SiteID  to_node;
00576                 EnergyTermType weight;
00577         } Neighbor;
00578 
00579         LinkedBlockList *m_neighbors;
00580         bool m_needToFinishSettingNeighbors;
00581         SiteID **m_neighborsIndexes;
00582         EnergyTermType **m_neighborsWeights;
00583         bool m_needTodeleteNeighbors;
00584 };
00585 
00586 
00588 
00590 
00591 
00592 OLGA_INLINE GCoptimization::SiteID GCoptimization::numSites() const
00593 {
00594         return m_num_sites;
00595 }
00596 
00597 OLGA_INLINE GCoptimization::LabelID GCoptimization::numLabels() const
00598 {
00599         return m_num_labels;
00600 }
00601 
00602 OLGA_INLINE void GCoptimization::setLabel(SiteID site, LabelID label)
00603 {
00604         assert(label >= 0 && label < m_num_labels && site >= 0 && site < m_num_sites);
00605         m_labeling[site] = label;
00606         m_labelingInfoDirty = true;
00607 }
00608 
00609 OLGA_INLINE GCoptimization::LabelID GCoptimization::whatLabel(SiteID site)
00610 {
00611         assert(site >= 0 && site < m_num_sites);
00612         return m_labeling[site];
00613 }
00614 
00615 #endif