ChowLiuTree.cpp
Go to the documentation of this file.
00001 
00029 #include "isam/ChowLiuTree.h"
00030 #include "isam/util.h"
00031 
00032 using namespace std;
00033 using namespace Eigen;
00034 
00035 namespace isam {
00036 
00037 MatrixXd matslice (const MatrixXd A, vector<int> ii, vector<int> jj) {
00038   MatrixXd B(ii.size(), jj.size());
00039   for (size_t i=0; i<ii.size(); i++) {
00040     for (size_t j=0; j<jj.size(); j++) {
00041       B(i,j) = A(ii[i], jj[j]);
00042     }
00043   }
00044   return B;
00045 }
00046 
00047 
00048 MatrixXd ChowLiuTreeInfo::marginal(int id) {
00049 
00050   // get the indicies associated with the node
00051   vector<int> iia(0);
00052   vector<int> iib(0);
00053   int off = 0;
00054   for (size_t i=0; i<_nodes.size(); i++) {
00055     if ((int)i == id)
00056       for (int j=0; j<_nodes[i]->dim(); j++) iia.push_back(off+j);
00057     else
00058       for (int j=0; j<_nodes[i]->dim(); j++) iib.push_back(off+j);
00059     off += _nodes[i]->dim();
00060   }
00061 
00062   if (iib.size() > 0) {
00063       MatrixXd Laa = matslice(_L, iia, iia);
00064       MatrixXd Lbb = matslice(_L, iib, iib);
00065       MatrixXd Lab = matslice(_L, iia, iib);
00066       MatrixXd Lbbinv = posdef_pinv(Lbb);
00067       return Laa - Lab * Lbbinv * Lab.transpose();
00068       //return Laa - Lab * Lbb.inverse() * Lab.transpose();
00069   } else
00070       return _L;
00071 }
00072 
00073 MatrixXd ChowLiuTreeInfo::joint(int ida, int idb) {
00074 
00075   vector<int> iia(0);
00076   vector<int> iib(0);
00077   vector<int> iic(0);
00078   int off = 0;
00079   for (size_t i=0; i<_nodes.size(); i++) {
00080     if ((int)i == ida)
00081       for (int j=0; j<_nodes[i]->dim(); j++) iia.push_back(off+j);
00082     else if ((int)i == idb)
00083       for (int j=0; j<_nodes[i]->dim(); j++) iib.push_back(off+j);
00084     else
00085       for (int j=0; j<_nodes[i]->dim(); j++) iic.push_back(off+j);
00086     off += _nodes[i]->dim();
00087   }
00088   vector<int> iiab(0);
00089   iiab.insert(iiab.end(), iia.begin(), iia.end());
00090   iiab.insert(iiab.end(), iib.begin(), iib.end());
00091   
00092   if (iic.size() > 0) {
00093     MatrixXd Labab = matslice(_L, iiab, iiab);
00094     MatrixXd Lcc = matslice(_L, iic, iic);
00095     MatrixXd Labc = matslice(_L, iiab, iic);
00096     MatrixXd Lccinv = posdef_pinv(Lcc);
00097     return Labab - Labc * Lccinv * Labc.transpose();
00098     //return Labab - Labc * Lcc.inverse() * Labc.transpose();
00099   } else
00100       return matslice(_L, iiab, iiab);
00101 }
00102 
00103 MatrixXd ChowLiuTreeInfo::conditional(int ida, int idb) {
00104 
00105   MatrixXd Lj = joint(ida, idb);
00106   return Lj.block(0, 0, _nodes[ida]->dim(), _nodes[ida]->dim());
00107 
00108 }
00109 
00110 ChowLiuTree::ChowLiuTree (const Eigen::MatrixXd &L, const std::vector<Node *>& nodes)
00111   : _clt_info(L, nodes)
00112 {
00113   // make sure we have at least two nodes otherwise return a trival tree
00114   if (nodes.size() == 1) {
00115 
00116     ChowLiuTreeNode node;
00117     node.id = 0;
00118     node.pid = -1;
00119     node.marginal = _clt_info.marginal(0);
00120     node.conditional = node.marginal;
00121     node.joint = node.marginal;
00122     tree[node.id] = node;
00123 
00124   } else {
00125 
00126     //calculate the parent nodes based on maximising mutual information
00127     _calc_edges();
00128     _max_span_tree();
00129     tree.clear();
00130     _build_tree_rec(_edges.front().id1, -1);
00131 
00132   }
00133 }
00134 
00135 bool mi_sort(MI &first, MI &second) {
00136   return first.mi > second.mi;
00137 }
00138 
00139 void
00140 ChowLiuTree::_calc_edges()  {
00141 
00142   int nn = _clt_info.num_nodes();
00143   //double npairs = floor(pow((double)bb, 2.0) / 2) -  floor((double)nn/2);
00144 
00145   for (int i=0; i<nn; i++) {
00146     for (int j=(i+1); j<nn; j++) {
00147       MI mi_tmp (i,j, _calc_mi(i, j));
00148       _edges.push_back(mi_tmp);
00149     }
00150   }
00151   _edges.sort(mi_sort);
00152   
00153 }
00154 
00155 double ChowLiuTree::_calc_mi(int ida, int idb) {
00156 
00157   MatrixXd L_agb = _clt_info.conditional (ida, idb);
00158   MatrixXd L_a = _clt_info.marginal (ida);
00159 
00160   // use pdet
00161   //double ldL_agb = plogdet(L_agb);
00162   //double ldL_a = plogdet(L_a);
00163   //double mi = 0.5*(ldL_agb - ldL_a);
00164 
00165   // use normal det, must be pinned
00166   double ldL_agb = log ((L_agb +  MatrixXd::Identity(L_agb.rows(), L_agb.cols())).determinant());
00167   double ldL_a = log ((L_a +  MatrixXd::Identity(L_a.rows(), L_a.cols())).determinant());
00168   double mi = 0.5*(ldL_agb - ldL_a);
00169 
00170   return mi;
00171 }
00172 
00173 void ChowLiuTree::_max_span_tree() {
00174 
00175   // init groups: assign each id to a different group initially
00176   map<int, int> groups; // map <node index, group index>
00177   for (int i=0; i<_clt_info.num_nodes(); i++) {
00178       groups[i] = i;
00179   }
00180     
00181   int group1, group2;
00182   list<MI>::iterator edge = _edges.begin();
00183   while(edge != _edges.end()) {
00184     if(groups[edge->id1] != groups[edge->id2]) {
00185       group1 = groups[edge->id1];
00186       group2 = groups[edge->id2];
00187 
00188       // merge group2 into group 1
00189       map<int, int>::iterator groupIt;
00190       for(groupIt = groups.begin(); groupIt != groups.end(); groupIt++)
00191         if(groupIt->second == group2) groupIt->second = group1;
00192 
00193       edge++;
00194     } else {
00195       edge = _edges.erase(edge);
00196     }
00197   }
00198   
00199 }
00200 
00201 void ChowLiuTree::_build_tree_rec(int id, int pid) {
00202 
00203   ChowLiuTreeNode new_node;
00204 
00205   new_node.id = id;
00206   new_node.pid = pid;
00207   new_node.marginal = _clt_info.marginal(id);
00208   if (pid == -1) {
00209     new_node.conditional = new_node.marginal;
00210     new_node.joint = new_node.marginal;
00211   } else {
00212     new_node.conditional = _clt_info.conditional(id, pid);
00213     new_node.joint = _clt_info.joint(id, pid);
00214   }
00215 
00216   vector<int> cids; // vector of child ids
00217   list<MI>::iterator edge = _edges.begin();
00218   while(edge != _edges.end()) {
00219     if(edge->id1 == new_node.id) {
00220       cids.push_back(edge->id2);
00221       edge = _edges.erase(edge);
00222       continue;
00223     }
00224     if(edge->id2 == new_node.id) {
00225       cids.push_back(edge->id1);
00226       edge = _edges.erase(edge);
00227       continue;
00228     }
00229     edge++;
00230   }
00231   for(size_t i=0; i < cids.size(); i++) {
00232       new_node.cids.push_back(cids[i]);
00233       _build_tree_rec(cids[i], new_node.id);
00234   }
00235 
00236   tree[new_node.id] = new_node;
00237 }
00238 
00239 } //namespace isam


demo_rgbd
Author(s): Ji Zhang
autogenerated on Sun Oct 5 2014 23:25:11