00001 #ifdef MATLAB_MEX_FILE
00002 #include <mex.h>
00003 #endif
00004 #include "GCoptimization.h"
00005 #include "LinkedBlockList.h"
00006 #include <stdio.h>
00007 #include <stdlib.h>
00008 #include <vector>
00009 #include <algorithm>
00011 // will leave this one just for the laughs :)
00012 //#define olga_assert(expr) assert(!(expr))
00014 // Choose reasonably high-precision timer (sub-millisec resolution if possible).
00015 #ifdef _WIN32
00016 #define WIN32_LEAN_AND_MEAN
00017 #define VC_EXTRALEAN
00018 #define NOMINMAX
00019 #include <windows.h>
00020 extern "C" gcoclock_t GCO_CLOCKS_PER_SEC = 0;
00021 extern "C" inline gcoclock_t gcoclock() // TODO: not thread safe; separate begin/end so that end doesn't have to check for query frequency
00022 {
00023         gcoclock_t result = 0;
00024         if (GCO_CLOCKS_PER_SEC == 0)
00025                 QueryPerformanceFrequency((LARGE_INTEGER*)&GCO_CLOCKS_PER_SEC);
00026         QueryPerformanceCounter((LARGE_INTEGER*)&result);
00027         return result;
00028 }
00030 #else
00031 extern "C" gcoclock_t GCO_CLOCKS_PER_SEC = CLOCKS_PER_SEC;
00032 extern "C" gcoclock_t gcoclock() { return clock(); }
00033 #endif
00035 #ifdef MATLAB_MEX_FILE
00036 extern "C" bool utIsInterruptPending();
00037 static void flushnow()
00038 {
00039         // Don't flush to frequently, for overall speed.
00040         static gcoclock_t prevclock = 0;
00041         gcoclock_t now = gcoclock();
00042         if (now - prevclock > GCO_CLOCKS_PER_SEC/5) {
00043                 prevclock = now;
00044                 mexEvalString("drawnow;");
00045         }
00046 }
00047 #define INDEX0 1  // print 1-based label and site indices for MATLAB
00048 #else
00049 inline static bool utIsInterruptPending() { return false; }
00050 static void flushnow() { }
00051 #define INDEX0 0  // print 0-based label and site indices
00052 #endif
00054 // Singly-linked list helper functions; works on any struct with a 'next' member.
00055 template <typename T>
00056 void slist_clear(T*& head)
00057 {
00058         while (head) {
00059                 T* temp = head;
00060                 head = head->next;
00061                 delete temp;
00062         }
00063 }
00065 template <typename T>
00066 void slist_prepend(T*& head, T* val)
00067 {
00068         val->next = head;
00069         head = val;
00070 }
00073 void GCException::Report() 
00074 {
00075         printf("\n%s\n",message);
00076         exit(0);
00077 }
00082 //   First we have functions for the base class
00084 // Constructor for base class                                                       
00085 GCoptimization::GCoptimization(SiteID nSites, LabelID nLabels) 
00086 : m_num_labels(nLabels)
00087 , m_num_sites(nSites)
00088 , m_datacostIndividual(0)
00089 , m_smoothcostIndividual(0)
00090 , m_labelcostsAll(0)
00091 , m_labelcostsByLabel(0)
00092 , m_labelcostCount(0)
00093 , m_smoothcostFn(0)
00094 , m_datacostFn(0)
00095 , m_numNeighborsTotal(0)
00096 , m_queryActiveSitesExpansion(&GCoptimization::queryActiveSitesExpansion<DataCostFnFromArray>)
00097 , m_setupDataCostsSwap(0)
00098 , m_setupDataCostsExpansion(0)
00099 , m_setupSmoothCostsSwap(0)
00100 , m_setupSmoothCostsExpansion(0)
00101 , m_applyNewLabeling(0)
00102 , m_updateLabelingDataCosts(0)
00103 , m_giveSmoothEnergyInternal(0)
00104 , m_solveSpecialCases(&GCoptimization::solveSpecialCases<DataCostFnFromArray>)
00105 , m_datacostFnDelete(0)
00106 , m_smoothcostFnDelete(0)
00107 , m_random_label_order(false)
00108 , m_verbosity(0)
00109 , m_labelingInfoDirty(true)
00110 , m_lookupSiteVar(new SiteID[nSites])
00111 , m_labeling(new LabelID[nSites])
00112 , m_labelTable(new LabelID[nLabels])
00113 , m_labelingDataCosts(new EnergyTermType[nSites])
00114 , m_labelCounts(new SiteID[nLabels])
00115 , m_activeLabelCounts(new SiteID[m_num_labels])
00116 , m_stepsThisCycle(0)
00117 , m_stepsThisCycleTotal(0)
00118 {
00119         if ( nLabels <= 1 ) handleError("Number of labels must be >= 2");
00120         if ( nSites <= 0 )  handleError("Number of sites must be >= 1");
00122         if ( !m_lookupSiteVar || !m_labelTable || !m_labeling ){
00123                 if (m_lookupSiteVar) delete [] m_lookupSiteVar;
00124                 if (m_labelTable) delete [] m_labelTable;
00125                 if (m_labeling) delete [] m_labeling;
00126                 if (m_labelingDataCosts) delete [] m_labelingDataCosts;
00127                 if (m_labelCounts) delete [] m_labelCounts;
00128                 handleError("Not enough memory.");
00129         }
00131         memset(m_labeling, 0, m_num_sites*sizeof(LabelID));
00132         memset(m_lookupSiteVar,-1,m_num_sites*sizeof(SiteID));
00133         setLabelOrder(false);
00134         specializeSmoothCostFunctor(SmoothCostFnPotts());
00135 }
00137 //-------------------------------------------------------------------
00139 GCoptimization::~GCoptimization()
00140 {
00141         delete [] m_labelTable;
00142         delete [] m_lookupSiteVar;
00143         delete [] m_labeling;
00144         delete [] m_labelingDataCosts;
00145         delete [] m_labelCounts;
00146         delete [] m_activeLabelCounts;
00148         if (m_datacostFnDelete) m_datacostFnDelete(m_datacostFn);
00149         if (m_smoothcostFnDelete) m_smoothcostFnDelete(m_smoothcostFn);
00151         if (m_datacostIndividual) delete [] m_datacostIndividual;
00152         if (m_smoothcostIndividual) delete [] m_smoothcostIndividual;
00154         // Delete label cost bookkeeping structures
00155         //
00156         slist_clear(m_labelcostsAll);
00157         if (m_labelcostsByLabel) {
00158                 for ( LabelID i = 0; i < m_num_labels; ++i )
00159                         slist_clear(m_labelcostsByLabel[i]);
00160                 delete [] m_labelcostsByLabel;
00161         }
00162 }
00164 //-------------------------------------------------------------------
00166 template <>
00167 GCoptimization::SiteID GCoptimization::queryActiveSitesExpansion<GCoptimization::DataCostFnSparse>(LabelID alpha_label,SiteID *activeSites)
00168 {
00169         return ((DataCostFnSparse*)m_datacostFn)->queryActiveSitesExpansion(alpha_label,m_labeling,activeSites);
00170 }
00172 //-------------------------------------------------------------------
00174 template <>
00175 void GCoptimization::setupDataCostsExpansion<GCoptimization::DataCostFnSparse>(SiteID size,LabelID alpha_label,EnergyT *e,SiteID *activeSites)
00176 {
00177         DataCostFnSparse* dc = (DataCostFnSparse*)m_datacostFn;
00178         DataCostFnSparse::iterator dciter = dc->begin(alpha_label);
00179         for ( SiteID i = 0; i < size; ++i )
00180         {
00181                 SiteID site = activeSites[i];
00182                 while ( dciter.site() != site )
00183                         ++dciter;
00184                 addterm1_checked(e,i,dciter.cost(),m_labelingDataCosts[site]);
00185         }
00186 }
00188 //-------------------------------------------------------------------
00190 template <>
00191 void GCoptimization::applyNewLabeling<GCoptimization::DataCostFnSparse>(EnergyT *e,SiteID *activeSites,SiteID size,LabelID alpha_label)
00192 {
00193         DataCostFnSparse* dc = (DataCostFnSparse*)m_datacostFn;
00194         DataCostFnSparse::iterator dciter = dc->begin(alpha_label);
00195         for ( SiteID i = 0; i < size; i++ )
00196         {
00197                 if ( e->get_var(i) == 0 )
00198                 {
00199                         SiteID site = activeSites[i];
00200                         LabelID prev = m_labeling[site];
00201                         m_labeling[site] = alpha_label;
00202                         m_labelCounts[alpha_label]++;
00203                         m_labelCounts[prev]--;
00204                         while ( dciter.site() != site )
00205                                 ++dciter;
00206                         m_labelingDataCosts[site] = dciter.cost();
00207                 }
00208         }
00209         m_labelingInfoDirty = true;
00210         updateLabelingInfo(false,true,false); // labels have changed, so update necessary labeling info
00211 }
00213 //-------------------------------------------------------------------
00215 template <typename UserFunctor>
00216 void GCoptimization::specializeDataCostFunctor(const UserFunctor f) {
00217         if ( m_datacostFnDelete )
00218                 m_datacostFnDelete(m_datacostFn);
00219         if ( m_datacostIndividual )
00220         {
00221                 delete [] m_datacostIndividual;
00222                 m_datacostIndividual = 0;
00223         }
00224         m_datacostFn = new UserFunctor(f);
00225         m_datacostFnDelete          = &GCoptimization::deleteFunctor<UserFunctor>;
00226         m_queryActiveSitesExpansion = &GCoptimization::queryActiveSitesExpansion<UserFunctor>;
00227         m_setupDataCostsExpansion   = &GCoptimization::setupDataCostsExpansion<UserFunctor>;
00228         m_setupDataCostsSwap        = &GCoptimization::setupDataCostsSwap<UserFunctor>;
00229         m_applyNewLabeling          = &GCoptimization::applyNewLabeling<UserFunctor>;
00230         m_updateLabelingDataCosts   = &GCoptimization::updateLabelingDataCosts<UserFunctor>;
00231         m_solveSpecialCases         = &GCoptimization::solveSpecialCases<UserFunctor>;
00232 }
00234 template <typename UserFunctor>
00235 void GCoptimization::specializeSmoothCostFunctor(const UserFunctor f) {
00236         if ( m_smoothcostFnDelete )
00237                 m_smoothcostFnDelete(m_smoothcostFn);
00238         if ( m_smoothcostIndividual )
00239         {
00240                 delete [] m_smoothcostIndividual;
00241                 m_smoothcostIndividual = 0;
00242         }
00243         m_smoothcostFn = new UserFunctor(f);
00244         m_smoothcostFnDelete        = &GCoptimization::deleteFunctor<UserFunctor>;
00245         m_giveSmoothEnergyInternal  = &GCoptimization::giveSmoothEnergyInternal<UserFunctor>;
00246         m_setupSmoothCostsExpansion = &GCoptimization::setupSmoothCostsExpansion<UserFunctor>;
00247         m_setupSmoothCostsSwap      = &GCoptimization::setupSmoothCostsSwap<UserFunctor>;
00248 }
00250 //-------------------------------------------------------------------
00252 template <typename SmoothCostT>
00253 GCoptimization::EnergyType GCoptimization::giveSmoothEnergyInternal()
00254 {
00255         EnergyType eng = (EnergyType) 0;
00256         SiteID i,numN,*nPointer,nSite,n;
00257         EnergyTermType *weights;
00258         SmoothCostT* sc = (SmoothCostT*) m_smoothcostFn;
00259         for ( i = 0; i < m_num_sites; i++ )
00260         {
00261                 giveNeighborInfo(i,&numN,&nPointer,&weights);
00262                 for ( n = 0; n < numN; n++ )
00263                 {
00264                         nSite = nPointer[n];
00265                         if ( nSite < i ) 
00266                                 eng += weights[n]*(sc->compute(i,nSite,m_labeling[i],m_labeling[nSite]));
00267                 }
00268         }
00270         return eng;
00271 }
00273 //-------------------------------------------------------------------
00275 OLGA_INLINE void GCoptimization::addterm1_checked(EnergyT* e, VarID i, EnergyTermType e0, EnergyTermType e1)
00276 {
00277         if ( e0 > GCO_MAX_ENERGYTERM || e1 > GCO_MAX_ENERGYTERM )
00278                 handleError("Data cost term was larger than GCO_MAX_ENERGYTERM; danger of integer overflow.");
00279         m_beforeExpansionEnergy += e1;
00280         e->add_term1(i,e0,e1);
00281 }
00283 OLGA_INLINE void GCoptimization::addterm1_checked(EnergyT* e, VarID i, EnergyTermType e0, EnergyTermType e1, EnergyTermType w)
00284 {
00285         if ( e0 > GCO_MAX_ENERGYTERM || e1 > GCO_MAX_ENERGYTERM )
00286                 handleError("Smooth cost term was larger than GCO_MAX_ENERGYTERM; danger of integer overflow.");
00287         if ( w > GCO_MAX_ENERGYTERM )
00288                 handleError("Smoothness weight was larger than GCO_MAX_ENERGYTERM; danger of integer overflow.");
00289         m_beforeExpansionEnergy += e1*w;
00290         e->add_term1(i,e0*w,e1*w);
00291 }
00293 OLGA_INLINE void GCoptimization::addterm2_checked(EnergyT* e, VarID i, VarID j, EnergyTermType e00, EnergyTermType e01, EnergyTermType e10, EnergyTermType e11, EnergyTermType w)
00294 {
00295         if ( e00 > GCO_MAX_ENERGYTERM || e11 > GCO_MAX_ENERGYTERM || e01 > GCO_MAX_ENERGYTERM || e10 > GCO_MAX_ENERGYTERM )
00296                 handleError("Smooth cost term was larger than GCO_MAX_ENERGYTERM; danger of integer overflow.");
00297         if ( w > GCO_MAX_ENERGYTERM )
00298                 handleError("Smoothness weight was larger than GCO_MAX_ENERGYTERM; danger of integer overflow.");
00299         // Inside energy/maxflow code the submodularity check is performed as an assertion,
00300         // but is optimized out. We check it in release builds as well.
00301         if ( e00+e11 > e01+e10 )
00302                 handleError("Non-submodular expansion term detected; smooth costs must be a metric for expansion");
00303         m_beforeExpansionEnergy += e11*w;
00304         e->add_term2(i,j,e00*w,e01*w,e10*w,e11*w);
00305 }
00307 //------------------------------------------------------------------
00309 template <typename DataCostT>
00310 GCoptimization::SiteID GCoptimization::queryActiveSitesExpansion(LabelID alpha_label,SiteID *activeSites)
00311 {
00312         SiteID size = 0;
00313         for ( SiteID i = 0; i < m_num_sites; i++ )
00314                 if ( m_labeling[i] != alpha_label )
00315                         activeSites[size++] = i;
00316         return size;
00317 }
00319 //-------------------------------------------------------------------
00321 template <typename DataCostT>
00322 void GCoptimization::setupDataCostsExpansion(SiteID size,LabelID alpha_label,EnergyT *e,SiteID *activeSites)
00323 {
00324         DataCostT* dc = (DataCostT*)m_datacostFn;
00325         for ( SiteID i = 0; i < size; ++i )
00326                 addterm1_checked(e,i,dc->compute(activeSites[i],alpha_label),m_labelingDataCosts[activeSites[i]]);
00327 }
00329 //-------------------------------------------------------------------
00331 template <typename SmoothCostT>
00332 void GCoptimization::setupSmoothCostsExpansion(SiteID size,LabelID alpha_label,EnergyT *e,SiteID *activeSites)
00333 {
00334         SiteID i,nSite,site,n,nNum,*nPointer;
00335         EnergyTermType *weights;
00336         SmoothCostT* sc = (SmoothCostT*)m_smoothcostFn;
00338         for ( i = size - 1; i >= 0; i-- )
00339         {
00340                 site = activeSites[i];
00341                 giveNeighborInfo(site,&nNum,&nPointer,&weights);
00342                 for ( n = 0; n < nNum; n++ )
00343                 {
00344                         nSite = nPointer[n];
00345                         if ( m_lookupSiteVar[nSite] == -1 ) 
00346                                 addterm1_checked(e,i,sc->compute(site,nSite,alpha_label,m_labeling[nSite]),
00347                                                      sc->compute(site,nSite,m_labeling[site],m_labeling[nSite]),weights[n]);
00348                         else if ( nSite < site ) 
00349                         {
00350                                 addterm2_checked(e,i,m_lookupSiteVar[nSite],
00351                                                  sc->compute(site,nSite,alpha_label,alpha_label),
00352                                                  sc->compute(site,nSite,alpha_label,m_labeling[nSite]),
00353                                                  sc->compute(site,nSite,m_labeling[site],alpha_label),
00354                                                  sc->compute(site,nSite,m_labeling[site],m_labeling[nSite]),weights[n]);
00355                         }
00356                 }
00357         }
00358 }
00360 //-----------------------------------------------------------------------------------
00362 template <typename DataCostT>
00363 void GCoptimization::setupDataCostsSwap(SiteID size, LabelID alpha_label, LabelID beta_label,
00364                                                                                  EnergyT *e,SiteID *activeSites )
00365 {
00366         DataCostT* dc = (DataCostT*)m_datacostFn;
00367         for ( SiteID i = 0; i < size; i++ )
00368         {
00369                 e->add_term1(i,dc->compute(activeSites[i],alpha_label),
00370                                dc->compute(activeSites[i],beta_label) );
00371         }
00372 }
00374 //-------------------------------------------------------------------
00376 template <typename SmoothCostT>
00377 void GCoptimization::setupSmoothCostsSwap(SiteID size, LabelID alpha_label,LabelID beta_label,
00378                                                                                  EnergyT *e,SiteID *activeSites )
00379 {
00380         SiteID i,nSite,site,n,nNum,*nPointer;
00381         EnergyTermType *weights;
00382         SmoothCostT* sc = (SmoothCostT*)m_smoothcostFn;
00384         for ( i = size - 1; i >= 0; i-- )
00385         {
00386                 site = activeSites[i];
00387                 giveNeighborInfo(site,&nNum,&nPointer,&weights);
00388                 for ( n = 0; n < nNum; n++ )
00389                 {
00390                         nSite = nPointer[n];
00391                         if ( m_lookupSiteVar[nSite] == -1 )
00392                                 addterm1_checked(e,i,sc->compute(site,nSite,alpha_label,m_labeling[nSite]),
00393                                                      sc->compute(site,nSite,beta_label, m_labeling[nSite]),weights[n]);
00394                         else if ( nSite < site )
00395                         {
00396                                 addterm2_checked(e,i,m_lookupSiteVar[nSite],
00397                                                  sc->compute(site,nSite,alpha_label,alpha_label),
00398                                                  sc->compute(site,nSite,alpha_label,beta_label),
00399                                                  sc->compute(site,nSite,beta_label,alpha_label),
00400                                                  sc->compute(site,nSite,beta_label,beta_label),weights[n]);
00401                         }
00402                 }
00403         }
00404 }
00406 //-----------------------------------------------------------------------------------
00408 template <typename DataCostT>
00409 void GCoptimization::applyNewLabeling(EnergyT *e,SiteID *activeSites,SiteID size,LabelID alpha_label)
00410 {
00411         DataCostT* dc = (DataCostT*)m_datacostFn;
00412         for ( SiteID i = 0; i < size; i++ )
00413         {
00414                 if ( e->get_var(i) == 0 )
00415                 {
00416                         SiteID site = activeSites[i];
00417                         LabelID prev = m_labeling[site];
00418                         m_labeling[site] = alpha_label;
00419                         m_labelCounts[alpha_label]++;
00420                         m_labelCounts[prev]--;
00421                         m_labelingDataCosts[site] = dc->compute(site,alpha_label);
00422                 }
00423         }
00424         m_labelingInfoDirty = true;
00425         updateLabelingInfo(false,true,false); // labels have changed, so update necessary labeling info
00426 }
00428 //-----------------------------------------------------------------------------------
00430 template <typename DataCostT>
00431 void GCoptimization::updateLabelingDataCosts()
00432 {
00433         DataCostT* dc = (DataCostT*)m_datacostFn;
00434         for (int i = 0; i < m_num_sites; ++i)
00435                 m_labelingDataCosts[i] = dc->compute(i,m_labeling[i]);
00436 }
00438 //-----------------------------------------------------------------------------------
00440 template <typename DataCostT>
00441 bool GCoptimization::solveSpecialCases(EnergyType& energy)
00442 {
00443         finalizeNeighbors();
00445         DataCostT* dc = (DataCostT*)m_datacostFn;
00446         bool sc = m_numNeighborsTotal != 0;
00447         bool lc = m_labelcostsAll != 0;
00449         if ( !dc && !sc && !lc )
00450         {
00451                 energy = 0;
00452                 return true;
00453         }
00455         if ( dc && !sc && !lc ) {
00456                 // Special case: No label costs, so return trivial solution
00457                 energy = 0;
00458                 for ( SiteID i = 0; i < m_num_sites; ++i ) {
00459                         LabelID minCostLabel = 0;
00460                         EnergyTermType minCost = dc->compute(i, 0);
00461                         for ( LabelID l = 1; l < m_num_labels; ++l ) {
00462                                 EnergyTermType lcost = dc->compute(i, l);
00463                                 if ( lcost < minCost ) {
00464                                         minCostLabel = l;
00465                                         minCost = lcost;
00466                                 }
00467                         }
00468                         if ( minCostLabel > GCO_MAX_ENERGYTERM )
00469                                 handleError("Data cost was larger than GCO_MAX_ENERGYTERM; danger of integer overflow.");
00470                         m_labeling[i] = minCostLabel;
00471                         energy += minCost;
00472                 }
00473                 m_labelingInfoDirty = true;
00474                 updateLabelingInfo();
00475                 return true;
00476         }
00478         if ( !dc && !sc && lc ) {
00479                 // Special case: No data costs, so return trivial solution
00480                 LabelID minLabel = 0;
00481                 EnergyType minLabelCost = GCO_MAX_ENERGYTERM*(EnergyType)m_num_labels;
00482                 for ( LabelID l = 0; l < m_num_labels; ++l ) {
00483                         EnergyType lcsum = 0;
00484                         for ( LabelCostIter* lci = m_labelcostsByLabel[l]; lci; lci = lci->next )
00485                                 lcsum += lci->node->cost;
00486                         if ( lcsum < minLabelCost ) {
00487                                 minLabel = l;
00488                                 minLabelCost = lcsum;
00489                         }
00490                 }
00491                 for ( SiteID i = 0; i < m_num_sites; ++i )
00492                         m_labeling[i] = minLabel;
00493                 energy = minLabelCost;
00494                 m_labelingInfoDirty = true;
00495                 updateLabelingInfo();
00496                 return true;
00497         }
00499         if ( dc && !sc && lc ) {
00500                 LabelCost* lc;
00501                 for ( lc = m_labelcostsAll; lc; lc = lc->next )
00502                         if ( lc->numLabels > 1)
00503                                 break;
00504                 if ( !lc ) {
00505                         // Special case: Data costs and per-label costs 
00506                         energy = solveGreedy<DataCostT>();
00507                         return true;
00508                 }
00509         }
00511         // Otherwise, use full-blown expansion/swap
00512         return false;
00513 }
00515 template <>
00516 class GCoptimization::GreedyIter<GCoptimization::DataCostFnSparse> {
00517 public:
00518         GreedyIter(DataCostFnSparse& dc, SiteID)
00519         : m_dc(dc), m_label(0), m_labelend(0)
00520         { }
00522         OLGA_INLINE void start(const LabelID* labels, LabelID labelCount=1)
00523         {
00524                 m_label = labels;
00525                 m_labelend = labels + labelCount;
00526                 if (labelCount > 0) {
00527                         m_site = m_dc.begin(*labels);
00528                         m_siteend = m_dc.end(*labels);
00529                         while (m_site == m_siteend) {
00530                                 if (++m_label == m_labelend)
00531                                         break;
00532                                 m_site     = m_dc.begin(*m_label);
00533                                 m_siteend  = m_dc.end(*m_label);
00534                         }
00535                 }
00536         }
00537         OLGA_INLINE SiteID site()  const { return m_site.site(); }
00538         OLGA_INLINE SiteID label() const { return *m_label; }
00539         OLGA_INLINE bool   done()  const { return m_label >= m_labelend; }
00540         OLGA_INLINE GreedyIter& operator++() 
00541         {
00542                 // The inner loop is over sites, not labels, because sparse data costs 
00543                 // are stored as consecutive [sparse] SiteIDs with respect to each label.
00544                 if (++m_site == m_siteend) {
00545                         while (++m_label < m_labelend) {
00546                                 m_site     = m_dc.begin(*m_label);
00547                                 m_siteend  = m_dc.end(*m_label);
00548                                 if (m_site != m_siteend)
00549                                         break;
00550                         }
00551                 }
00552                 return *this;
00553         }
00554         OLGA_INLINE EnergyTermType compute() const { return m_site.cost(); }
00555         OLGA_INLINE SiteID feasibleSites() const { return (SiteID)(m_siteend - m_site); }
00557 private:
00558         DataCostFnSparse::iterator m_site;
00559         DataCostFnSparse::iterator m_siteend;
00560         DataCostFnSparse& m_dc;
00561         const LabelID* m_label;
00562         const LabelID* m_labelend;
00563 };
00565 template <typename DataCostT>
00566 GCoptimization::EnergyType GCoptimization::solveGreedy()
00567 {
00568         printStatus1("starting greedy algorithm (1 cycle only)");
00569         m_stepsThisCycle = m_stepsThisCycleTotal = 0;
00571         EnergyType estart = compute_energy();
00572         EnergyType efinal = 0;
00573         LabelID* oldLabeling = m_labeling;
00574         m_labeling = new LabelID[m_num_sites];
00575         EnergyType* e = new EnergyType[m_num_labels];
00576         LabelID* order = new LabelID[m_num_labels];  // order[0..activeCount-1] contains the activated labels so far
00578         try {
00579                 gcoclock_t ticks0all = gcoclock();
00580                 gcoclock_t ticks0 = gcoclock();
00582                 // clear active flags
00583                 for ( LabelCost* lc = m_labelcostsAll; lc; lc = lc->next)
00584                         lc->active = false;
00586                 DataCostT* dc = (DataCostT*)m_datacostFn;
00587                 GreedyIter<DataCostT> iter(*dc,m_num_sites);
00588                 LabelID alpha = 0;
00590                 // Treat first iteration as special case. 
00591                 // Ignore current labeling and just find the greedy initial label.
00592                 for ( LabelID l = 0; l < m_num_labels; ++l ) {
00593                         e[l] = 0;
00594                         for ( LabelCostIter* lci = m_labelcostsByLabel[l]; lci; lci = lci->next )
00595                                 e[l] += lci->node->cost;
00596                         iter.start(&l);
00597                         e[l] += (EnergyType)(m_num_sites - iter.feasibleSites()) * GCO_MAX_ENERGYTERM; // pre-add GCO_MAX_ENERGYTERM for all infeasible sites
00598                         for (; !iter.done(); ++iter) {
00599                                 EnergyTermType dataCost = iter.compute();
00600                                 if ( dataCost > GCO_MAX_ENERGYTERM )
00601                                         handleError("Data cost was larger than GCO_MAX_ENERGYTERM; danger of integer overflow.");
00602                                 e[l] += dataCost;
00603                                 if ( e[l] > e[alpha] ) // break out early if this will definitely 
00604                                         break;             // not be a good label to start from
00605                         }
00606                         if ( e[l] < e[alpha] ) // choose alpha with minimum energy e[alpha]
00607                                 alpha = l;
00608                 }
00609                 for ( SiteID i = 0; i < m_num_sites; ++i ) {
00610                         m_labeling[i] = alpha;
00611                         m_labelingDataCosts[i] = dc->compute(i,alpha);
00612                 }
00613                 for ( LabelCostIter* lci = m_labelcostsByLabel[alpha]; lci; lci = lci->next )
00614                         lci->node->active = true;
00616                 // List of labels in the order that they were expanded upon (order[0] first, order[1] second, ...)
00617                 for ( LabelID l = 0; l < m_num_labels; ++l )
00618                         order[l] = l;
00619                 order[alpha] = 0;
00620                 order[0] = alpha;
00622                 printStatus2(alpha,-1,m_num_sites,ticks0);
00624                 // Greedily expand remaining labels
00625                 for ( LabelID alpha_count = 1; alpha_count <= m_num_labels; ++alpha_count) {
00626                         checkInterrupt();
00627                         ticks0 = gcoclock();
00629                         // Energy e[l] for expanding on label 'l' starts at e[alpha] + new labelcosts for introducing l
00630                         LabelID alpha_prev = alpha;
00631                         for ( LabelID li = alpha_count; li < m_num_labels; ++li ) {
00632                                 LabelID l = order[li];
00633                                 e[l] = e[alpha_prev];
00634                                 for ( LabelCostIter* lci = m_labelcostsByLabel[l]; lci; lci = lci->next )
00635                                         if ( !lci->node->active )
00636                                                 e[l] += lci->node->cost;
00637                         }
00639                         // Loop over all sites and all remaining labels to calculate energy drop.
00640                         for ( iter.start(&order[alpha_count],m_num_labels-alpha_count); !iter.done(); ++iter ) {
00641                                 EnergyTermType dc_l = iter.compute();
00642                                 EnergyTermType dc_i = m_labelingDataCosts[iter.site()];
00643                                 EnergyTermType delta_i = dc_l - dc_i;
00644                                 if ( delta_i < 0 )
00645                                         e[iter.label()] += delta_i;
00646                         }
00648                         // Choose the next alpha based on lowest resulting energy
00649                         LabelID alpha_index = alpha_count-1;
00650                         for ( LabelID li = alpha_count; li < m_num_labels; ++li ) {
00651                                 LabelID l = order[li];
00652                                 if ( e[l] < e[alpha] ) {
00653                                         alpha = l;
00654                                         alpha_index = li;
00655                                 }
00656                         }
00658                         if ( alpha == alpha_prev )
00659                                 break;
00661                         // Append alpha to the list of activated labels
00662                         LabelID temp = order[alpha_count];
00663                         order[alpha_count] = order[alpha_index];
00664                         order[alpha_index] = temp;
00666                         // Apply the new labeling, updating m_labelingDataCosts and active labelcosts as necessary
00667                         iter.start(&alpha);
00668                         SiteID size = iter.feasibleSites();
00669                         for ( ; !iter.done(); ++iter ) {
00670                                 EnergyTermType dc_l = iter.compute();
00671                                 EnergyTermType dc_i = m_labelingDataCosts[iter.site()];
00672                                 EnergyTermType delta_i = dc_l - dc_i;
00673                                 if ( delta_i < 0 ) {
00674                                         m_labeling[iter.site()] = alpha;
00675                                         m_labelingDataCosts[iter.site()] = dc_l;
00676                                 }
00677                         }
00678                         for ( LabelCostIter* lci = m_labelcostsByLabel[alpha]; lci; lci = lci->next )
00679                                 lci->node->active = true;
00680                         printStatus2(alpha,-1,size,ticks0);
00681                 }
00683                 efinal = e[alpha];
00684                 if ( efinal < estart ) {
00685                         // Greedy succeeded in lowering energy compared to initial labeling
00686                         delete [] oldLabeling;
00687                         m_labelingInfoDirty = true;
00688                         updateLabelingInfo(true,false,false); // update m_labelCounts only; m_labelingDataCosts and active labelcosts should be up to date
00689                         printStatus1(1,false,ticks0all);
00690                 } else {
00691                         // Greedy failed to find a lower energy, so revert everything
00692                         efinal = estart;
00693                         delete [] m_labeling;
00694                         m_labeling = oldLabeling;
00695                         m_labelingInfoDirty = true;
00696                         updateLabelingInfo(); // put all labeling info back the way it was
00697                         printStatus1(1,false,ticks0all);
00698                 }
00700                 delete [] order;
00701                 delete [] e;
00702         } catch (...) {
00703                 delete [] order;
00704                 delete [] e;
00705                 throw;
00706         }
00707         return efinal;
00708 }
00710 //------------------------------------------------------------------
00712 void GCoptimization::setDataCost(DataCostFn fn) { 
00713         specializeDataCostFunctor(DataCostFnFromFunction(fn));
00714         m_labelingInfoDirty = true;
00715 }
00718 //------------------------------------------------------------------
00720 void GCoptimization::setDataCost(DataCostFnExtra fn, void *extraData) { 
00721         specializeDataCostFunctor(DataCostFnFromFunctionExtra(fn, extraData));
00722         m_labelingInfoDirty = true;
00723 }
00725 //-------------------------------------------------------------------
00727 void GCoptimization::setDataCost(EnergyTermType *dataArray) {
00728         specializeDataCostFunctor(DataCostFnFromArray(dataArray, m_num_labels));
00729         m_labelingInfoDirty = true;
00730 }
00732 //-------------------------------------------------------------------
00734 void GCoptimization::setDataCost(SiteID s, LabelID l, EnergyTermType e) {
00735         if ( !m_datacostIndividual )
00736         {
00737                 EnergyTermType* table = new EnergyTermType[m_num_sites*m_num_labels];
00738                 memset(table, 0, m_num_sites*m_num_labels*sizeof(EnergyTermType));
00739                 specializeDataCostFunctor(DataCostFnFromArray(table, m_num_labels));
00740                 m_datacostIndividual = table;
00741                 m_labelingInfoDirty = true;
00742         }
00743         m_datacostIndividual[s*m_num_labels + l] = e;
00744         if ( m_labeling[s] == l )
00745                 m_labelingInfoDirty = true; // m_labelingDataCosts is dirty
00746 }
00748 //-------------------------------------------------------------------
00750 void GCoptimization::setDataCostFunctor(DataCostFunctor* f) {
00751         if ( m_datacostFnDelete )
00752                 m_datacostFnDelete(m_datacostFn);
00753         if ( m_datacostIndividual )
00754         {
00755                 delete [] m_datacostIndividual;
00756                 m_datacostIndividual = 0;
00757         }
00758         m_datacostFn = f;
00759         m_datacostFnDelete          = 0;
00760         m_queryActiveSitesExpansion = &GCoptimization::queryActiveSitesExpansion<DataCostFunctor>;
00761         m_setupDataCostsExpansion   = &GCoptimization::setupDataCostsExpansion<DataCostFunctor>;
00762         m_setupDataCostsSwap        = &GCoptimization::setupDataCostsSwap<DataCostFunctor>;
00763         m_applyNewLabeling          = &GCoptimization::applyNewLabeling<DataCostFunctor>;
00764         m_updateLabelingDataCosts   = &GCoptimization::updateLabelingDataCosts<DataCostFunctor>;
00765         m_solveSpecialCases         = &GCoptimization::solveSpecialCases<DataCostFunctor>;
00766         m_labelingInfoDirty = true;
00767 }
00769 //-------------------------------------------------------------------
00771 void GCoptimization::setDataCost(LabelID l, SparseDataCost *costs, SiteID count)
00772 {
00773         if ( !m_datacostFn )
00774                 specializeDataCostFunctor(DataCostFnSparse(numSites(),numLabels()));
00775         else if ( m_queryActiveSitesExpansion != (SiteID (GCoptimization::*)(LabelID,SiteID*))&GCoptimization::queryActiveSitesExpansion<DataCostFnSparse> )
00776                 handleError("Cannot apply sparse data costs after dense data costs have been used.");
00777         m_labelingInfoDirty = true;
00778         DataCostFnSparse* dc = (DataCostFnSparse*)m_datacostFn;
00779         dc->set(l,costs,count);
00780 }
00782 //-------------------------------------------------------------------
00784 void GCoptimization::setSmoothCost(SmoothCostFn fn) {
00785         specializeSmoothCostFunctor(SmoothCostFnFromFunction(fn));
00786 }
00788 //-------------------------------------------------------------------
00790 void GCoptimization::setSmoothCost(SmoothCostFnExtra fn, void* extraData) {
00791         specializeSmoothCostFunctor(SmoothCostFnFromFunctionExtra(fn, extraData));
00792 }
00794 //-------------------------------------------------------------------
00796 void GCoptimization::setSmoothCost(EnergyTermType *smoothArray) {
00797         specializeSmoothCostFunctor(SmoothCostFnFromArray(smoothArray, m_num_labels));
00798 }
00800 //-------------------------------------------------------------------
00802 void GCoptimization::setSmoothCost(LabelID l1, LabelID l2, EnergyTermType e){
00803         if ( !m_smoothcostIndividual )
00804         {
00805                 EnergyTermType* table = new EnergyTermType[m_num_labels*m_num_labels];
00806                 memset(table, 0, m_num_labels*m_num_labels*sizeof(EnergyTermType));
00807                 specializeSmoothCostFunctor(SmoothCostFnFromArray(table, m_num_labels));
00808                 m_smoothcostIndividual = table;
00809         } 
00810         m_smoothcostIndividual[l1*m_num_labels + l2] = e;
00811 }
00813 //-------------------------------------------------------------------
00815 void GCoptimization::setSmoothCostFunctor(SmoothCostFunctor* f) {
00816         if ( m_smoothcostFnDelete )
00817                 m_smoothcostFnDelete(m_smoothcostFn);
00818         if ( m_smoothcostIndividual )
00819         {
00820                 delete [] m_smoothcostIndividual;
00821                 m_smoothcostIndividual = 0;
00822         }
00823         m_smoothcostFn = f;
00824         m_smoothcostFnDelete        = 0;
00825         m_giveSmoothEnergyInternal  = &GCoptimization::giveSmoothEnergyInternal<SmoothCostFunctor>;
00826         m_setupSmoothCostsExpansion = &GCoptimization::setupSmoothCostsExpansion<SmoothCostFunctor>;
00827         m_setupSmoothCostsSwap      = &GCoptimization::setupSmoothCostsSwap<SmoothCostFunctor>;
00828 }
00830 //-------------------------------------------------------------------
00832 void GCoptimization::setLabelCost(EnergyTermType cost) 
00833 {
00834         EnergyTermType* lc = new EnergyTermType[m_num_labels];
00835         for ( LabelID i = 0; i < m_num_labels; ++i )
00836                 lc[i] = cost;
00837         setLabelCost(lc);
00838         delete [] lc;
00839 }
00841 //-------------------------------------------------------------------
00843 void GCoptimization::setLabelCost(EnergyTermType *costArray) 
00844 {
00845         for ( LabelID i = 0; i < m_num_labels; ++i )
00846                 setLabelSubsetCost(&i, 1, costArray[i]);
00847 }
00849 //-------------------------------------------------------------------
00851 void GCoptimization::setLabelSubsetCost(LabelID* labels, LabelID numLabels, EnergyTermType cost)
00852 {
00853         if ( cost < 0 )
00854                 handleError("Label costs must be non-negative.");
00855         if ( cost > GCO_MAX_ENERGYTERM )
00856                 handleError("Label cost was larger than GCO_MAX_ENERGYTERM; danger of integer overflow.");
00857         for ( LabelID i = 0; i < numLabels; ++i)
00858                 if ( labels[i] < 0 || labels[i] >= m_num_labels )
00859                         handleError("Invalid label id was found in label subset list.");
00861         if ( !m_labelcostsByLabel ) {
00862                 m_labelcostsByLabel = new LabelCostIter*[m_num_labels];
00863                 memset(m_labelcostsByLabel, 0, m_num_labels*sizeof(void*));
00864         }
00866         // If this particular subset already has a cost, simply replace it.
00867         for ( LabelCostIter* lci = m_labelcostsByLabel[labels[0]]; lci; lci = lci->next ) {
00868                 if ( numLabels == lci->node->numLabels ) {
00869                         if ( !memcmp(labels, lci->node->labels, numLabels*sizeof(LabelID)) ) {
00870                                 // This label subset already exists, so just update the cost and return
00871                                 lci->node->cost = cost;
00872                                 return;
00873                         }
00874                 }
00875         }
00877         if (cost == 0)
00878                 return;
00880         // Create a new LabelCost entry and add it to the appropriate lists
00881         m_labelcostCount++;
00882         LabelCost* lc = new LabelCost;
00883         lc->cost = cost; 
00884         lc->active = false;
00885         lc->aux = -1;
00886         lc->numLabels = numLabels;
00887         lc->labels = new LabelID[numLabels];
00888         memcpy(lc->labels, labels, numLabels*sizeof(LabelID));
00889         slist_prepend(m_labelcostsAll, lc);
00890         for ( LabelID i = 0; i < numLabels; ++i ) {
00891                 LabelCostIter* lci = new LabelCostIter;
00892                 lci->node = lc; 
00893                 slist_prepend(m_labelcostsByLabel[labels[i]], lci);
00894         }
00895 }
00897 //-------------------------------------------------------------------
00899 void GCoptimization::whatLabel(SiteID start, SiteID count, LabelID* labeling)
00900 {
00901         assert(start >= 0 && start+count <= m_num_sites);
00902         memcpy(labeling, m_labeling+start, count*sizeof(LabelID));
00903 }
00905 //-------------------------------------------------------------------
00907 GCoptimization::EnergyType GCoptimization::giveSmoothEnergy()
00908 {
00909         finalizeNeighbors();
00910         if ( m_giveSmoothEnergyInternal ) 
00911                 return( (this->*m_giveSmoothEnergyInternal)());
00912         return 0;
00913 }
00915 //-------------------------------------------------------------------
00917 GCoptimization::EnergyType GCoptimization::giveDataEnergy()
00918 {
00919         updateLabelingInfo();
00920         EnergyType energy = 0;
00921         for ( SiteID i = 0; i < m_num_sites; i++ )
00922                 energy += m_labelingDataCosts[i];
00923         return energy;
00924 }
00926 GCoptimization::EnergyType GCoptimization::giveLabelEnergy()
00927 {
00928         updateLabelingInfo();
00929         EnergyType energy = 0;
00930         for ( LabelCost* lc = m_labelcostsAll; lc; lc = lc->next)
00931                 if ( lc->active )
00932                         energy += lc->cost;
00933         return energy;
00934 }
00936 //-------------------------------------------------------------------
00938 GCoptimization::EnergyType GCoptimization::compute_energy()
00939 {
00940         return giveDataEnergy() + giveSmoothEnergy() + giveLabelEnergy();
00941 }
00943 //-------------------------------------------------------------------
00945 void GCoptimization::permuteLabelTable()
00946 {
00947         if ( !m_random_label_order )
00948                 return;
00949         for ( LabelID i = 0; i < m_num_labels; i++ )
00950         {
00951                 LabelID j = i + (rand() % (m_num_labels-i));
00952                 LabelID temp    = m_labelTable[i];
00953                 m_labelTable[i] = m_labelTable[j];
00954                 m_labelTable[j] = temp;
00955         }
00956 }
00958 //-------------------------------------------------------------------
00960 GCoptimization::EnergyType GCoptimization::expansion(int max_num_iterations)
00961 {
00962         EnergyType new_energy, old_energy;
00963         if ( (this->*m_solveSpecialCases)(new_energy) )
00964                 return new_energy;
00966         permuteLabelTable();
00967         updateLabelingInfo();
00969         try 
00970         {
00971                 if ( max_num_iterations == -1 )
00972                 {
00973                         // Strategic expansion loop focuses on labels that successfuly reduced the energy
00974                         printStatus1("starting alpha-expansion w/ adaptive cycles");
00975                         std::vector<LabelID> queueSizes;
00976                         queueSizes.push_back(m_num_labels);
00978                         int cycle = 1;
00979                         LabelID next = 0;
00980                         do
00981                         {
00982                                 gcoclock_t ticks0 = gcoclock();
00983                                 m_stepsThisCycle = 0; 
00985                                 // Make a pass over the unchecked labels in the current queue, i.e. m_labelTable[next..queueSize-1]
00986                                 LabelID queueSize = queueSizes.back();
00987                                 LabelID start = next;
00988                                 m_stepsThisCycleTotal = queueSize - start;
00989                                 do 
00990                                 {
00991                                         if ( !alpha_expansion(m_labelTable[next]) )
00992                                                 std::swap(m_labelTable[next],m_labelTable[--queueSize]); // don't put this label in a new queue
00993                                         else
00994                                                 ++next; // keep this label for the next (smaller) queue
00995                                         m_stepsThisCycle++;
00996                                 } while ( next < queueSize );
00998                                 if ( next == start )  // No expansion was successful, so try more labels from the previous queue
00999                                 {
01000                                         next = queueSizes.back();
01001                                         queueSizes.pop_back();
01002                                 }
01003                                 else if ( queueSize < queueSizes.back()/2 ) // Some expansions were successful, so focus on them in a new queue
01004                                 {
01005                                         next = 0;
01006                                         queueSizes.push_back(queueSize);
01007                                 }
01008                                 else
01009                                         next = 0;  // All expansions were successful, so do another complete sweep
01011                                 printStatus1(cycle++,false,ticks0);
01012                         } while ( !queueSizes.empty() );
01013                         new_energy = compute_energy();
01014                 }
01015                 else
01016                 {
01017                         // Standard expansion loop sweeps over all labels each cycle
01018                         printStatus1("starting alpha-expansion w/ standard cycles");
01019                         new_energy = compute_energy();
01020                         old_energy = new_energy+1;
01021                         for ( int cycle = 1; cycle <= max_num_iterations; cycle++ )
01022                         {
01023                                 gcoclock_t ticks0 = gcoclock();
01024                                 old_energy = new_energy;
01025                                 new_energy = oneExpansionIteration();
01026                                 printStatus1(cycle,false,ticks0);
01027                                 if ( new_energy == old_energy )
01028                                         break;
01029                                 permuteLabelTable();
01030                         }
01031                 }
01032         } 
01033         catch (...)
01034         {
01035                 m_stepsThisCycle = m_stepsThisCycleTotal = 0;
01036                 throw;
01037         }
01038         m_stepsThisCycle = m_stepsThisCycleTotal = 0; // set so that alpha_expansion() knows it's no inside expansion() if called externally
01039         return new_energy;
01040 }
01042 //-------------------------------------------------------------------
01044 void GCoptimization::setLabelOrder(bool isRandom)
01045 {
01046         m_random_label_order = isRandom;
01047         for ( LabelID i = 0; i < m_num_labels; i++ )
01048                 m_labelTable[i] = i;
01049 }
01051 //-------------------------------------------------------------------
01053 void GCoptimization::setLabelOrder(const LabelID* order, LabelID size)
01054 {
01055         if ( size > m_num_labels )
01056                 handleError("setLabelOrder receieved too many labels");
01057         for ( LabelID i = 0; i < size; ++i )
01058                 if ( order[i] < 0 || order[i] >= m_num_labels )
01059                         handleError("Invalid label id in setLabelOrder");
01060         m_random_label_order = false;
01061         memcpy(m_labelTable,order,size*sizeof(LabelID));
01062         memset(m_labelTable+size,-1,(m_num_labels-size)*sizeof(LabelID));
01063 }
01065 //------------------------------------------------------------------
01067 void GCoptimization::handleError(const char *message)
01068 {
01069         throw GCException(message);
01070 }
01072 //------------------------------------------------------------------
01074 void GCoptimization::checkInterrupt()
01075 {
01076         if ( utIsInterruptPending() )
01077                 throw GCException("Interrupted.");
01078 }
01081 //-------------------------------------------------------------------//
01082 //                  METHODS for EXPANSION MOVES                      //  
01083 //-------------------------------------------------------------------//
01085 GCoptimization::EnergyType GCoptimization::setupLabelCostsExpansion(SiteID size,LabelID alpha_label,EnergyT *e,SiteID *activeSites)
01086 {
01087         EnergyType alphaCostCorrection = 0;
01088         if ( !m_labelcostsAll )
01089                 return alphaCostCorrection;
01091         const SiteID DISABLE = -2;
01092         const SiteID UNINIT  = -1;
01093         for ( LabelCost* lc = m_labelcostsAll; lc; lc = lc->next )
01094                 lc->aux = UNINIT;
01096         // Skip higher-order costs that include alpha_label or any label used
01097         // outside the activeSites, since they cannot be eliminated by the expansion.
01098         if ( m_queryActiveSitesExpansion == (SiteID (GCoptimization::*)(LabelID,SiteID*))&GCoptimization::queryActiveSitesExpansion<DataCostFnSparse> )
01099         {
01100                 // For sparse data costs, things are more complicated, because we must ensure that
01101                 // no label cost for a fixed (non-active) non-alpha label is encoded in the graph.
01102                 memset(m_activeLabelCounts,0,m_num_labels*sizeof(SiteID));
01103                 for ( SiteID i = 0; i < size; ++i )
01104                         m_activeLabelCounts[m_labeling[activeSites[i]]]++;
01106                 for ( LabelID l = 0; l < m_num_labels; ++l )
01107                 {
01108                         if ( m_activeLabelCounts[l] != m_labelCounts[l] )
01109                         {
01110                                 for ( LabelCostIter* lcj = m_labelcostsByLabel[l]; lcj; lcj = lcj->next )
01111                                         lcj->node->aux = DISABLE;
01112                         }
01113                 }
01114         }
01115         for ( LabelCostIter* lci = m_labelcostsByLabel[alpha_label]; lci; lci = lci->next )
01116                 lci->node->aux = DISABLE;
01118         // Since we're explicitly omitting the alpha_label label costs from the binary energy, 
01119         // calculate what it would have been, so that we can potentially reject the expansion afterwards.
01120         if ( !m_labelCounts[alpha_label] )
01121         {
01122                 for ( LabelCostIter* lci = m_labelcostsByLabel[alpha_label]; lci; lci = lci->next )
01123                         if ( !lci->node->active )
01124                                 alphaCostCorrection += lci->node->cost;
01125         }
01127         // Add edges to the graph, including auxiliary vertices as needed
01128         for ( SiteID i = 0; i < size; i++ )
01129         {
01130                 LabelID label_i = m_labeling[activeSites[i]];
01131                 for ( LabelCostIter* lci = m_labelcostsByLabel[label_i]; lci; lci = lci->next ) 
01132                 {
01133                         LabelCost* lc = lci->node;
01134                         if ( lc->aux == DISABLE )
01135                                 continue;
01137                         // Add auxiliary variable if necessary, and add pairwise potential
01138                         if ( lc->aux == UNINIT ) 
01139                         {
01140                                 lc->aux = e->add_variable();
01141                                 e->add_term1(lc->aux,0,lc->cost);
01142                                 m_beforeExpansionEnergy += lc->cost;
01143                         }
01144                         e->add_term2(i,lc->aux,0,0,lc->cost,0);
01145                 }
01146         }
01148         return alphaCostCorrection;
01149 }
01151 //-------------------------------------------------------------------
01152 void GCoptimization::updateLabelingInfo(bool updateCounts, bool updateActive, bool updateCosts)
01153 {
01154         if ( !m_labelingInfoDirty )
01155                 return;
01157         m_labelingInfoDirty = false;
01159         if ( m_labelcostsAll )
01160         {
01161                 if ( updateCounts )
01162                 {
01163                         memset(m_labelCounts,0,m_num_labels*sizeof(SiteID));
01164                         for ( SiteID i = 0; i < m_num_sites; ++i )
01165                                 m_labelCounts[m_labeling[i]]++;
01166                 }
01168                 if ( updateActive )
01169                 {
01170                         for ( LabelCost* lc = m_labelcostsAll; lc; lc = lc->next )
01171                                 lc->active = false;
01173                         EnergyType energy = 0;
01174                         for ( LabelID l = 0; l < m_num_labels; ++l ) 
01175                                 if ( m_labelCounts[l] )
01176                                         for ( LabelCostIter* lci = m_labelcostsByLabel[l]; lci; lci = lci->next ) 
01177                                                 lci->node->active = true;
01178                 }
01179         }
01181         if ( updateCosts )
01182         {
01183                 if (m_updateLabelingDataCosts)
01184                         (this->*m_updateLabelingDataCosts)();
01185                 else
01186                         memset(m_labelingDataCosts,0,m_num_sites*sizeof(EnergyTermType));
01187         }
01188 }
01190 //-------------------------------------------------------------------
01191 // Sets up the binary expansion energy, optimizes it, and updates the current labeling.
01192 //
01193 bool GCoptimization::alpha_expansion(LabelID alpha_label)
01194 {
01195         if (alpha_label < 0)
01196                 return false; // label was disabled due to setLabelOrder on subset of labels
01198         finalizeNeighbors();
01199         gcoclock_t ticks0 = gcoclock();
01201         if ( m_stepsThisCycleTotal == 0 )
01202                 m_labelingInfoDirty = true; // if not inside expansion(), assume data cost function could have changed since last expansion
01203         updateLabelingInfo();
01205         // Determine list of active sites for this expansion move
01206         SiteID size = 0;
01207         SiteID *activeSites = new SiteID[m_num_sites];
01208         EnergyType afterExpansionEnergy = 0;
01209         try 
01210         {
01211                 // Get list of active sites based on alpha and current labeling
01212                 if ( m_queryActiveSitesExpansion )
01213                         size = (this->*m_queryActiveSitesExpansion)(alpha_label,activeSites);
01214                 if ( size == 0 )  // Nothing to do
01215                 {
01216                         delete [] activeSites;
01217                         printStatus2(alpha_label,-1,size,ticks0);
01218                         return false;
01219                 }
01221                 // Initialise reverse-lookup so that non-active neighbours can be identified
01222                 // while constructing the graph
01223                 for ( SiteID i = 0; i < size; i++ )
01224                         m_lookupSiteVar[activeSites[i]] = i;
01226                 // Create binary variables for each remaining site, add the data costs,
01227                 // and compute the smooth costs between variables.
01228                 EnergyT e(size+m_labelcostCount, // poor guess at number of pairwise terms needed :(
01229                                  m_numNeighborsTotal+(m_labelcostCount?size+m_labelcostCount : 0),
01230                                  (void(*)(char*))handleError);
01231                 e.add_variable(size);
01232                 m_beforeExpansionEnergy = 0;
01233                 if ( m_setupDataCostsExpansion   ) (this->*m_setupDataCostsExpansion  )(size,alpha_label,&e,activeSites);
01234                 if ( m_setupSmoothCostsExpansion ) (this->*m_setupSmoothCostsExpansion)(size,alpha_label,&e,activeSites);
01235                 EnergyType alphaCorrection = setupLabelCostsExpansion(size,alpha_label,&e,activeSites);
01236                 checkInterrupt();
01237                 afterExpansionEnergy = e.minimize() + alphaCorrection;
01238                 checkInterrupt();
01240                 if ( afterExpansionEnergy < m_beforeExpansionEnergy )
01241                         (this->*m_applyNewLabeling)(&e,activeSites,size,alpha_label);
01243                 for ( SiteID i = 0; i < size; i++ )
01244                         m_lookupSiteVar[activeSites[i]] = -1; // restore m_lookupSite to all -1s
01246                 printStatus2(alpha_label,-1,size,ticks0);
01247         } 
01248         catch (...)
01249         {
01250                 delete [] activeSites;
01251                 throw;
01252         }
01253         delete [] activeSites;
01254         return afterExpansionEnergy < m_beforeExpansionEnergy;
01255 }
01257 //-------------------------------------------------------------------
01259 GCoptimization::EnergyType GCoptimization::oneExpansionIteration()
01260 {
01261         permuteLabelTable();
01262         m_stepsThisCycle = 0;
01263         m_stepsThisCycleTotal = m_num_labels;
01265         // Each cycle is exactly one pass over the labels
01266         for (LabelID next = 0; next < m_num_labels; next++, m_stepsThisCycle++ )
01267                 alpha_expansion(m_labelTable[next]);
01269         return compute_energy();
01270 }
01272 //-------------------------------------------------------------------//
01273 //                  METHODS for SWAP MOVES                           //  
01274 //-------------------------------------------------------------------//
01276 GCoptimization::EnergyType GCoptimization::swap(int max_num_iterations)
01277 {
01278         EnergyType new_energy,old_energy;
01279         if ( (this->*m_solveSpecialCases)(new_energy) )
01280                 return new_energy;
01282         new_energy = compute_energy();
01283         old_energy = new_energy+1;
01284         printStatus1("starting alpha/beta-swap");
01286         if ( max_num_iterations == -1 )
01287                 max_num_iterations = 10000000;
01288         int curr_cycle = 1;
01289         m_stepsThisCycleTotal = (m_num_labels*(m_num_labels-1))/2;
01290         try
01291         {
01292                 while ( old_energy > new_energy && curr_cycle <= max_num_iterations)
01293                 {
01294                         gcoclock_t ticks0 = gcoclock();
01295                         old_energy = new_energy;
01296                         new_energy = oneSwapIteration();
01297                         printStatus1(curr_cycle,true,ticks0);
01298                         curr_cycle++;
01299                 }
01300         } 
01301         catch (...)
01302         {
01303                 m_stepsThisCycle = m_stepsThisCycleTotal = 0;
01304                 throw;
01305         }
01306         m_stepsThisCycle = m_stepsThisCycleTotal = 0;
01308         return(new_energy);
01309 }
01311 //--------------------------------------------------------------------------------
01313 GCoptimization::EnergyType GCoptimization::oneSwapIteration()
01314 {
01315         LabelID next,next1;
01316         permuteLabelTable();
01317         m_stepsThisCycle = 0;
01319         for (next = 0;  next < m_num_labels;  next++ )
01320                 for (next1 = m_num_labels - 1;  next1 >= 0;  next1-- )
01321                         if ( m_labelTable[next] < m_labelTable[next1] )
01322                         {
01323                                 alpha_beta_swap(m_labelTable[next],m_labelTable[next1]); 
01324                                 m_stepsThisCycle++;
01325                         }
01327         return(compute_energy());
01328 }
01330 //---------------------------------------------------------------------------------
01332 void GCoptimization::alpha_beta_swap(LabelID alpha_label, LabelID beta_label)
01333 {
01334         assert( alpha_label >= 0 && alpha_label < m_num_labels && beta_label >= 0 && beta_label < m_num_labels);
01335         if ( m_labelcostsAll )
01336                 handleError("Label costs only implemented for alpha-expansion.");
01338         finalizeNeighbors();
01339         gcoclock_t ticks0 = gcoclock();
01341         // Determine the list of active sites for this swap move
01342         SiteID size = 0;
01343         SiteID *activeSites = new SiteID[m_num_sites];
01344         try
01345         {
01346                 for ( SiteID i = 0; i < m_num_sites; i++ )
01347                 {
01348                         if ( m_labeling[i] == alpha_label || m_labeling[i] == beta_label )
01349                         {
01350                                 activeSites[size] = i;
01351                                 m_lookupSiteVar[i] = size;
01352                                 size++;
01353                         }
01354                 }
01355                 if ( size == 0 )
01356                 {
01357                         delete [] activeSites;
01358                         printStatus2(alpha_label,beta_label,size,ticks0);
01359                         return;
01360                 }
01362                 // Create binary variables for each remaining site, add the data costs,
01363                 // and compute the smooth costs between variables.
01364                 EnergyT e(size,m_numNeighborsTotal,(void(*)(char*))handleError);
01365                 e.add_variable(size);
01366                 if ( m_setupDataCostsSwap   ) (this->*m_setupDataCostsSwap  )(size,alpha_label,beta_label,&e,activeSites);
01367                 if ( m_setupSmoothCostsSwap ) (this->*m_setupSmoothCostsSwap)(size,alpha_label,beta_label,&e,activeSites);
01368                 checkInterrupt();
01369                 e.minimize();
01370                 checkInterrupt();
01372                 // Apply the new labeling
01373                 for ( SiteID i = 0; i < size; i++ )
01374                 {
01375                         m_labeling[activeSites[i]] = (e.get_var(i) == 0) ? alpha_label : beta_label;
01376                         m_lookupSiteVar[activeSites[i]] = -1; // restore lookupSiteVar to all -1s
01377                 }
01378                 m_labelingInfoDirty = true;
01379         } 
01380         catch (...)
01381         {
01382                 delete [] activeSites;
01383                 throw;
01384         }
01385         delete [] activeSites;
01387         printStatus2(alpha_label,beta_label,size,ticks0);
01388 }
01392 // Functions for the GCoptimizationGridGraph, derived from GCoptimization
01395 GCoptimizationGridGraph::GCoptimizationGridGraph(SiteID width, SiteID height,LabelID num_labels)
01396                                                 :GCoptimization(width*height,num_labels)
01397 {
01398         assert( (width > 1) && (height > 1) && (num_labels > 1 ));
01400         m_weightedGraph = 0;
01401         for (int  i = 0; i < 4; i ++ )  m_unityWeights[i] = 1;
01403         m_width  = width;
01404         m_height = height;
01406         m_numNeighbors = new SiteID[m_num_sites];
01407         m_neighbors = new SiteID[4*m_num_sites];
01409         SiteID indexes[4] = {-1,1,-m_width,m_width};
01411         SiteID indexesL[3] = {1,-m_width,m_width};
01412         SiteID indexesR[3] = {-1,-m_width,m_width};
01413         SiteID indexesU[3] = {1,-1,m_width};
01414         SiteID indexesD[3] = {1,-1,-m_width};
01416         SiteID indexesUL[2] = {1,m_width};
01417         SiteID indexesUR[2] = {-1,m_width};
01418         SiteID indexesDL[2] = {1,-m_width};
01419         SiteID indexesDR[2] = {-1,-m_width};
01421         setupNeighbData(1,m_height-1,1,m_width-1,4,indexes);
01423         setupNeighbData(1,m_height-1,0,1,3,indexesL);
01424         setupNeighbData(1,m_height-1,m_width-1,m_width,3,indexesR);
01425         setupNeighbData(0,1,1,width-1,3,indexesU);
01426         setupNeighbData(m_height-1,m_height,1,m_width-1,3,indexesD);
01428         setupNeighbData(0,1,0,1,2,indexesUL);
01429         setupNeighbData(0,1,m_width-1,m_width,2,indexesUR);
01430         setupNeighbData(m_height-1,m_height,0,1,2,indexesDL);
01431         setupNeighbData(m_height-1,m_height,m_width-1,m_width,2,indexesDR);
01432 }
01434 //-------------------------------------------------------------------
01436 GCoptimizationGridGraph::~GCoptimizationGridGraph()
01437 {
01438         delete [] m_numNeighbors;
01439         if ( m_neighbors )
01440                 delete [] m_neighbors;
01441         if (m_weightedGraph) delete [] m_neighborsWeights;
01442 }
01445 //-------------------------------------------------------------------
01447 void GCoptimizationGridGraph::setupNeighbData(SiteID startY,SiteID endY,SiteID startX,
01448                                                                                           SiteID endX,SiteID maxInd,SiteID *indexes)
01449 {
01450         SiteID x,y,pix;
01451         SiteID n;
01453         for ( y = startY; y < endY; y++ )
01454                 for ( x = startX; x < endX; x++ )
01455                 {
01456                         pix = x+y*m_width;
01457                         m_numNeighbors[pix] = maxInd;
01458                         m_numNeighborsTotal += maxInd;
01460                         for (n = 0; n < maxInd; n++ )
01461                                 m_neighbors[pix*4+n] = pix+indexes[n];
01462                 }
01463 }
01465 //-------------------------------------------------------------------
01467 void GCoptimizationGridGraph::finalizeNeighbors()
01468 {
01469 }
01471 //-------------------------------------------------------------------
01473 void GCoptimizationGridGraph::setSmoothCostVH(EnergyTermType *smoothArray, EnergyTermType *vCosts, EnergyTermType *hCosts)
01474 {
01475         setSmoothCost(smoothArray);
01476         m_weightedGraph = 1;
01477         computeNeighborWeights(vCosts,hCosts);
01478 }
01480 //-------------------------------------------------------------------
01482 void GCoptimizationGridGraph::giveNeighborInfo(SiteID site, SiteID *numSites, SiteID **neighbors, EnergyTermType **weights)
01483 {
01484         *numSites  = m_numNeighbors[site];
01485         *neighbors = &m_neighbors[site*4];
01487         if (m_weightedGraph) *weights  = &m_neighborsWeights[site*4];
01488         else *weights = m_unityWeights;
01489 }
01491 //-------------------------------------------------------------------
01493 void GCoptimizationGridGraph::computeNeighborWeights(EnergyTermType *vCosts,EnergyTermType *hCosts)
01494 {
01495         SiteID i,n,nSite;
01496         GCoptimization::EnergyTermType weight;
01498         m_neighborsWeights = new EnergyTermType[m_num_sites*4];
01500         for ( i = 0; i < m_num_sites; i++ )
01501         {
01502                 for ( n = 0; n < m_numNeighbors[i]; n++ )
01503                 {
01504                         nSite = m_neighbors[4*i+n];
01505                         if ( i-nSite == 1 )            weight = hCosts[nSite];
01506                         else if (i-nSite == -1 )       weight = hCosts[i];
01507                         else if ( i-nSite == m_width ) weight = vCosts[nSite];
01508                         else if (i-nSite == -m_width ) weight = vCosts[i];
01510                         m_neighborsWeights[i*4+n] = weight;
01511                 }
01512         }
01514 }
01516 // Functions for the GCoptimizationGeneralGraph, derived from GCoptimization
01519 GCoptimizationGeneralGraph::GCoptimizationGeneralGraph(SiteID num_sites,LabelID num_labels):GCoptimization(num_sites,num_labels)
01520 {
01521         assert( num_sites > 1 && num_labels > 1 );
01523         m_neighborsIndexes = 0;
01524         m_neighborsWeights = 0;
01525         m_numNeighbors     = 0;
01526         m_neighbors        = 0;
01528         m_needTodeleteNeighbors        = true;
01529         m_needToFinishSettingNeighbors = true;
01530 }
01532 //------------------------------------------------------------------
01534 GCoptimizationGeneralGraph::~GCoptimizationGeneralGraph()
01535 {
01537         if ( m_neighbors )
01538                 delete [] m_neighbors;
01540         if ( m_numNeighbors && m_needTodeleteNeighbors )
01541         {
01542                 for ( SiteID i = 0; i < m_num_sites; i++ )
01543                 {
01544                         if (m_numNeighbors[i] != 0 )
01545                         {
01546                                 delete [] m_neighborsIndexes[i];
01547                                 delete [] m_neighborsWeights[i];
01548                         }
01549                 }
01551                 delete [] m_numNeighbors;
01552                 delete [] m_neighborsIndexes;
01553                 delete [] m_neighborsWeights;
01554         }
01555 }
01557 //------------------------------------------------------------------
01559 void GCoptimizationGeneralGraph::finalizeNeighbors()
01560 {
01561         if ( !m_needToFinishSettingNeighbors )
01562                 return;
01563         m_needToFinishSettingNeighbors = false;
01565         Neighbor *tmp;
01566         SiteID i,site,count;
01568         EnergyTermType *tempWeights = new EnergyTermType[m_num_sites];
01569         SiteID *tempIndexes         = new SiteID[m_num_sites];
01571         if ( !tempWeights || !tempIndexes ) handleError("Not enough memory");
01573         m_numNeighbors     = new SiteID[m_num_sites];
01574         m_neighborsIndexes = new SiteID*[m_num_sites];
01575         m_neighborsWeights = new EnergyTermType*[m_num_sites];
01577         if ( !m_numNeighbors || !m_neighborsIndexes || !m_neighborsWeights ) handleError("Not enough memory.");
01579         for ( site = 0; site < m_num_sites; site++ )
01580         {
01581                 if ( m_neighbors && !m_neighbors[site].isEmpty() )
01582                 {
01583                         m_neighbors[site].setCursorFront();
01584                         count = 0;
01586                         while ( m_neighbors[site].hasNext() )
01587                         {
01588                                 tmp = (Neighbor *) (m_neighbors[site].next());
01589                                 tempIndexes[count] =  tmp->to_node;
01590                                 tempWeights[count] =  tmp->weight;
01591                                 delete tmp;
01592                                 count++;
01593                         }
01594                         m_numNeighbors[site]     = count;
01595                         m_numNeighborsTotal     += count;
01596                         m_neighborsIndexes[site] = new SiteID[count];
01597                         m_neighborsWeights[site] = new EnergyTermType[count];
01599                         if ( !m_neighborsIndexes[site] || !m_neighborsWeights[site] ) handleError("Not enough memory.");
01601                         for ( i = 0; i < count; i++ )
01602                         {
01603                                 m_neighborsIndexes[site][i] = tempIndexes[i];
01604                                 m_neighborsWeights[site][i] = tempWeights[i];
01605                         }
01606                 }
01607                 else m_numNeighbors[site] = 0;
01609         }
01611         delete [] tempIndexes;
01612         delete [] tempWeights;
01613         if (m_neighbors) {
01614                 delete [] m_neighbors;
01615                 m_neighbors = 0;
01616         }
01617 }
01618 //------------------------------------------------------------------------------
01620 void GCoptimizationGeneralGraph::giveNeighborInfo(SiteID site, SiteID *numSites, 
01621                                                                                                   SiteID **neighbors, EnergyTermType **weights)
01622 {
01623         if (m_numNeighbors) {
01624                 (*numSites)  =  m_numNeighbors[site];
01625                 (*neighbors) = m_neighborsIndexes[site];
01626                 (*weights)   = m_neighborsWeights[site];
01627         } else {
01628                 *numSites = 0;
01629                 *neighbors = 0;
01630                 *weights = 0;
01631         }
01632 }
01635 //------------------------------------------------------------------
01637 void GCoptimizationGeneralGraph::setNeighbors(SiteID site1, SiteID site2, EnergyTermType weight)
01638 {
01640         assert( site1 < m_num_sites && site1 >= 0 && site2 < m_num_sites && site2 >= 0);
01641         if ( m_needToFinishSettingNeighbors == false )
01642                 handleError("Already set up neighborhood system.");
01644         if ( !m_neighbors )
01645         {
01646                 m_neighbors = (LinkedBlockList *) new LinkedBlockList[m_num_sites];
01647                 if ( !m_neighbors ) handleError("Not enough memory.");
01648         }
01650         Neighbor *temp1 = (Neighbor *) new Neighbor;
01651         Neighbor *temp2 = (Neighbor *) new Neighbor;
01653         temp1->weight  = weight;
01654         temp1->to_node = site2;
01656         temp2->weight  = weight;
01657         temp2->to_node = site1;
01659         m_neighbors[site1].addFront(temp1);
01660         m_neighbors[site2].addFront(temp2);
01662 }
01663 //------------------------------------------------------------------
01665 void GCoptimizationGeneralGraph::setAllNeighbors(SiteID *numNeighbors,SiteID **neighborsIndexes,
01666                                                                                                  EnergyTermType **neighborsWeights)
01667 {
01668         m_needTodeleteNeighbors = false;
01669         m_needToFinishSettingNeighbors = false;
01670         if ( m_numNeighborsTotal > 0 )
01671                 handleError("Already set up neighborhood system.");
01672         m_numNeighbors     = numNeighbors;
01673         m_numNeighborsTotal = 0;
01674         for (int site = 0; site < m_num_sites; site++ ) m_numNeighborsTotal += m_numNeighbors[site];
01675         m_neighborsIndexes = neighborsIndexes;
01676         m_neighborsWeights = neighborsWeights;
01677 }
01681 //------------------------------------------------------------------
01682 // boring status messages
01684 void GCoptimization::printStatus1(const char* extraMsg)
01685 {
01686         if ( m_verbosity < 1 )
01687                 return;
01688         if ( extraMsg )
01689                 printf("gco>> %s\n",extraMsg);
01690         printf("gco>> initial energy: \tE=%lld (E=%lld+%lld+%lld)\n",(long long)compute_energy(),
01691                 (long long)giveDataEnergy(), (long long)giveSmoothEnergy(), (long long)giveLabelEnergy()); 
01692         flushnow(); 
01693 }
01695 void GCoptimization::printStatus1(int cycle, bool isSwap, gcoclock_t ticks0)
01696 {
01697         if ( m_verbosity < 1 )
01698                 return;
01699         gcoclock_t ticks1 = gcoclock();
01700         printf("gco>> after cycle %2d: \tE=%lld (E=%lld+%lld+%lld);",cycle,(long long)compute_energy(),
01701                 (long long)giveDataEnergy(),(long long)giveSmoothEnergy(),(long long)giveLabelEnergy());
01702         if ( m_stepsThisCycleTotal > 0 )
01703                 printf(isSwap ? " \t%d swaps(s);" : " \t%d expansions(s);",m_stepsThisCycleTotal);
01704         if ( m_verbosity == 1 )
01705         {
01706                 // Don't print time if time is already printed at finer scale, since printing
01707                 // itself takes time (esp in MATLAB) and makes time useless at this level
01708                 int ms = (int)(1000*(ticks1 - ticks0) / GCO_CLOCKS_PER_SEC);
01709                 printf(" \t%d ms",ms);
01710         }
01711         printf("\n");
01712         flushnow(); 
01713 }
01715 void GCoptimization::printStatus2(int alpha, int beta, int numVars, gcoclock_t ticks0)
01716 {
01717         if ( m_verbosity < 2 )
01718                 return;
01719         int microsec = (int)(1000000*(gcoclock() - ticks0) / GCO_CLOCKS_PER_SEC);
01720         if ( beta >= 0 )
01721                 printf("gco>>   after swap(%d,%d):",alpha+INDEX0,beta+INDEX0);
01722         else
01723                 printf("gco>>   after expansion(%d):",alpha+INDEX0);
01724         printf(" \tE=%lld (E=%lld+%lld+%lld);\t %lld vars;",
01725                 (long long)compute_energy(),(long long)giveDataEnergy(),
01726                 (long long)giveSmoothEnergy(),(long long)giveLabelEnergy(),(long long)numVars);
01727         if ( m_stepsThisCycleTotal > 0 )
01728                 printf(" \t(%d of %d);",m_stepsThisCycle+1,m_stepsThisCycleTotal);
01730         printf(microsec > 100 ? "\t %.2f ms\n" : "\t %.3f ms\n",(double)microsec/1000.0);
01731         flushnow();
01732 }
01737 //-------------------------------------------------------------------
01738 // DataCostFnSparse methods
01739 //-------------------------------------------------------------------
01742 GCoptimization::DataCostFnSparse::DataCostFnSparse(SiteID num_sites, LabelID num_labels)
01743 : m_num_sites(num_sites)
01744 , m_num_labels(num_labels)
01745 , m_buckets_per_label((m_num_sites + cSitesPerBucket-1)/cSitesPerBucket)
01746 , m_buckets(0)
01747 {
01748 }
01750 GCoptimization::DataCostFnSparse::DataCostFnSparse(const DataCostFnSparse& src)
01751 : m_num_sites(src.m_num_sites)
01752 , m_num_labels(src.m_num_labels)
01753 , m_buckets_per_label(src.m_buckets_per_label)
01754 , m_buckets(0)
01755 {
01756         assert(!src.m_buckets); // not implemented
01757 }
01759 GCoptimization::DataCostFnSparse::~DataCostFnSparse()
01760 {
01761         if (m_buckets) {
01762                 for (LabelID l = 0; l < m_num_labels; ++l)
01763                         if (m_buckets[l*m_buckets_per_label].begin)
01764                                 delete [] m_buckets[l*m_buckets_per_label].begin;
01765                 delete [] m_buckets;
01766         }
01767 }
01769 void GCoptimization::DataCostFnSparse::set(LabelID l, const SparseDataCost* costs, SiteID count)
01770 {
01771         // Create the bucket if necessary, and copy all the costs
01772         //
01773         if (!m_buckets) {
01774                 m_buckets = new DataCostBucket[m_num_labels*m_buckets_per_label];
01775                 memset(m_buckets, 0, m_num_labels*m_buckets_per_label*sizeof(DataCostBucket));
01776         }
01778         DataCostBucket* b = &m_buckets[l*m_buckets_per_label];
01779         if (b->begin)
01780                 delete [] b->begin;
01781         SparseDataCost* next = new SparseDataCost[count];
01782         memcpy(next,costs,count*sizeof(SparseDataCost));
01784         //
01785         // Scan the list of costs and remember pointers to delimit the 'buckets', i.e. where 
01786         // ranges of SiteIDs lie along the array. Buckets can be empty (begin == end).
01787         //
01788         const SparseDataCost* end  = next+count;
01789         SiteID prev_site = -1;
01790         for (int i = 0; i < m_buckets_per_label; ++i) {
01791                 b[i].begin = b[i].predict = next;
01792                 SiteID end_site = (i+1)*cSitesPerBucket;
01793                 while (next < end && next->site < end_site) {
01794                         if (next->site < 0 || next->site >= m_num_sites)
01795                                 throw GCException("Invalid site id given for sparse data cost; must be within range.");
01796                         if (next->site <= prev_site)
01797                                 throw GCException("Sparse data costs must be sorted in increasing order of SiteID");
01798                         prev_site = next->site;
01799                         ++next;
01800                 }
01801                 b[i].end = next;
01802         }
01803 }
01805 GCoptimization::EnergyTermType GCoptimization::DataCostFnSparse::search(DataCostBucket& b, SiteID s)
01806 {
01807         // Perform binary search for requested SiteID
01808         //
01809         const SparseDataCost* L = b.begin;
01810         const SparseDataCost* R = b.end-1;
01811         if ( R - L == m_num_sites )
01812                 return b.begin[s].cost; // special case: this particular label is actually dense
01813         do {
01814                 const SparseDataCost* mid = (const SparseDataCost*)((((size_t)L+(size_t)R) >> 1) & cDataCostPtrMask);
01815                 if (s < mid->site)
01816                         R = mid-1;         // eliminate upper range
01817                 else if (mid->site < s)
01818                         L = mid+1;         // eliminate lower range
01819                 else {
01820                         b.predict = mid+1;
01821                         return mid->cost;  // found it!
01822                 }
01823         } while (R - L > cLinearSearchSize);
01825         // Finish off with linear search over the remaining elements
01826         //
01827         do {
01828                 if (L->site >= s) {
01829                         if (L->site == s) {
01830                                 b.predict = L+1;
01831                                 return L->cost;
01832                         }
01833                         break;
01834                 }
01835         } while (++L <= R);
01836         b.predict = L;
01838         return GCO_MAX_ENERGYTERM; // the site belongs to this bucket but with no cost specified
01839 }
01841 OLGA_INLINE GCoptimization::EnergyTermType GCoptimization::DataCostFnSparse::compute(SiteID s, LabelID l)
01842 {
01843         DataCostBucket& b = m_buckets[l*m_buckets_per_label + (s >> cLogSitesPerBucket)];
01844         if (b.begin == b.end)
01845                 return GCO_MAX_ENERGYTERM;
01846         if (b.predict < b.end) {
01847                 // Check for correct prediction
01848                 if (b.predict->site == s)
01849                         return (b.predict++)->cost; // predict++ for next time
01851                 // If the requested site is missing from the site ids near 'predict'
01852                 // then we know it doesn't exist in the bucket, so return INF
01853                 if (b.predict->site > s && b.predict > b.begin && (b.predict-1)->site < s)
01854                         return GCO_MAX_ENERGYTERM;
01855         }
01856         if ( (size_t)b.end - (size_t)b.begin == cSitesPerBucket*sizeof(SparseDataCost) )
01857                 return b.begin[s-b.begin->site].cost; // special case: this particular bucket is actually dense!
01859         return search(b,s);
01860 }
01862 GCoptimization::SiteID GCoptimization::DataCostFnSparse::queryActiveSitesExpansion(LabelID alpha_label, const LabelID* labeling, SiteID* activeSites)
01863 {
01864         const SparseDataCost* next = m_buckets[alpha_label*m_buckets_per_label].begin;
01865         const SparseDataCost* end  = m_buckets[alpha_label*m_buckets_per_label + m_buckets_per_label-1].end;
01866         SiteID count = 0;
01867         for (; next < end; ++next) {
01868                 if ( labeling[next->site] != alpha_label )
01869                         activeSites[count++] = next->site;
01870         }
01871         return count;
01872 }

