00001
00029 #include "isam/ChowLiuTree.h"
00030 #include "isam/util.h"
00031
00032 using namespace std;
00033 using namespace Eigen;
00034
00035 namespace isam {
00036
00037 MatrixXd matslice (const MatrixXd A, vector<int> ii, vector<int> jj) {
00038 MatrixXd B(ii.size(), jj.size());
00039 for (size_t i=0; i<ii.size(); i++) {
00040 for (size_t j=0; j<jj.size(); j++) {
00041 B(i,j) = A(ii[i], jj[j]);
00042 }
00043 }
00044 return B;
00045 }
00046
00047
00048 MatrixXd ChowLiuTreeInfo::marginal(int id) {
00049
00050
00051 vector<int> iia(0);
00052 vector<int> iib(0);
00053 int off = 0;
00054 for (size_t i=0; i<_nodes.size(); i++) {
00055 if ((int)i == id)
00056 for (int j=0; j<_nodes[i]->dim(); j++) iia.push_back(off+j);
00057 else
00058 for (int j=0; j<_nodes[i]->dim(); j++) iib.push_back(off+j);
00059 off += _nodes[i]->dim();
00060 }
00061
00062 if (iib.size() > 0) {
00063 MatrixXd Laa = matslice(_L, iia, iia);
00064 MatrixXd Lbb = matslice(_L, iib, iib);
00065 MatrixXd Lab = matslice(_L, iia, iib);
00066 MatrixXd Lbbinv = posdef_pinv(Lbb);
00067 return Laa - Lab * Lbbinv * Lab.transpose();
00068
00069 } else
00070 return _L;
00071 }
00072
00073 MatrixXd ChowLiuTreeInfo::joint(int ida, int idb) {
00074
00075 vector<int> iia(0);
00076 vector<int> iib(0);
00077 vector<int> iic(0);
00078 int off = 0;
00079 for (size_t i=0; i<_nodes.size(); i++) {
00080 if ((int)i == ida)
00081 for (int j=0; j<_nodes[i]->dim(); j++) iia.push_back(off+j);
00082 else if ((int)i == idb)
00083 for (int j=0; j<_nodes[i]->dim(); j++) iib.push_back(off+j);
00084 else
00085 for (int j=0; j<_nodes[i]->dim(); j++) iic.push_back(off+j);
00086 off += _nodes[i]->dim();
00087 }
00088 vector<int> iiab(0);
00089 iiab.insert(iiab.end(), iia.begin(), iia.end());
00090 iiab.insert(iiab.end(), iib.begin(), iib.end());
00091
00092 if (iic.size() > 0) {
00093 MatrixXd Labab = matslice(_L, iiab, iiab);
00094 MatrixXd Lcc = matslice(_L, iic, iic);
00095 MatrixXd Labc = matslice(_L, iiab, iic);
00096 MatrixXd Lccinv = posdef_pinv(Lcc);
00097 return Labab - Labc * Lccinv * Labc.transpose();
00098
00099 } else
00100 return matslice(_L, iiab, iiab);
00101 }
00102
00103 MatrixXd ChowLiuTreeInfo::conditional(int ida, int idb) {
00104
00105 MatrixXd Lj = joint(ida, idb);
00106 return Lj.block(0, 0, _nodes[ida]->dim(), _nodes[ida]->dim());
00107
00108 }
00109
00110 ChowLiuTree::ChowLiuTree (const Eigen::MatrixXd &L, const std::vector<Node *>& nodes)
00111 : _clt_info(L, nodes)
00112 {
00113
00114 if (nodes.size() == 1) {
00115
00116 ChowLiuTreeNode node;
00117 node.id = 0;
00118 node.pid = -1;
00119 node.marginal = _clt_info.marginal(0);
00120 node.conditional = node.marginal;
00121 node.joint = node.marginal;
00122 tree[node.id] = node;
00123
00124 } else {
00125
00126
00127 _calc_edges();
00128 _max_span_tree();
00129 tree.clear();
00130 _build_tree_rec(_edges.front().id1, -1);
00131
00132 }
00133 }
00134
00135 bool mi_sort(MI &first, MI &second) {
00136 return first.mi > second.mi;
00137 }
00138
00139 void
00140 ChowLiuTree::_calc_edges() {
00141
00142 int nn = _clt_info.num_nodes();
00143
00144
00145 for (int i=0; i<nn; i++) {
00146 for (int j=(i+1); j<nn; j++) {
00147 MI mi_tmp (i,j, _calc_mi(i, j));
00148 _edges.push_back(mi_tmp);
00149 }
00150 }
00151 _edges.sort(mi_sort);
00152
00153 }
00154
00155 double ChowLiuTree::_calc_mi(int ida, int idb) {
00156
00157 MatrixXd L_agb = _clt_info.conditional (ida, idb);
00158 MatrixXd L_a = _clt_info.marginal (ida);
00159
00160
00161
00162
00163
00164
00165
00166 double ldL_agb = log ((L_agb + MatrixXd::Identity(L_agb.rows(), L_agb.cols())).determinant());
00167 double ldL_a = log ((L_a + MatrixXd::Identity(L_a.rows(), L_a.cols())).determinant());
00168 double mi = 0.5*(ldL_agb - ldL_a);
00169
00170 return mi;
00171 }
00172
00173 void ChowLiuTree::_max_span_tree() {
00174
00175
00176 map<int, int> groups;
00177 for (int i=0; i<_clt_info.num_nodes(); i++) {
00178 groups[i] = i;
00179 }
00180
00181 int group1, group2;
00182 list<MI>::iterator edge = _edges.begin();
00183 while(edge != _edges.end()) {
00184 if(groups[edge->id1] != groups[edge->id2]) {
00185 group1 = groups[edge->id1];
00186 group2 = groups[edge->id2];
00187
00188
00189 map<int, int>::iterator groupIt;
00190 for(groupIt = groups.begin(); groupIt != groups.end(); groupIt++)
00191 if(groupIt->second == group2) groupIt->second = group1;
00192
00193 edge++;
00194 } else {
00195 edge = _edges.erase(edge);
00196 }
00197 }
00198
00199 }
00200
00201 void ChowLiuTree::_build_tree_rec(int id, int pid) {
00202
00203 ChowLiuTreeNode new_node;
00204
00205 new_node.id = id;
00206 new_node.pid = pid;
00207 new_node.marginal = _clt_info.marginal(id);
00208 if (pid == -1) {
00209 new_node.conditional = new_node.marginal;
00210 new_node.joint = new_node.marginal;
00211 } else {
00212 new_node.conditional = _clt_info.conditional(id, pid);
00213 new_node.joint = _clt_info.joint(id, pid);
00214 }
00215
00216 vector<int> cids;
00217 list<MI>::iterator edge = _edges.begin();
00218 while(edge != _edges.end()) {
00219 if(edge->id1 == new_node.id) {
00220 cids.push_back(edge->id2);
00221 edge = _edges.erase(edge);
00222 continue;
00223 }
00224 if(edge->id2 == new_node.id) {
00225 cids.push_back(edge->id1);
00226 edge = _edges.erase(edge);
00227 continue;
00228 }
00229 edge++;
00230 }
00231 for(size_t i=0; i < cids.size(); i++) {
00232 new_node.cids.push_back(cids[i]);
00233 _build_tree_rec(cids[i], new_node.id);
00234 }
00235
00236 tree[new_node.id] = new_node;
00237 }
00238
00239 }