00001 #ifdef MATLAB_MEX_FILE
00002 #include <mex.h>
00003 #endif
00004 #include "GCoptimization.h"
00005 #include "LinkedBlockList.h"
00006 #include <stdio.h>
00007 #include <stdlib.h>
00008 #include <vector>
00009 #include <algorithm>
00010
00011
00012
00013
00014
00015 #ifdef _WIN32
00016 #define WIN32_LEAN_AND_MEAN
00017 #define VC_EXTRALEAN
00018 #define NOMINMAX
00019 #include <windows.h>
00020 extern "C" gcoclock_t GCO_CLOCKS_PER_SEC = 0;
00021 extern "C" inline gcoclock_t gcoclock()
00022 {
00023 gcoclock_t result = 0;
00024 if (GCO_CLOCKS_PER_SEC == 0)
00025 QueryPerformanceFrequency((LARGE_INTEGER*)&GCO_CLOCKS_PER_SEC);
00026 QueryPerformanceCounter((LARGE_INTEGER*)&result);
00027 return result;
00028 }
00029
00030 #else
00031 extern "C" gcoclock_t GCO_CLOCKS_PER_SEC = CLOCKS_PER_SEC;
00032 extern "C" gcoclock_t gcoclock() { return clock(); }
00033 #endif
00034
00035 #ifdef MATLAB_MEX_FILE
00036 extern "C" bool utIsInterruptPending();
00037 static void flushnow()
00038 {
00039
00040 static gcoclock_t prevclock = 0;
00041 gcoclock_t now = gcoclock();
00042 if (now - prevclock > GCO_CLOCKS_PER_SEC/5) {
00043 prevclock = now;
00044 mexEvalString("drawnow;");
00045 }
00046 }
00047 #define INDEX0 1 // print 1-based label and site indices for MATLAB
00048 #else
00049 inline static bool utIsInterruptPending() { return false; }
00050 static void flushnow() { }
00051 #define INDEX0 0 // print 0-based label and site indices
00052 #endif
00053
00054
00055 template <typename T>
00056 void slist_clear(T*& head)
00057 {
00058 while (head) {
00059 T* temp = head;
00060 head = head->next;
00061 delete temp;
00062 }
00063 }
00064
00065 template <typename T>
00066 void slist_prepend(T*& head, T* val)
00067 {
00068 val->next = head;
00069 head = val;
00070 }
00071
00072
00073 void GCException::Report()
00074 {
00075 printf("\n%s\n",message);
00076 exit(0);
00077 }
00078
00079
00080
00082
00084
00085 GCoptimization::GCoptimization(SiteID nSites, LabelID nLabels)
00086 : m_num_labels(nLabels)
00087 , m_num_sites(nSites)
00088 , m_datacostIndividual(0)
00089 , m_smoothcostIndividual(0)
00090 , m_labelcostsAll(0)
00091 , m_labelcostsByLabel(0)
00092 , m_labelcostCount(0)
00093 , m_smoothcostFn(0)
00094 , m_datacostFn(0)
00095 , m_numNeighborsTotal(0)
00096 , m_queryActiveSitesExpansion(&GCoptimization::queryActiveSitesExpansion<DataCostFnFromArray>)
00097 , m_setupDataCostsSwap(0)
00098 , m_setupDataCostsExpansion(0)
00099 , m_setupSmoothCostsSwap(0)
00100 , m_setupSmoothCostsExpansion(0)
00101 , m_applyNewLabeling(0)
00102 , m_updateLabelingDataCosts(0)
00103 , m_giveSmoothEnergyInternal(0)
00104 , m_solveSpecialCases(&GCoptimization::solveSpecialCases<DataCostFnFromArray>)
00105 , m_datacostFnDelete(0)
00106 , m_smoothcostFnDelete(0)
00107 , m_random_label_order(false)
00108 , m_verbosity(0)
00109 , m_labelingInfoDirty(true)
00110 , m_lookupSiteVar(new SiteID[nSites])
00111 , m_labeling(new LabelID[nSites])
00112 , m_labelTable(new LabelID[nLabels])
00113 , m_labelingDataCosts(new EnergyTermType[nSites])
00114 , m_labelCounts(new SiteID[nLabels])
00115 , m_activeLabelCounts(new SiteID[m_num_labels])
00116 , m_stepsThisCycle(0)
00117 , m_stepsThisCycleTotal(0)
00118 {
00119 if ( nLabels <= 1 ) handleError("Number of labels must be >= 2");
00120 if ( nSites <= 0 ) handleError("Number of sites must be >= 1");
00121
00122 if ( !m_lookupSiteVar || !m_labelTable || !m_labeling ){
00123 if (m_lookupSiteVar) delete [] m_lookupSiteVar;
00124 if (m_labelTable) delete [] m_labelTable;
00125 if (m_labeling) delete [] m_labeling;
00126 if (m_labelingDataCosts) delete [] m_labelingDataCosts;
00127 if (m_labelCounts) delete [] m_labelCounts;
00128 handleError("Not enough memory.");
00129 }
00130
00131 memset(m_labeling, 0, m_num_sites*sizeof(LabelID));
00132 memset(m_lookupSiteVar,-1,m_num_sites*sizeof(SiteID));
00133 setLabelOrder(false);
00134 specializeSmoothCostFunctor(SmoothCostFnPotts());
00135 }
00136
00137
00138
00139 GCoptimization::~GCoptimization()
00140 {
00141 delete [] m_labelTable;
00142 delete [] m_lookupSiteVar;
00143 delete [] m_labeling;
00144 delete [] m_labelingDataCosts;
00145 delete [] m_labelCounts;
00146 delete [] m_activeLabelCounts;
00147
00148 if (m_datacostFnDelete) m_datacostFnDelete(m_datacostFn);
00149 if (m_smoothcostFnDelete) m_smoothcostFnDelete(m_smoothcostFn);
00150
00151 if (m_datacostIndividual) delete [] m_datacostIndividual;
00152 if (m_smoothcostIndividual) delete [] m_smoothcostIndividual;
00153
00154
00155
00156 slist_clear(m_labelcostsAll);
00157 if (m_labelcostsByLabel) {
00158 for ( LabelID i = 0; i < m_num_labels; ++i )
00159 slist_clear(m_labelcostsByLabel[i]);
00160 delete [] m_labelcostsByLabel;
00161 }
00162 }
00163
00164
00165
00166 template <>
00167 GCoptimization::SiteID GCoptimization::queryActiveSitesExpansion<GCoptimization::DataCostFnSparse>(LabelID alpha_label,SiteID *activeSites)
00168 {
00169 return ((DataCostFnSparse*)m_datacostFn)->queryActiveSitesExpansion(alpha_label,m_labeling,activeSites);
00170 }
00171
00172
00173
00174 template <>
00175 void GCoptimization::setupDataCostsExpansion<GCoptimization::DataCostFnSparse>(SiteID size,LabelID alpha_label,EnergyT *e,SiteID *activeSites)
00176 {
00177 DataCostFnSparse* dc = (DataCostFnSparse*)m_datacostFn;
00178 DataCostFnSparse::iterator dciter = dc->begin(alpha_label);
00179 for ( SiteID i = 0; i < size; ++i )
00180 {
00181 SiteID site = activeSites[i];
00182 while ( dciter.site() != site )
00183 ++dciter;
00184 addterm1_checked(e,i,dciter.cost(),m_labelingDataCosts[site]);
00185 }
00186 }
00187
00188
00189
00190 template <>
00191 void GCoptimization::applyNewLabeling<GCoptimization::DataCostFnSparse>(EnergyT *e,SiteID *activeSites,SiteID size,LabelID alpha_label)
00192 {
00193 DataCostFnSparse* dc = (DataCostFnSparse*)m_datacostFn;
00194 DataCostFnSparse::iterator dciter = dc->begin(alpha_label);
00195 for ( SiteID i = 0; i < size; i++ )
00196 {
00197 if ( e->get_var(i) == 0 )
00198 {
00199 SiteID site = activeSites[i];
00200 LabelID prev = m_labeling[site];
00201 m_labeling[site] = alpha_label;
00202 m_labelCounts[alpha_label]++;
00203 m_labelCounts[prev]--;
00204 while ( dciter.site() != site )
00205 ++dciter;
00206 m_labelingDataCosts[site] = dciter.cost();
00207 }
00208 }
00209 m_labelingInfoDirty = true;
00210 updateLabelingInfo(false,true,false);
00211 }
00212
00213
00214
00215 template <typename UserFunctor>
00216 void GCoptimization::specializeDataCostFunctor(const UserFunctor f) {
00217 if ( m_datacostFnDelete )
00218 m_datacostFnDelete(m_datacostFn);
00219 if ( m_datacostIndividual )
00220 {
00221 delete [] m_datacostIndividual;
00222 m_datacostIndividual = 0;
00223 }
00224 m_datacostFn = new UserFunctor(f);
00225 m_datacostFnDelete = &GCoptimization::deleteFunctor<UserFunctor>;
00226 m_queryActiveSitesExpansion = &GCoptimization::queryActiveSitesExpansion<UserFunctor>;
00227 m_setupDataCostsExpansion = &GCoptimization::setupDataCostsExpansion<UserFunctor>;
00228 m_setupDataCostsSwap = &GCoptimization::setupDataCostsSwap<UserFunctor>;
00229 m_applyNewLabeling = &GCoptimization::applyNewLabeling<UserFunctor>;
00230 m_updateLabelingDataCosts = &GCoptimization::updateLabelingDataCosts<UserFunctor>;
00231 m_solveSpecialCases = &GCoptimization::solveSpecialCases<UserFunctor>;
00232 }
00233
00234 template <typename UserFunctor>
00235 void GCoptimization::specializeSmoothCostFunctor(const UserFunctor f) {
00236 if ( m_smoothcostFnDelete )
00237 m_smoothcostFnDelete(m_smoothcostFn);
00238 if ( m_smoothcostIndividual )
00239 {
00240 delete [] m_smoothcostIndividual;
00241 m_smoothcostIndividual = 0;
00242 }
00243 m_smoothcostFn = new UserFunctor(f);
00244 m_smoothcostFnDelete = &GCoptimization::deleteFunctor<UserFunctor>;
00245 m_giveSmoothEnergyInternal = &GCoptimization::giveSmoothEnergyInternal<UserFunctor>;
00246 m_setupSmoothCostsExpansion = &GCoptimization::setupSmoothCostsExpansion<UserFunctor>;
00247 m_setupSmoothCostsSwap = &GCoptimization::setupSmoothCostsSwap<UserFunctor>;
00248 }
00249
00250
00251
00252 template <typename SmoothCostT>
00253 GCoptimization::EnergyType GCoptimization::giveSmoothEnergyInternal()
00254 {
00255 EnergyType eng = (EnergyType) 0;
00256 SiteID i,numN,*nPointer,nSite,n;
00257 EnergyTermType *weights;
00258 SmoothCostT* sc = (SmoothCostT*) m_smoothcostFn;
00259 for ( i = 0; i < m_num_sites; i++ )
00260 {
00261 giveNeighborInfo(i,&numN,&nPointer,&weights);
00262 for ( n = 0; n < numN; n++ )
00263 {
00264 nSite = nPointer[n];
00265 if ( nSite < i )
00266 eng += weights[n]*(sc->compute(i,nSite,m_labeling[i],m_labeling[nSite]));
00267 }
00268 }
00269
00270 return eng;
00271 }
00272
00273
00274
00275 OLGA_INLINE void GCoptimization::addterm1_checked(EnergyT* e, VarID i, EnergyTermType e0, EnergyTermType e1)
00276 {
00277 if ( e0 > GCO_MAX_ENERGYTERM || e1 > GCO_MAX_ENERGYTERM )
00278 handleError("Data cost term was larger than GCO_MAX_ENERGYTERM; danger of integer overflow.");
00279 m_beforeExpansionEnergy += e1;
00280 e->add_term1(i,e0,e1);
00281 }
00282
00283 OLGA_INLINE void GCoptimization::addterm1_checked(EnergyT* e, VarID i, EnergyTermType e0, EnergyTermType e1, EnergyTermType w)
00284 {
00285 if ( e0 > GCO_MAX_ENERGYTERM || e1 > GCO_MAX_ENERGYTERM )
00286 handleError("Smooth cost term was larger than GCO_MAX_ENERGYTERM; danger of integer overflow.");
00287 if ( w > GCO_MAX_ENERGYTERM )
00288 handleError("Smoothness weight was larger than GCO_MAX_ENERGYTERM; danger of integer overflow.");
00289 m_beforeExpansionEnergy += e1*w;
00290 e->add_term1(i,e0*w,e1*w);
00291 }
00292
00293 OLGA_INLINE void GCoptimization::addterm2_checked(EnergyT* e, VarID i, VarID j, EnergyTermType e00, EnergyTermType e01, EnergyTermType e10, EnergyTermType e11, EnergyTermType w)
00294 {
00295 if ( e00 > GCO_MAX_ENERGYTERM || e11 > GCO_MAX_ENERGYTERM || e01 > GCO_MAX_ENERGYTERM || e10 > GCO_MAX_ENERGYTERM )
00296 handleError("Smooth cost term was larger than GCO_MAX_ENERGYTERM; danger of integer overflow.");
00297 if ( w > GCO_MAX_ENERGYTERM )
00298 handleError("Smoothness weight was larger than GCO_MAX_ENERGYTERM; danger of integer overflow.");
00299
00300
00301 if ( e00+e11 > e01+e10 )
00302 handleError("Non-submodular expansion term detected; smooth costs must be a metric for expansion");
00303 m_beforeExpansionEnergy += e11*w;
00304 e->add_term2(i,j,e00*w,e01*w,e10*w,e11*w);
00305 }
00306
00307
00308
00309 template <typename DataCostT>
00310 GCoptimization::SiteID GCoptimization::queryActiveSitesExpansion(LabelID alpha_label,SiteID *activeSites)
00311 {
00312 SiteID size = 0;
00313 for ( SiteID i = 0; i < m_num_sites; i++ )
00314 if ( m_labeling[i] != alpha_label )
00315 activeSites[size++] = i;
00316 return size;
00317 }
00318
00319
00320
00321 template <typename DataCostT>
00322 void GCoptimization::setupDataCostsExpansion(SiteID size,LabelID alpha_label,EnergyT *e,SiteID *activeSites)
00323 {
00324 DataCostT* dc = (DataCostT*)m_datacostFn;
00325 for ( SiteID i = 0; i < size; ++i )
00326 addterm1_checked(e,i,dc->compute(activeSites[i],alpha_label),m_labelingDataCosts[activeSites[i]]);
00327 }
00328
00329
00330
00331 template <typename SmoothCostT>
00332 void GCoptimization::setupSmoothCostsExpansion(SiteID size,LabelID alpha_label,EnergyT *e,SiteID *activeSites)
00333 {
00334 SiteID i,nSite,site,n,nNum,*nPointer;
00335 EnergyTermType *weights;
00336 SmoothCostT* sc = (SmoothCostT*)m_smoothcostFn;
00337
00338 for ( i = size - 1; i >= 0; i-- )
00339 {
00340 site = activeSites[i];
00341 giveNeighborInfo(site,&nNum,&nPointer,&weights);
00342 for ( n = 0; n < nNum; n++ )
00343 {
00344 nSite = nPointer[n];
00345 if ( m_lookupSiteVar[nSite] == -1 )
00346 addterm1_checked(e,i,sc->compute(site,nSite,alpha_label,m_labeling[nSite]),
00347 sc->compute(site,nSite,m_labeling[site],m_labeling[nSite]),weights[n]);
00348 else if ( nSite < site )
00349 {
00350 addterm2_checked(e,i,m_lookupSiteVar[nSite],
00351 sc->compute(site,nSite,alpha_label,alpha_label),
00352 sc->compute(site,nSite,alpha_label,m_labeling[nSite]),
00353 sc->compute(site,nSite,m_labeling[site],alpha_label),
00354 sc->compute(site,nSite,m_labeling[site],m_labeling[nSite]),weights[n]);
00355 }
00356 }
00357 }
00358 }
00359
00360
00361
00362 template <typename DataCostT>
00363 void GCoptimization::setupDataCostsSwap(SiteID size, LabelID alpha_label, LabelID beta_label,
00364 EnergyT *e,SiteID *activeSites )
00365 {
00366 DataCostT* dc = (DataCostT*)m_datacostFn;
00367 for ( SiteID i = 0; i < size; i++ )
00368 {
00369 e->add_term1(i,dc->compute(activeSites[i],alpha_label),
00370 dc->compute(activeSites[i],beta_label) );
00371 }
00372 }
00373
00374
00375
00376 template <typename SmoothCostT>
00377 void GCoptimization::setupSmoothCostsSwap(SiteID size, LabelID alpha_label,LabelID beta_label,
00378 EnergyT *e,SiteID *activeSites )
00379 {
00380 SiteID i,nSite,site,n,nNum,*nPointer;
00381 EnergyTermType *weights;
00382 SmoothCostT* sc = (SmoothCostT*)m_smoothcostFn;
00383
00384 for ( i = size - 1; i >= 0; i-- )
00385 {
00386 site = activeSites[i];
00387 giveNeighborInfo(site,&nNum,&nPointer,&weights);
00388 for ( n = 0; n < nNum; n++ )
00389 {
00390 nSite = nPointer[n];
00391 if ( m_lookupSiteVar[nSite] == -1 )
00392 addterm1_checked(e,i,sc->compute(site,nSite,alpha_label,m_labeling[nSite]),
00393 sc->compute(site,nSite,beta_label, m_labeling[nSite]),weights[n]);
00394 else if ( nSite < site )
00395 {
00396 addterm2_checked(e,i,m_lookupSiteVar[nSite],
00397 sc->compute(site,nSite,alpha_label,alpha_label),
00398 sc->compute(site,nSite,alpha_label,beta_label),
00399 sc->compute(site,nSite,beta_label,alpha_label),
00400 sc->compute(site,nSite,beta_label,beta_label),weights[n]);
00401 }
00402 }
00403 }
00404 }
00405
00406
00407
00408 template <typename DataCostT>
00409 void GCoptimization::applyNewLabeling(EnergyT *e,SiteID *activeSites,SiteID size,LabelID alpha_label)
00410 {
00411 DataCostT* dc = (DataCostT*)m_datacostFn;
00412 for ( SiteID i = 0; i < size; i++ )
00413 {
00414 if ( e->get_var(i) == 0 )
00415 {
00416 SiteID site = activeSites[i];
00417 LabelID prev = m_labeling[site];
00418 m_labeling[site] = alpha_label;
00419 m_labelCounts[alpha_label]++;
00420 m_labelCounts[prev]--;
00421 m_labelingDataCosts[site] = dc->compute(site,alpha_label);
00422 }
00423 }
00424 m_labelingInfoDirty = true;
00425 updateLabelingInfo(false,true,false);
00426 }
00427
00428
00429
00430 template <typename DataCostT>
00431 void GCoptimization::updateLabelingDataCosts()
00432 {
00433 DataCostT* dc = (DataCostT*)m_datacostFn;
00434 for (int i = 0; i < m_num_sites; ++i)
00435 m_labelingDataCosts[i] = dc->compute(i,m_labeling[i]);
00436 }
00437
00438
00439
00440 template <typename DataCostT>
00441 bool GCoptimization::solveSpecialCases(EnergyType& energy)
00442 {
00443 finalizeNeighbors();
00444
00445 DataCostT* dc = (DataCostT*)m_datacostFn;
00446 bool sc = m_numNeighborsTotal != 0;
00447 bool lc = m_labelcostsAll != 0;
00448
00449 if ( !dc && !sc && !lc )
00450 {
00451 energy = 0;
00452 return true;
00453 }
00454
00455 if ( dc && !sc && !lc ) {
00456
00457 energy = 0;
00458 for ( SiteID i = 0; i < m_num_sites; ++i ) {
00459 LabelID minCostLabel = 0;
00460 EnergyTermType minCost = dc->compute(i, 0);
00461 for ( LabelID l = 1; l < m_num_labels; ++l ) {
00462 EnergyTermType lcost = dc->compute(i, l);
00463 if ( lcost < minCost ) {
00464 minCostLabel = l;
00465 minCost = lcost;
00466 }
00467 }
00468 if ( minCostLabel > GCO_MAX_ENERGYTERM )
00469 handleError("Data cost was larger than GCO_MAX_ENERGYTERM; danger of integer overflow.");
00470 m_labeling[i] = minCostLabel;
00471 energy += minCost;
00472 }
00473 m_labelingInfoDirty = true;
00474 updateLabelingInfo();
00475 return true;
00476 }
00477
00478 if ( !dc && !sc && lc ) {
00479
00480 LabelID minLabel = 0;
00481 EnergyType minLabelCost = GCO_MAX_ENERGYTERM*(EnergyType)m_num_labels;
00482 for ( LabelID l = 0; l < m_num_labels; ++l ) {
00483 EnergyType lcsum = 0;
00484 for ( LabelCostIter* lci = m_labelcostsByLabel[l]; lci; lci = lci->next )
00485 lcsum += lci->node->cost;
00486 if ( lcsum < minLabelCost ) {
00487 minLabel = l;
00488 minLabelCost = lcsum;
00489 }
00490 }
00491 for ( SiteID i = 0; i < m_num_sites; ++i )
00492 m_labeling[i] = minLabel;
00493 energy = minLabelCost;
00494 m_labelingInfoDirty = true;
00495 updateLabelingInfo();
00496 return true;
00497 }
00498
00499 if ( dc && !sc && lc ) {
00500 LabelCost* lc;
00501 for ( lc = m_labelcostsAll; lc; lc = lc->next )
00502 if ( lc->numLabels > 1)
00503 break;
00504 if ( !lc ) {
00505
00506 energy = solveGreedy<DataCostT>();
00507 return true;
00508 }
00509 }
00510
00511
00512 return false;
00513 }
00514
00515 template <>
00516 class GCoptimization::GreedyIter<GCoptimization::DataCostFnSparse> {
00517 public:
00518 GreedyIter(DataCostFnSparse& dc, SiteID)
00519 : m_dc(dc), m_label(0), m_labelend(0)
00520 { }
00521
00522 OLGA_INLINE void start(const LabelID* labels, LabelID labelCount=1)
00523 {
00524 m_label = labels;
00525 m_labelend = labels + labelCount;
00526 if (labelCount > 0) {
00527 m_site = m_dc.begin(*labels);
00528 m_siteend = m_dc.end(*labels);
00529 while (m_site == m_siteend) {
00530 if (++m_label == m_labelend)
00531 break;
00532 m_site = m_dc.begin(*m_label);
00533 m_siteend = m_dc.end(*m_label);
00534 }
00535 }
00536 }
00537 OLGA_INLINE SiteID site() const { return m_site.site(); }
00538 OLGA_INLINE SiteID label() const { return *m_label; }
00539 OLGA_INLINE bool done() const { return m_label >= m_labelend; }
00540 OLGA_INLINE GreedyIter& operator++()
00541 {
00542
00543
00544 if (++m_site == m_siteend) {
00545 while (++m_label < m_labelend) {
00546 m_site = m_dc.begin(*m_label);
00547 m_siteend = m_dc.end(*m_label);
00548 if (m_site != m_siteend)
00549 break;
00550 }
00551 }
00552 return *this;
00553 }
00554 OLGA_INLINE EnergyTermType compute() const { return m_site.cost(); }
00555 OLGA_INLINE SiteID feasibleSites() const { return (SiteID)(m_siteend - m_site); }
00556
00557 private:
00558 DataCostFnSparse::iterator m_site;
00559 DataCostFnSparse::iterator m_siteend;
00560 DataCostFnSparse& m_dc;
00561 const LabelID* m_label;
00562 const LabelID* m_labelend;
00563 };
00564
00565 template <typename DataCostT>
00566 GCoptimization::EnergyType GCoptimization::solveGreedy()
00567 {
00568 printStatus1("starting greedy algorithm (1 cycle only)");
00569 m_stepsThisCycle = m_stepsThisCycleTotal = 0;
00570
00571 EnergyType estart = compute_energy();
00572 EnergyType efinal = 0;
00573 LabelID* oldLabeling = m_labeling;
00574 m_labeling = new LabelID[m_num_sites];
00575 EnergyType* e = new EnergyType[m_num_labels];
00576 LabelID* order = new LabelID[m_num_labels];
00577
00578 try {
00579 gcoclock_t ticks0all = gcoclock();
00580 gcoclock_t ticks0 = gcoclock();
00581
00582
00583 for ( LabelCost* lc = m_labelcostsAll; lc; lc = lc->next)
00584 lc->active = false;
00585
00586 DataCostT* dc = (DataCostT*)m_datacostFn;
00587 GreedyIter<DataCostT> iter(*dc,m_num_sites);
00588 LabelID alpha = 0;
00589
00590
00591
00592 for ( LabelID l = 0; l < m_num_labels; ++l ) {
00593 e[l] = 0;
00594 for ( LabelCostIter* lci = m_labelcostsByLabel[l]; lci; lci = lci->next )
00595 e[l] += lci->node->cost;
00596 iter.start(&l);
00597 e[l] += (EnergyType)(m_num_sites - iter.feasibleSites()) * GCO_MAX_ENERGYTERM;
00598 for (; !iter.done(); ++iter) {
00599 EnergyTermType dataCost = iter.compute();
00600 if ( dataCost > GCO_MAX_ENERGYTERM )
00601 handleError("Data cost was larger than GCO_MAX_ENERGYTERM; danger of integer overflow.");
00602 e[l] += dataCost;
00603 if ( e[l] > e[alpha] )
00604 break;
00605 }
00606 if ( e[l] < e[alpha] )
00607 alpha = l;
00608 }
00609 for ( SiteID i = 0; i < m_num_sites; ++i ) {
00610 m_labeling[i] = alpha;
00611 m_labelingDataCosts[i] = dc->compute(i,alpha);
00612 }
00613 for ( LabelCostIter* lci = m_labelcostsByLabel[alpha]; lci; lci = lci->next )
00614 lci->node->active = true;
00615
00616
00617 for ( LabelID l = 0; l < m_num_labels; ++l )
00618 order[l] = l;
00619 order[alpha] = 0;
00620 order[0] = alpha;
00621
00622 printStatus2(alpha,-1,m_num_sites,ticks0);
00623
00624
00625 for ( LabelID alpha_count = 1; alpha_count <= m_num_labels; ++alpha_count) {
00626 checkInterrupt();
00627 ticks0 = gcoclock();
00628
00629
00630 LabelID alpha_prev = alpha;
00631 for ( LabelID li = alpha_count; li < m_num_labels; ++li ) {
00632 LabelID l = order[li];
00633 e[l] = e[alpha_prev];
00634 for ( LabelCostIter* lci = m_labelcostsByLabel[l]; lci; lci = lci->next )
00635 if ( !lci->node->active )
00636 e[l] += lci->node->cost;
00637 }
00638
00639
00640 for ( iter.start(&order[alpha_count],m_num_labels-alpha_count); !iter.done(); ++iter ) {
00641 EnergyTermType dc_l = iter.compute();
00642 EnergyTermType dc_i = m_labelingDataCosts[iter.site()];
00643 EnergyTermType delta_i = dc_l - dc_i;
00644 if ( delta_i < 0 )
00645 e[iter.label()] += delta_i;
00646 }
00647
00648
00649 LabelID alpha_index = alpha_count-1;
00650 for ( LabelID li = alpha_count; li < m_num_labels; ++li ) {
00651 LabelID l = order[li];
00652 if ( e[l] < e[alpha] ) {
00653 alpha = l;
00654 alpha_index = li;
00655 }
00656 }
00657
00658 if ( alpha == alpha_prev )
00659 break;
00660
00661
00662 LabelID temp = order[alpha_count];
00663 order[alpha_count] = order[alpha_index];
00664 order[alpha_index] = temp;
00665
00666
00667 iter.start(&alpha);
00668 SiteID size = iter.feasibleSites();
00669 for ( ; !iter.done(); ++iter ) {
00670 EnergyTermType dc_l = iter.compute();
00671 EnergyTermType dc_i = m_labelingDataCosts[iter.site()];
00672 EnergyTermType delta_i = dc_l - dc_i;
00673 if ( delta_i < 0 ) {
00674 m_labeling[iter.site()] = alpha;
00675 m_labelingDataCosts[iter.site()] = dc_l;
00676 }
00677 }
00678 for ( LabelCostIter* lci = m_labelcostsByLabel[alpha]; lci; lci = lci->next )
00679 lci->node->active = true;
00680 printStatus2(alpha,-1,size,ticks0);
00681 }
00682
00683 efinal = e[alpha];
00684 if ( efinal < estart ) {
00685
00686 delete [] oldLabeling;
00687 m_labelingInfoDirty = true;
00688 updateLabelingInfo(true,false,false);
00689 printStatus1(1,false,ticks0all);
00690 } else {
00691
00692 efinal = estart;
00693 delete [] m_labeling;
00694 m_labeling = oldLabeling;
00695 m_labelingInfoDirty = true;
00696 updateLabelingInfo();
00697 printStatus1(1,false,ticks0all);
00698 }
00699
00700 delete [] order;
00701 delete [] e;
00702 } catch (...) {
00703 delete [] order;
00704 delete [] e;
00705 throw;
00706 }
00707 return efinal;
00708 }
00709
00710
00711
00712 void GCoptimization::setDataCost(DataCostFn fn) {
00713 specializeDataCostFunctor(DataCostFnFromFunction(fn));
00714 m_labelingInfoDirty = true;
00715 }
00716
00717
00718
00719
00720 void GCoptimization::setDataCost(DataCostFnExtra fn, void *extraData) {
00721 specializeDataCostFunctor(DataCostFnFromFunctionExtra(fn, extraData));
00722 m_labelingInfoDirty = true;
00723 }
00724
00725
00726
00727 void GCoptimization::setDataCost(EnergyTermType *dataArray) {
00728 specializeDataCostFunctor(DataCostFnFromArray(dataArray, m_num_labels));
00729 m_labelingInfoDirty = true;
00730 }
00731
00732
00733
00734 void GCoptimization::setDataCost(SiteID s, LabelID l, EnergyTermType e) {
00735 if ( !m_datacostIndividual )
00736 {
00737 EnergyTermType* table = new EnergyTermType[m_num_sites*m_num_labels];
00738 memset(table, 0, m_num_sites*m_num_labels*sizeof(EnergyTermType));
00739 specializeDataCostFunctor(DataCostFnFromArray(table, m_num_labels));
00740 m_datacostIndividual = table;
00741 m_labelingInfoDirty = true;
00742 }
00743 m_datacostIndividual[s*m_num_labels + l] = e;
00744 if ( m_labeling[s] == l )
00745 m_labelingInfoDirty = true;
00746 }
00747
00748
00749
00750 void GCoptimization::setDataCostFunctor(DataCostFunctor* f) {
00751 if ( m_datacostFnDelete )
00752 m_datacostFnDelete(m_datacostFn);
00753 if ( m_datacostIndividual )
00754 {
00755 delete [] m_datacostIndividual;
00756 m_datacostIndividual = 0;
00757 }
00758 m_datacostFn = f;
00759 m_datacostFnDelete = 0;
00760 m_queryActiveSitesExpansion = &GCoptimization::queryActiveSitesExpansion<DataCostFunctor>;
00761 m_setupDataCostsExpansion = &GCoptimization::setupDataCostsExpansion<DataCostFunctor>;
00762 m_setupDataCostsSwap = &GCoptimization::setupDataCostsSwap<DataCostFunctor>;
00763 m_applyNewLabeling = &GCoptimization::applyNewLabeling<DataCostFunctor>;
00764 m_updateLabelingDataCosts = &GCoptimization::updateLabelingDataCosts<DataCostFunctor>;
00765 m_solveSpecialCases = &GCoptimization::solveSpecialCases<DataCostFunctor>;
00766 m_labelingInfoDirty = true;
00767 }
00768
00769
00770
00771 void GCoptimization::setDataCost(LabelID l, SparseDataCost *costs, SiteID count)
00772 {
00773 if ( !m_datacostFn )
00774 specializeDataCostFunctor(DataCostFnSparse(numSites(),numLabels()));
00775 else if ( m_queryActiveSitesExpansion != (SiteID (GCoptimization::*)(LabelID,SiteID*))&GCoptimization::queryActiveSitesExpansion<DataCostFnSparse> )
00776 handleError("Cannot apply sparse data costs after dense data costs have been used.");
00777 m_labelingInfoDirty = true;
00778 DataCostFnSparse* dc = (DataCostFnSparse*)m_datacostFn;
00779 dc->set(l,costs,count);
00780 }
00781
00782
00783
00784 void GCoptimization::setSmoothCost(SmoothCostFn fn) {
00785 specializeSmoothCostFunctor(SmoothCostFnFromFunction(fn));
00786 }
00787
00788
00789
00790 void GCoptimization::setSmoothCost(SmoothCostFnExtra fn, void* extraData) {
00791 specializeSmoothCostFunctor(SmoothCostFnFromFunctionExtra(fn, extraData));
00792 }
00793
00794
00795
00796 void GCoptimization::setSmoothCost(EnergyTermType *smoothArray) {
00797 specializeSmoothCostFunctor(SmoothCostFnFromArray(smoothArray, m_num_labels));
00798 }
00799
00800
00801
00802 void GCoptimization::setSmoothCost(LabelID l1, LabelID l2, EnergyTermType e){
00803 if ( !m_smoothcostIndividual )
00804 {
00805 EnergyTermType* table = new EnergyTermType[m_num_labels*m_num_labels];
00806 memset(table, 0, m_num_labels*m_num_labels*sizeof(EnergyTermType));
00807 specializeSmoothCostFunctor(SmoothCostFnFromArray(table, m_num_labels));
00808 m_smoothcostIndividual = table;
00809 }
00810 m_smoothcostIndividual[l1*m_num_labels + l2] = e;
00811 }
00812
00813
00814
00815 void GCoptimization::setSmoothCostFunctor(SmoothCostFunctor* f) {
00816 if ( m_smoothcostFnDelete )
00817 m_smoothcostFnDelete(m_smoothcostFn);
00818 if ( m_smoothcostIndividual )
00819 {
00820 delete [] m_smoothcostIndividual;
00821 m_smoothcostIndividual = 0;
00822 }
00823 m_smoothcostFn = f;
00824 m_smoothcostFnDelete = 0;
00825 m_giveSmoothEnergyInternal = &GCoptimization::giveSmoothEnergyInternal<SmoothCostFunctor>;
00826 m_setupSmoothCostsExpansion = &GCoptimization::setupSmoothCostsExpansion<SmoothCostFunctor>;
00827 m_setupSmoothCostsSwap = &GCoptimization::setupSmoothCostsSwap<SmoothCostFunctor>;
00828 }
00829
00830
00831
00832 void GCoptimization::setLabelCost(EnergyTermType cost)
00833 {
00834 EnergyTermType* lc = new EnergyTermType[m_num_labels];
00835 for ( LabelID i = 0; i < m_num_labels; ++i )
00836 lc[i] = cost;
00837 setLabelCost(lc);
00838 delete [] lc;
00839 }
00840
00841
00842
00843 void GCoptimization::setLabelCost(EnergyTermType *costArray)
00844 {
00845 for ( LabelID i = 0; i < m_num_labels; ++i )
00846 setLabelSubsetCost(&i, 1, costArray[i]);
00847 }
00848
00849
00850
00851 void GCoptimization::setLabelSubsetCost(LabelID* labels, LabelID numLabels, EnergyTermType cost)
00852 {
00853 if ( cost < 0 )
00854 handleError("Label costs must be non-negative.");
00855 if ( cost > GCO_MAX_ENERGYTERM )
00856 handleError("Label cost was larger than GCO_MAX_ENERGYTERM; danger of integer overflow.");
00857 for ( LabelID i = 0; i < numLabels; ++i)
00858 if ( labels[i] < 0 || labels[i] >= m_num_labels )
00859 handleError("Invalid label id was found in label subset list.");
00860
00861 if ( !m_labelcostsByLabel ) {
00862 m_labelcostsByLabel = new LabelCostIter*[m_num_labels];
00863 memset(m_labelcostsByLabel, 0, m_num_labels*sizeof(void*));
00864 }
00865
00866
00867 for ( LabelCostIter* lci = m_labelcostsByLabel[labels[0]]; lci; lci = lci->next ) {
00868 if ( numLabels == lci->node->numLabels ) {
00869 if ( !memcmp(labels, lci->node->labels, numLabels*sizeof(LabelID)) ) {
00870
00871 lci->node->cost = cost;
00872 return;
00873 }
00874 }
00875 }
00876
00877 if (cost == 0)
00878 return;
00879
00880
00881 m_labelcostCount++;
00882 LabelCost* lc = new LabelCost;
00883 lc->cost = cost;
00884 lc->active = false;
00885 lc->aux = -1;
00886 lc->numLabels = numLabels;
00887 lc->labels = new LabelID[numLabels];
00888 memcpy(lc->labels, labels, numLabels*sizeof(LabelID));
00889 slist_prepend(m_labelcostsAll, lc);
00890 for ( LabelID i = 0; i < numLabels; ++i ) {
00891 LabelCostIter* lci = new LabelCostIter;
00892 lci->node = lc;
00893 slist_prepend(m_labelcostsByLabel[labels[i]], lci);
00894 }
00895 }
00896
00897
00898
00899 void GCoptimization::whatLabel(SiteID start, SiteID count, LabelID* labeling)
00900 {
00901 assert(start >= 0 && start+count <= m_num_sites);
00902 memcpy(labeling, m_labeling+start, count*sizeof(LabelID));
00903 }
00904
00905
00906
00907 GCoptimization::EnergyType GCoptimization::giveSmoothEnergy()
00908 {
00909 finalizeNeighbors();
00910 if ( m_giveSmoothEnergyInternal )
00911 return( (this->*m_giveSmoothEnergyInternal)());
00912 return 0;
00913 }
00914
00915
00916
00917 GCoptimization::EnergyType GCoptimization::giveDataEnergy()
00918 {
00919 updateLabelingInfo();
00920 EnergyType energy = 0;
00921 for ( SiteID i = 0; i < m_num_sites; i++ )
00922 energy += m_labelingDataCosts[i];
00923 return energy;
00924 }
00925
00926 GCoptimization::EnergyType GCoptimization::giveLabelEnergy()
00927 {
00928 updateLabelingInfo();
00929 EnergyType energy = 0;
00930 for ( LabelCost* lc = m_labelcostsAll; lc; lc = lc->next)
00931 if ( lc->active )
00932 energy += lc->cost;
00933 return energy;
00934 }
00935
00936
00937
00938 GCoptimization::EnergyType GCoptimization::compute_energy()
00939 {
00940 return giveDataEnergy() + giveSmoothEnergy() + giveLabelEnergy();
00941 }
00942
00943
00944
00945 void GCoptimization::permuteLabelTable()
00946 {
00947 if ( !m_random_label_order )
00948 return;
00949 for ( LabelID i = 0; i < m_num_labels; i++ )
00950 {
00951 LabelID j = i + (rand() % (m_num_labels-i));
00952 LabelID temp = m_labelTable[i];
00953 m_labelTable[i] = m_labelTable[j];
00954 m_labelTable[j] = temp;
00955 }
00956 }
00957
00958
00959
00960 GCoptimization::EnergyType GCoptimization::expansion(int max_num_iterations)
00961 {
00962 EnergyType new_energy, old_energy;
00963 if ( (this->*m_solveSpecialCases)(new_energy) )
00964 return new_energy;
00965
00966 permuteLabelTable();
00967 updateLabelingInfo();
00968
00969 try
00970 {
00971 if ( max_num_iterations == -1 )
00972 {
00973
00974 printStatus1("starting alpha-expansion w/ adaptive cycles");
00975 std::vector<LabelID> queueSizes;
00976 queueSizes.push_back(m_num_labels);
00977
00978 int cycle = 1;
00979 LabelID next = 0;
00980 do
00981 {
00982 gcoclock_t ticks0 = gcoclock();
00983 m_stepsThisCycle = 0;
00984
00985
00986 LabelID queueSize = queueSizes.back();
00987 LabelID start = next;
00988 m_stepsThisCycleTotal = queueSize - start;
00989 do
00990 {
00991 if ( !alpha_expansion(m_labelTable[next]) )
00992 std::swap(m_labelTable[next],m_labelTable[--queueSize]);
00993 else
00994 ++next;
00995 m_stepsThisCycle++;
00996 } while ( next < queueSize );
00997
00998 if ( next == start )
00999 {
01000 next = queueSizes.back();
01001 queueSizes.pop_back();
01002 }
01003 else if ( queueSize < queueSizes.back()/2 )
01004 {
01005 next = 0;
01006 queueSizes.push_back(queueSize);
01007 }
01008 else
01009 next = 0;
01010
01011 printStatus1(cycle++,false,ticks0);
01012 } while ( !queueSizes.empty() );
01013 new_energy = compute_energy();
01014 }
01015 else
01016 {
01017
01018 printStatus1("starting alpha-expansion w/ standard cycles");
01019 new_energy = compute_energy();
01020 old_energy = new_energy+1;
01021 for ( int cycle = 1; cycle <= max_num_iterations; cycle++ )
01022 {
01023 gcoclock_t ticks0 = gcoclock();
01024 old_energy = new_energy;
01025 new_energy = oneExpansionIteration();
01026 printStatus1(cycle,false,ticks0);
01027 if ( new_energy == old_energy )
01028 break;
01029 permuteLabelTable();
01030 }
01031 }
01032 }
01033 catch (...)
01034 {
01035 m_stepsThisCycle = m_stepsThisCycleTotal = 0;
01036 throw;
01037 }
01038 m_stepsThisCycle = m_stepsThisCycleTotal = 0;
01039 return new_energy;
01040 }
01041
01042
01043
01044 void GCoptimization::setLabelOrder(bool isRandom)
01045 {
01046 m_random_label_order = isRandom;
01047 for ( LabelID i = 0; i < m_num_labels; i++ )
01048 m_labelTable[i] = i;
01049 }
01050
01051
01052
01053 void GCoptimization::setLabelOrder(const LabelID* order, LabelID size)
01054 {
01055 if ( size > m_num_labels )
01056 handleError("setLabelOrder receieved too many labels");
01057 for ( LabelID i = 0; i < size; ++i )
01058 if ( order[i] < 0 || order[i] >= m_num_labels )
01059 handleError("Invalid label id in setLabelOrder");
01060 m_random_label_order = false;
01061 memcpy(m_labelTable,order,size*sizeof(LabelID));
01062 memset(m_labelTable+size,-1,(m_num_labels-size)*sizeof(LabelID));
01063 }
01064
01065
01066
01067 void GCoptimization::handleError(const char *message)
01068 {
01069 throw GCException(message);
01070 }
01071
01072
01073
01074 void GCoptimization::checkInterrupt()
01075 {
01076 if ( utIsInterruptPending() )
01077 throw GCException("Interrupted.");
01078 }
01079
01080
01081
01082
01083
01084
01085 GCoptimization::EnergyType GCoptimization::setupLabelCostsExpansion(SiteID size,LabelID alpha_label,EnergyT *e,SiteID *activeSites)
01086 {
01087 EnergyType alphaCostCorrection = 0;
01088 if ( !m_labelcostsAll )
01089 return alphaCostCorrection;
01090
01091 const SiteID DISABLE = -2;
01092 const SiteID UNINIT = -1;
01093 for ( LabelCost* lc = m_labelcostsAll; lc; lc = lc->next )
01094 lc->aux = UNINIT;
01095
01096
01097
01098 if ( m_queryActiveSitesExpansion == (SiteID (GCoptimization::*)(LabelID,SiteID*))&GCoptimization::queryActiveSitesExpansion<DataCostFnSparse> )
01099 {
01100
01101
01102 memset(m_activeLabelCounts,0,m_num_labels*sizeof(SiteID));
01103 for ( SiteID i = 0; i < size; ++i )
01104 m_activeLabelCounts[m_labeling[activeSites[i]]]++;
01105
01106 for ( LabelID l = 0; l < m_num_labels; ++l )
01107 {
01108 if ( m_activeLabelCounts[l] != m_labelCounts[l] )
01109 {
01110 for ( LabelCostIter* lcj = m_labelcostsByLabel[l]; lcj; lcj = lcj->next )
01111 lcj->node->aux = DISABLE;
01112 }
01113 }
01114 }
01115 for ( LabelCostIter* lci = m_labelcostsByLabel[alpha_label]; lci; lci = lci->next )
01116 lci->node->aux = DISABLE;
01117
01118
01119
01120 if ( !m_labelCounts[alpha_label] )
01121 {
01122 for ( LabelCostIter* lci = m_labelcostsByLabel[alpha_label]; lci; lci = lci->next )
01123 if ( !lci->node->active )
01124 alphaCostCorrection += lci->node->cost;
01125 }
01126
01127
01128 for ( SiteID i = 0; i < size; i++ )
01129 {
01130 LabelID label_i = m_labeling[activeSites[i]];
01131 for ( LabelCostIter* lci = m_labelcostsByLabel[label_i]; lci; lci = lci->next )
01132 {
01133 LabelCost* lc = lci->node;
01134 if ( lc->aux == DISABLE )
01135 continue;
01136
01137
01138 if ( lc->aux == UNINIT )
01139 {
01140 lc->aux = e->add_variable();
01141 e->add_term1(lc->aux,0,lc->cost);
01142 m_beforeExpansionEnergy += lc->cost;
01143 }
01144 e->add_term2(i,lc->aux,0,0,lc->cost,0);
01145 }
01146 }
01147
01148 return alphaCostCorrection;
01149 }
01150
01151
01152 void GCoptimization::updateLabelingInfo(bool updateCounts, bool updateActive, bool updateCosts)
01153 {
01154 if ( !m_labelingInfoDirty )
01155 return;
01156
01157 m_labelingInfoDirty = false;
01158
01159 if ( m_labelcostsAll )
01160 {
01161 if ( updateCounts )
01162 {
01163 memset(m_labelCounts,0,m_num_labels*sizeof(SiteID));
01164 for ( SiteID i = 0; i < m_num_sites; ++i )
01165 m_labelCounts[m_labeling[i]]++;
01166 }
01167
01168 if ( updateActive )
01169 {
01170 for ( LabelCost* lc = m_labelcostsAll; lc; lc = lc->next )
01171 lc->active = false;
01172
01173 EnergyType energy = 0;
01174 for ( LabelID l = 0; l < m_num_labels; ++l )
01175 if ( m_labelCounts[l] )
01176 for ( LabelCostIter* lci = m_labelcostsByLabel[l]; lci; lci = lci->next )
01177 lci->node->active = true;
01178 }
01179 }
01180
01181 if ( updateCosts )
01182 {
01183 if (m_updateLabelingDataCosts)
01184 (this->*m_updateLabelingDataCosts)();
01185 else
01186 memset(m_labelingDataCosts,0,m_num_sites*sizeof(EnergyTermType));
01187 }
01188 }
01189
01190
01191
01192
01193 bool GCoptimization::alpha_expansion(LabelID alpha_label)
01194 {
01195 if (alpha_label < 0)
01196 return false;
01197
01198 finalizeNeighbors();
01199 gcoclock_t ticks0 = gcoclock();
01200
01201 if ( m_stepsThisCycleTotal == 0 )
01202 m_labelingInfoDirty = true;
01203 updateLabelingInfo();
01204
01205
01206 SiteID size = 0;
01207 SiteID *activeSites = new SiteID[m_num_sites];
01208 EnergyType afterExpansionEnergy = 0;
01209 try
01210 {
01211
01212 if ( m_queryActiveSitesExpansion )
01213 size = (this->*m_queryActiveSitesExpansion)(alpha_label,activeSites);
01214 if ( size == 0 )
01215 {
01216 delete [] activeSites;
01217 printStatus2(alpha_label,-1,size,ticks0);
01218 return false;
01219 }
01220
01221
01222
01223 for ( SiteID i = 0; i < size; i++ )
01224 m_lookupSiteVar[activeSites[i]] = i;
01225
01226
01227
01228 EnergyT e(size+m_labelcostCount,
01229 m_numNeighborsTotal+(m_labelcostCount?size+m_labelcostCount : 0),
01230 (void(*)(char*))handleError);
01231 e.add_variable(size);
01232 m_beforeExpansionEnergy = 0;
01233 if ( m_setupDataCostsExpansion ) (this->*m_setupDataCostsExpansion )(size,alpha_label,&e,activeSites);
01234 if ( m_setupSmoothCostsExpansion ) (this->*m_setupSmoothCostsExpansion)(size,alpha_label,&e,activeSites);
01235 EnergyType alphaCorrection = setupLabelCostsExpansion(size,alpha_label,&e,activeSites);
01236 checkInterrupt();
01237 afterExpansionEnergy = e.minimize() + alphaCorrection;
01238 checkInterrupt();
01239
01240 if ( afterExpansionEnergy < m_beforeExpansionEnergy )
01241 (this->*m_applyNewLabeling)(&e,activeSites,size,alpha_label);
01242
01243 for ( SiteID i = 0; i < size; i++ )
01244 m_lookupSiteVar[activeSites[i]] = -1;
01245
01246 printStatus2(alpha_label,-1,size,ticks0);
01247 }
01248 catch (...)
01249 {
01250 delete [] activeSites;
01251 throw;
01252 }
01253 delete [] activeSites;
01254 return afterExpansionEnergy < m_beforeExpansionEnergy;
01255 }
01256
01257
01258
01259 GCoptimization::EnergyType GCoptimization::oneExpansionIteration()
01260 {
01261 permuteLabelTable();
01262 m_stepsThisCycle = 0;
01263 m_stepsThisCycleTotal = m_num_labels;
01264
01265
01266 for (LabelID next = 0; next < m_num_labels; next++, m_stepsThisCycle++ )
01267 alpha_expansion(m_labelTable[next]);
01268
01269 return compute_energy();
01270 }
01271
01272
01273
01274
01275
01276 GCoptimization::EnergyType GCoptimization::swap(int max_num_iterations)
01277 {
01278 EnergyType new_energy,old_energy;
01279 if ( (this->*m_solveSpecialCases)(new_energy) )
01280 return new_energy;
01281
01282 new_energy = compute_energy();
01283 old_energy = new_energy+1;
01284 printStatus1("starting alpha/beta-swap");
01285
01286 if ( max_num_iterations == -1 )
01287 max_num_iterations = 10000000;
01288 int curr_cycle = 1;
01289 m_stepsThisCycleTotal = (m_num_labels*(m_num_labels-1))/2;
01290 try
01291 {
01292 while ( old_energy > new_energy && curr_cycle <= max_num_iterations)
01293 {
01294 gcoclock_t ticks0 = gcoclock();
01295 old_energy = new_energy;
01296 new_energy = oneSwapIteration();
01297 printStatus1(curr_cycle,true,ticks0);
01298 curr_cycle++;
01299 }
01300 }
01301 catch (...)
01302 {
01303 m_stepsThisCycle = m_stepsThisCycleTotal = 0;
01304 throw;
01305 }
01306 m_stepsThisCycle = m_stepsThisCycleTotal = 0;
01307
01308 return(new_energy);
01309 }
01310
01311
01312
01313 GCoptimization::EnergyType GCoptimization::oneSwapIteration()
01314 {
01315 LabelID next,next1;
01316 permuteLabelTable();
01317 m_stepsThisCycle = 0;
01318
01319 for (next = 0; next < m_num_labels; next++ )
01320 for (next1 = m_num_labels - 1; next1 >= 0; next1-- )
01321 if ( m_labelTable[next] < m_labelTable[next1] )
01322 {
01323 alpha_beta_swap(m_labelTable[next],m_labelTable[next1]);
01324 m_stepsThisCycle++;
01325 }
01326
01327 return(compute_energy());
01328 }
01329
01330
01331
01332 void GCoptimization::alpha_beta_swap(LabelID alpha_label, LabelID beta_label)
01333 {
01334 assert( alpha_label >= 0 && alpha_label < m_num_labels && beta_label >= 0 && beta_label < m_num_labels);
01335 if ( m_labelcostsAll )
01336 handleError("Label costs only implemented for alpha-expansion.");
01337
01338 finalizeNeighbors();
01339 gcoclock_t ticks0 = gcoclock();
01340
01341
01342 SiteID size = 0;
01343 SiteID *activeSites = new SiteID[m_num_sites];
01344 try
01345 {
01346 for ( SiteID i = 0; i < m_num_sites; i++ )
01347 {
01348 if ( m_labeling[i] == alpha_label || m_labeling[i] == beta_label )
01349 {
01350 activeSites[size] = i;
01351 m_lookupSiteVar[i] = size;
01352 size++;
01353 }
01354 }
01355 if ( size == 0 )
01356 {
01357 delete [] activeSites;
01358 printStatus2(alpha_label,beta_label,size,ticks0);
01359 return;
01360 }
01361
01362
01363
01364 EnergyT e(size,m_numNeighborsTotal,(void(*)(char*))handleError);
01365 e.add_variable(size);
01366 if ( m_setupDataCostsSwap ) (this->*m_setupDataCostsSwap )(size,alpha_label,beta_label,&e,activeSites);
01367 if ( m_setupSmoothCostsSwap ) (this->*m_setupSmoothCostsSwap)(size,alpha_label,beta_label,&e,activeSites);
01368 checkInterrupt();
01369 e.minimize();
01370 checkInterrupt();
01371
01372
01373 for ( SiteID i = 0; i < size; i++ )
01374 {
01375 m_labeling[activeSites[i]] = (e.get_var(i) == 0) ? alpha_label : beta_label;
01376 m_lookupSiteVar[activeSites[i]] = -1;
01377 }
01378 m_labelingInfoDirty = true;
01379 }
01380 catch (...)
01381 {
01382 delete [] activeSites;
01383 throw;
01384 }
01385 delete [] activeSites;
01386
01387 printStatus2(alpha_label,beta_label,size,ticks0);
01388 }
01389
01390
01392
01394
01395 GCoptimizationGridGraph::GCoptimizationGridGraph(SiteID width, SiteID height,LabelID num_labels)
01396 :GCoptimization(width*height,num_labels)
01397 {
01398 assert( (width > 1) && (height > 1) && (num_labels > 1 ));
01399
01400 m_weightedGraph = 0;
01401 for (int i = 0; i < 4; i ++ ) m_unityWeights[i] = 1;
01402
01403 m_width = width;
01404 m_height = height;
01405
01406 m_numNeighbors = new SiteID[m_num_sites];
01407 m_neighbors = new SiteID[4*m_num_sites];
01408
01409 SiteID indexes[4] = {-1,1,-m_width,m_width};
01410
01411 SiteID indexesL[3] = {1,-m_width,m_width};
01412 SiteID indexesR[3] = {-1,-m_width,m_width};
01413 SiteID indexesU[3] = {1,-1,m_width};
01414 SiteID indexesD[3] = {1,-1,-m_width};
01415
01416 SiteID indexesUL[2] = {1,m_width};
01417 SiteID indexesUR[2] = {-1,m_width};
01418 SiteID indexesDL[2] = {1,-m_width};
01419 SiteID indexesDR[2] = {-1,-m_width};
01420
01421 setupNeighbData(1,m_height-1,1,m_width-1,4,indexes);
01422
01423 setupNeighbData(1,m_height-1,0,1,3,indexesL);
01424 setupNeighbData(1,m_height-1,m_width-1,m_width,3,indexesR);
01425 setupNeighbData(0,1,1,width-1,3,indexesU);
01426 setupNeighbData(m_height-1,m_height,1,m_width-1,3,indexesD);
01427
01428 setupNeighbData(0,1,0,1,2,indexesUL);
01429 setupNeighbData(0,1,m_width-1,m_width,2,indexesUR);
01430 setupNeighbData(m_height-1,m_height,0,1,2,indexesDL);
01431 setupNeighbData(m_height-1,m_height,m_width-1,m_width,2,indexesDR);
01432 }
01433
01434
01435
01436 GCoptimizationGridGraph::~GCoptimizationGridGraph()
01437 {
01438 delete [] m_numNeighbors;
01439 if ( m_neighbors )
01440 delete [] m_neighbors;
01441 if (m_weightedGraph) delete [] m_neighborsWeights;
01442 }
01443
01444
01445
01446
01447 void GCoptimizationGridGraph::setupNeighbData(SiteID startY,SiteID endY,SiteID startX,
01448 SiteID endX,SiteID maxInd,SiteID *indexes)
01449 {
01450 SiteID x,y,pix;
01451 SiteID n;
01452
01453 for ( y = startY; y < endY; y++ )
01454 for ( x = startX; x < endX; x++ )
01455 {
01456 pix = x+y*m_width;
01457 m_numNeighbors[pix] = maxInd;
01458 m_numNeighborsTotal += maxInd;
01459
01460 for (n = 0; n < maxInd; n++ )
01461 m_neighbors[pix*4+n] = pix+indexes[n];
01462 }
01463 }
01464
01465
01466
01467 void GCoptimizationGridGraph::finalizeNeighbors()
01468 {
01469 }
01470
01471
01472
01473 void GCoptimizationGridGraph::setSmoothCostVH(EnergyTermType *smoothArray, EnergyTermType *vCosts, EnergyTermType *hCosts)
01474 {
01475 setSmoothCost(smoothArray);
01476 m_weightedGraph = 1;
01477 computeNeighborWeights(vCosts,hCosts);
01478 }
01479
01480
01481
01482 void GCoptimizationGridGraph::giveNeighborInfo(SiteID site, SiteID *numSites, SiteID **neighbors, EnergyTermType **weights)
01483 {
01484 *numSites = m_numNeighbors[site];
01485 *neighbors = &m_neighbors[site*4];
01486
01487 if (m_weightedGraph) *weights = &m_neighborsWeights[site*4];
01488 else *weights = m_unityWeights;
01489 }
01490
01491
01492
01493 void GCoptimizationGridGraph::computeNeighborWeights(EnergyTermType *vCosts,EnergyTermType *hCosts)
01494 {
01495 SiteID i,n,nSite;
01496 GCoptimization::EnergyTermType weight;
01497
01498 m_neighborsWeights = new EnergyTermType[m_num_sites*4];
01499
01500 for ( i = 0; i < m_num_sites; i++ )
01501 {
01502 for ( n = 0; n < m_numNeighbors[i]; n++ )
01503 {
01504 nSite = m_neighbors[4*i+n];
01505 if ( i-nSite == 1 ) weight = hCosts[nSite];
01506 else if (i-nSite == -1 ) weight = hCosts[i];
01507 else if ( i-nSite == m_width ) weight = vCosts[nSite];
01508 else if (i-nSite == -m_width ) weight = vCosts[i];
01509
01510 m_neighborsWeights[i*4+n] = weight;
01511 }
01512 }
01513
01514 }
01516
01518
01519 GCoptimizationGeneralGraph::GCoptimizationGeneralGraph(SiteID num_sites,LabelID num_labels):GCoptimization(num_sites,num_labels)
01520 {
01521 assert( num_sites > 1 && num_labels > 1 );
01522
01523 m_neighborsIndexes = 0;
01524 m_neighborsWeights = 0;
01525 m_numNeighbors = 0;
01526 m_neighbors = 0;
01527
01528 m_needTodeleteNeighbors = true;
01529 m_needToFinishSettingNeighbors = true;
01530 }
01531
01532
01533
01534 GCoptimizationGeneralGraph::~GCoptimizationGeneralGraph()
01535 {
01536
01537 if ( m_neighbors )
01538 delete [] m_neighbors;
01539
01540 if ( m_numNeighbors && m_needTodeleteNeighbors )
01541 {
01542 for ( SiteID i = 0; i < m_num_sites; i++ )
01543 {
01544 if (m_numNeighbors[i] != 0 )
01545 {
01546 delete [] m_neighborsIndexes[i];
01547 delete [] m_neighborsWeights[i];
01548 }
01549 }
01550
01551 delete [] m_numNeighbors;
01552 delete [] m_neighborsIndexes;
01553 delete [] m_neighborsWeights;
01554 }
01555 }
01556
01557
01558
01559 void GCoptimizationGeneralGraph::finalizeNeighbors()
01560 {
01561 if ( !m_needToFinishSettingNeighbors )
01562 return;
01563 m_needToFinishSettingNeighbors = false;
01564
01565 Neighbor *tmp;
01566 SiteID i,site,count;
01567
01568 EnergyTermType *tempWeights = new EnergyTermType[m_num_sites];
01569 SiteID *tempIndexes = new SiteID[m_num_sites];
01570
01571 if ( !tempWeights || !tempIndexes ) handleError("Not enough memory");
01572
01573 m_numNeighbors = new SiteID[m_num_sites];
01574 m_neighborsIndexes = new SiteID*[m_num_sites];
01575 m_neighborsWeights = new EnergyTermType*[m_num_sites];
01576
01577 if ( !m_numNeighbors || !m_neighborsIndexes || !m_neighborsWeights ) handleError("Not enough memory.");
01578
01579 for ( site = 0; site < m_num_sites; site++ )
01580 {
01581 if ( m_neighbors && !m_neighbors[site].isEmpty() )
01582 {
01583 m_neighbors[site].setCursorFront();
01584 count = 0;
01585
01586 while ( m_neighbors[site].hasNext() )
01587 {
01588 tmp = (Neighbor *) (m_neighbors[site].next());
01589 tempIndexes[count] = tmp->to_node;
01590 tempWeights[count] = tmp->weight;
01591 delete tmp;
01592 count++;
01593 }
01594 m_numNeighbors[site] = count;
01595 m_numNeighborsTotal += count;
01596 m_neighborsIndexes[site] = new SiteID[count];
01597 m_neighborsWeights[site] = new EnergyTermType[count];
01598
01599 if ( !m_neighborsIndexes[site] || !m_neighborsWeights[site] ) handleError("Not enough memory.");
01600
01601 for ( i = 0; i < count; i++ )
01602 {
01603 m_neighborsIndexes[site][i] = tempIndexes[i];
01604 m_neighborsWeights[site][i] = tempWeights[i];
01605 }
01606 }
01607 else m_numNeighbors[site] = 0;
01608
01609 }
01610
01611 delete [] tempIndexes;
01612 delete [] tempWeights;
01613 if (m_neighbors) {
01614 delete [] m_neighbors;
01615 m_neighbors = 0;
01616 }
01617 }
01618
01619
01620 void GCoptimizationGeneralGraph::giveNeighborInfo(SiteID site, SiteID *numSites,
01621 SiteID **neighbors, EnergyTermType **weights)
01622 {
01623 if (m_numNeighbors) {
01624 (*numSites) = m_numNeighbors[site];
01625 (*neighbors) = m_neighborsIndexes[site];
01626 (*weights) = m_neighborsWeights[site];
01627 } else {
01628 *numSites = 0;
01629 *neighbors = 0;
01630 *weights = 0;
01631 }
01632 }
01633
01634
01635
01636
01637 void GCoptimizationGeneralGraph::setNeighbors(SiteID site1, SiteID site2, EnergyTermType weight)
01638 {
01639
01640 assert( site1 < m_num_sites && site1 >= 0 && site2 < m_num_sites && site2 >= 0);
01641 if ( m_needToFinishSettingNeighbors == false )
01642 handleError("Already set up neighborhood system.");
01643
01644 if ( !m_neighbors )
01645 {
01646 m_neighbors = (LinkedBlockList *) new LinkedBlockList[m_num_sites];
01647 if ( !m_neighbors ) handleError("Not enough memory.");
01648 }
01649
01650 Neighbor *temp1 = (Neighbor *) new Neighbor;
01651 Neighbor *temp2 = (Neighbor *) new Neighbor;
01652
01653 temp1->weight = weight;
01654 temp1->to_node = site2;
01655
01656 temp2->weight = weight;
01657 temp2->to_node = site1;
01658
01659 m_neighbors[site1].addFront(temp1);
01660 m_neighbors[site2].addFront(temp2);
01661
01662 }
01663
01664
01665 void GCoptimizationGeneralGraph::setAllNeighbors(SiteID *numNeighbors,SiteID **neighborsIndexes,
01666 EnergyTermType **neighborsWeights)
01667 {
01668 m_needTodeleteNeighbors = false;
01669 m_needToFinishSettingNeighbors = false;
01670 if ( m_numNeighborsTotal > 0 )
01671 handleError("Already set up neighborhood system.");
01672 m_numNeighbors = numNeighbors;
01673 m_numNeighborsTotal = 0;
01674 for (int site = 0; site < m_num_sites; site++ ) m_numNeighborsTotal += m_numNeighbors[site];
01675 m_neighborsIndexes = neighborsIndexes;
01676 m_neighborsWeights = neighborsWeights;
01677 }
01678
01679
01680
01681
01682
01683
01684 void GCoptimization::printStatus1(const char* extraMsg)
01685 {
01686 if ( m_verbosity < 1 )
01687 return;
01688 if ( extraMsg )
01689 printf("gco>> %s\n",extraMsg);
01690 printf("gco>> initial energy: \tE=%lld (E=%lld+%lld+%lld)\n",(long long)compute_energy(),
01691 (long long)giveDataEnergy(), (long long)giveSmoothEnergy(), (long long)giveLabelEnergy());
01692 flushnow();
01693 }
01694
01695 void GCoptimization::printStatus1(int cycle, bool isSwap, gcoclock_t ticks0)
01696 {
01697 if ( m_verbosity < 1 )
01698 return;
01699 gcoclock_t ticks1 = gcoclock();
01700 printf("gco>> after cycle %2d: \tE=%lld (E=%lld+%lld+%lld);",cycle,(long long)compute_energy(),
01701 (long long)giveDataEnergy(),(long long)giveSmoothEnergy(),(long long)giveLabelEnergy());
01702 if ( m_stepsThisCycleTotal > 0 )
01703 printf(isSwap ? " \t%d swaps(s);" : " \t%d expansions(s);",m_stepsThisCycleTotal);
01704 if ( m_verbosity == 1 )
01705 {
01706
01707
01708 int ms = (int)(1000*(ticks1 - ticks0) / GCO_CLOCKS_PER_SEC);
01709 printf(" \t%d ms",ms);
01710 }
01711 printf("\n");
01712 flushnow();
01713 }
01714
01715 void GCoptimization::printStatus2(int alpha, int beta, int numVars, gcoclock_t ticks0)
01716 {
01717 if ( m_verbosity < 2 )
01718 return;
01719 int microsec = (int)(1000000*(gcoclock() - ticks0) / GCO_CLOCKS_PER_SEC);
01720 if ( beta >= 0 )
01721 printf("gco>> after swap(%d,%d):",alpha+INDEX0,beta+INDEX0);
01722 else
01723 printf("gco>> after expansion(%d):",alpha+INDEX0);
01724 printf(" \tE=%lld (E=%lld+%lld+%lld);\t %lld vars;",
01725 (long long)compute_energy(),(long long)giveDataEnergy(),
01726 (long long)giveSmoothEnergy(),(long long)giveLabelEnergy(),(long long)numVars);
01727 if ( m_stepsThisCycleTotal > 0 )
01728 printf(" \t(%d of %d);",m_stepsThisCycle+1,m_stepsThisCycleTotal);
01729
01730 printf(microsec > 100 ? "\t %.2f ms\n" : "\t %.3f ms\n",(double)microsec/1000.0);
01731 flushnow();
01732 }
01733
01734
01735
01736
01737
01738
01739
01740
01741
01742 GCoptimization::DataCostFnSparse::DataCostFnSparse(SiteID num_sites, LabelID num_labels)
01743 : m_num_sites(num_sites)
01744 , m_num_labels(num_labels)
01745 , m_buckets_per_label((m_num_sites + cSitesPerBucket-1)/cSitesPerBucket)
01746 , m_buckets(0)
01747 {
01748 }
01749
01750 GCoptimization::DataCostFnSparse::DataCostFnSparse(const DataCostFnSparse& src)
01751 : m_num_sites(src.m_num_sites)
01752 , m_num_labels(src.m_num_labels)
01753 , m_buckets_per_label(src.m_buckets_per_label)
01754 , m_buckets(0)
01755 {
01756 assert(!src.m_buckets);
01757 }
01758
01759 GCoptimization::DataCostFnSparse::~DataCostFnSparse()
01760 {
01761 if (m_buckets) {
01762 for (LabelID l = 0; l < m_num_labels; ++l)
01763 if (m_buckets[l*m_buckets_per_label].begin)
01764 delete [] m_buckets[l*m_buckets_per_label].begin;
01765 delete [] m_buckets;
01766 }
01767 }
01768
01769 void GCoptimization::DataCostFnSparse::set(LabelID l, const SparseDataCost* costs, SiteID count)
01770 {
01771
01772
01773 if (!m_buckets) {
01774 m_buckets = new DataCostBucket[m_num_labels*m_buckets_per_label];
01775 memset(m_buckets, 0, m_num_labels*m_buckets_per_label*sizeof(DataCostBucket));
01776 }
01777
01778 DataCostBucket* b = &m_buckets[l*m_buckets_per_label];
01779 if (b->begin)
01780 delete [] b->begin;
01781 SparseDataCost* next = new SparseDataCost[count];
01782 memcpy(next,costs,count*sizeof(SparseDataCost));
01783
01784
01785
01786
01787
01788 const SparseDataCost* end = next+count;
01789 SiteID prev_site = -1;
01790 for (int i = 0; i < m_buckets_per_label; ++i) {
01791 b[i].begin = b[i].predict = next;
01792 SiteID end_site = (i+1)*cSitesPerBucket;
01793 while (next < end && next->site < end_site) {
01794 if (next->site < 0 || next->site >= m_num_sites)
01795 throw GCException("Invalid site id given for sparse data cost; must be within range.");
01796 if (next->site <= prev_site)
01797 throw GCException("Sparse data costs must be sorted in increasing order of SiteID");
01798 prev_site = next->site;
01799 ++next;
01800 }
01801 b[i].end = next;
01802 }
01803 }
01804
01805 GCoptimization::EnergyTermType GCoptimization::DataCostFnSparse::search(DataCostBucket& b, SiteID s)
01806 {
01807
01808
01809 const SparseDataCost* L = b.begin;
01810 const SparseDataCost* R = b.end-1;
01811 if ( R - L == m_num_sites )
01812 return b.begin[s].cost;
01813 do {
01814 const SparseDataCost* mid = (const SparseDataCost*)((((size_t)L+(size_t)R) >> 1) & cDataCostPtrMask);
01815 if (s < mid->site)
01816 R = mid-1;
01817 else if (mid->site < s)
01818 L = mid+1;
01819 else {
01820 b.predict = mid+1;
01821 return mid->cost;
01822 }
01823 } while (R - L > cLinearSearchSize);
01824
01825
01826
01827 do {
01828 if (L->site >= s) {
01829 if (L->site == s) {
01830 b.predict = L+1;
01831 return L->cost;
01832 }
01833 break;
01834 }
01835 } while (++L <= R);
01836 b.predict = L;
01837
01838 return GCO_MAX_ENERGYTERM;
01839 }
01840
01841 OLGA_INLINE GCoptimization::EnergyTermType GCoptimization::DataCostFnSparse::compute(SiteID s, LabelID l)
01842 {
01843 DataCostBucket& b = m_buckets[l*m_buckets_per_label + (s >> cLogSitesPerBucket)];
01844 if (b.begin == b.end)
01845 return GCO_MAX_ENERGYTERM;
01846 if (b.predict < b.end) {
01847
01848 if (b.predict->site == s)
01849 return (b.predict++)->cost;
01850
01851
01852
01853 if (b.predict->site > s && b.predict > b.begin && (b.predict-1)->site < s)
01854 return GCO_MAX_ENERGYTERM;
01855 }
01856 if ( (size_t)b.end - (size_t)b.begin == cSitesPerBucket*sizeof(SparseDataCost) )
01857 return b.begin[s-b.begin->site].cost;
01858
01859 return search(b,s);
01860 }
01861
01862 GCoptimization::SiteID GCoptimization::DataCostFnSparse::queryActiveSitesExpansion(LabelID alpha_label, const LabelID* labeling, SiteID* activeSites)
01863 {
01864 const SparseDataCost* next = m_buckets[alpha_label*m_buckets_per_label].begin;
01865 const SparseDataCost* end = m_buckets[alpha_label*m_buckets_per_label + m_buckets_per_label-1].end;
01866 SiteID count = 0;
01867 for (; next < end; ++next) {
01868 if ( labeling[next->site] != alpha_label )
01869 activeSites[count++] = next->site;
01870 }
01871 return count;
01872 }
01873