26 #define DISABLE_TIMING 33 using namespace gtsam;
81 #ifndef DISABLE_TIMING 82 cout << s <<
": " << std::setw(3) <<
muls <<
" muls, " <<
83 std::setw(3) <<
adds <<
" adds, " << 1000 *
elapsed <<
" ms." 88 double mul(
const double&
a,
const double&
b) {
92 double add_(
const double&
a,
const double&
b) {
108 ADT note(E, 0.9, 0.1);
110 ADT cnotb = c * notb;
111 dot(cnotb,
"ADT-cnotb");
115 ADT acnotb = a * cnotb;
119 dot(acnotb,
"ADT-acnotb");
132 static size_t count = 0;
134 std::stringstream
ss;
135 ss <<
"CPT-" << std::setw(3) << std::setfill(
'0') << ++count <<
"-" << key.first;
136 string DOTfile = ss.str();
144 DiscreteKey A(0, 2),
S(1, 2),
T(2, 2),
L(3, 2),
B(4, 2),
E(5, 2),
X(6, 2),
159 elapsed = asiaCPTsNode->secs() + asiaCPTsNode->wall();
167 dot(joint,
"Asia-A");
169 dot(joint,
"Asia-AS");
171 dot(joint,
"Asia-AST");
173 dot(joint,
"Asia-ASTL");
175 dot(joint,
"Asia-ASTLB");
177 dot(joint,
"Asia-ASTLBE");
179 dot(joint,
"Asia-ASTLBEX");
181 dot(joint,
"Asia-ASTLBEXD");
185 elapsed = asiaJointNode->secs() + asiaJointNode->wall();
206 B(2, 2),
L(3, 2),
E(4, 2),
S(5, 2),
T(6, 2),
X(7, 2);
220 elapsed = infCPTsNode->secs() + infCPTsNode->wall();
228 dot(joint,
"Joint-Product-A");
230 dot(joint,
"Joint-Product-AS");
232 dot(joint,
"Joint-Product-AST");
234 dot(joint,
"Joint-Product-ASTL");
236 dot(joint,
"Joint-Product-ASTLB");
238 dot(joint,
"Joint-Product-ASTLBE");
240 dot(joint,
"Joint-Product-ASTLBEX");
242 dot(joint,
"Joint-Product-ASTLBEXD");
246 elapsed = asiaProdNode->secs() + asiaProdNode->wall();
252 ADT marginal = joint;
254 dot(marginal,
"Joint-Sum-ADBLEST");
256 dot(marginal,
"Joint-Sum-ADBLES");
258 dot(marginal,
"Joint-Sum-ADBLE");
260 dot(marginal,
"Joint-Sum-ADBL");
264 elapsed = asiaSumNode->secs() + asiaSumNode->wall();
271 DiscreteKey B(0, 2),
L(1, 2),
E(2, 2),
S(3, 2),
T(4, 2),
X(5, 2);
284 elapsed = createCPTsNode->secs() + createCPTsNode->wall();
298 dot(fg,
"FactorGraph");
302 elapsed = asiaFGNode->secs() + asiaFGNode->wall();
309 dot(fg,
"Marginalized-6X");
311 dot(fg,
"Marginalized-5T");
313 dot(fg,
"Marginalized-4S");
315 dot(fg,
"Marginalized-3E");
317 dot(fg,
"Marginalized-2L");
321 elapsed = margNode->secs() + margNode->wall();
331 dot(fE,
"Eliminate-01-fEX");
333 dot(fE,
"Eliminate-02-fE");
336 elapsed = elimXNode->secs() + elimXNode->wall();
345 dot(fLE,
"Eliminate-03-fLET");
347 dot(fLE,
"Eliminate-04-fLE");
350 elapsed = elimTNode->secs() + elimTNode->wall();
360 dot(fBL,
"Eliminate-05-fBLS");
362 dot(fBL,
"Eliminate-06-fBL");
365 elapsed = elimSNode->secs() + elimSNode->wall();
375 dot(fBL2,
"Eliminate-07-fBLE");
377 dot(fBL2,
"Eliminate-08-fBL2");
380 elapsed = elimENode->secs() + elimENode->wall();
389 dot(fB,
"Eliminate-09-fBL");
391 dot(fB,
"Eliminate-10-fB");
394 elapsed = elimLNode->secs() + elimLNode->wall();
440 x00[0] = 0, x00[1] = 0;
441 x01[0] = 0, x01[1] = 1;
442 x02[0] = 0, x02[1] = 2;
443 x10[0] = 1, x10[1] = 0;
444 x11[0] = 1, x11[1] = 1;
445 x12[0] = 1, x12[1] = 2;
447 ADT f1(
v0 & v1,
"0 1 2 3 4 5");
455 ADT f2(v1 &
v0,
"0 1 2 3 4 5");
464 vector<double>
table(5 * 4 * 3 * 2);
466 for (
double&
t : table)
t = x++;
481 ADT fDiscreteKey(
X & Y,
"0.2 0.5 0.3 0.6");
482 dot(fDiscreteKey,
"conversion-f1");
484 std::map<Key, Key> keyMap;
491 dot(fIndexKey,
"conversion-f2");
494 x00[5] = 0, x00[2] = 0;
495 x01[5] = 0, x01[2] = 1;
496 x10[5] = 1, x10[2] = 0;
497 x11[5] = 1, x11[2] = 1;
508 ADT f1(
A &
B & C,
"1 2 3 4 5 6 1 8 3 3 5 5");
509 dot(f1,
"elimination-f1");
513 ADT actualSum = f1.
sum(C);
514 ADT expectedSum(
A &
B,
"3 7 11 9 6 10");
518 ADT actual = f1 / actualSum;
519 const vector<double> cpt{
520 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11,
521 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10};
529 ADT expectedSum(
A, 21, 25);
533 ADT actual = f1 / actualSum;
534 const vector<double> cpt{
535 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21,
536 1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25};
550 ADT expected_a_div_b(
A & B,
"4 2 8 4");
551 ADT expected_b_div_a(
A & B,
"0.25 0.5 0.125 0.25");
564 ADT anotb = a * notb;
567 x00[0] = 0, x00[1] = 0;
568 x01[0] = 0, x01[1] = 1;
569 x10[0] = 1, x10[1] = 0;
570 x11[0] = 1, x11[1] = 1;
Matrix< SCALARB, Dynamic, Dynamic, opt_B > B
const gtsam::Symbol key('X', 0)
AlgebraicDecisionTree sum(const L &label, size_t cardinality) const
#define tictoc_getNode(variable, label)
bool equals(const AlgebraicDecisionTree &other, double tol=1e-9) const
Equality method customized to value type double.
void printCounts(const string &s)
double add_(const double &a, const double &b)
Concept check for values that can be used in unit tests.
static int runAllTests(TestResult &result)
signatures for conditional densities
double mul(const double &a, const double &b)
DecisionTree< L, Y > apply(const DecisionTree< L, Y > &f, const typename DecisionTree< L, Y >::Unary &op)
Apply unary operator op to DecisionTree f.
EIGEN_DONT_INLINE Scalar zero()
bool assert_equal(const Matrix &expected, const Matrix &actual, double tol)
AlgebraicDecisionTree< Key > ADT
double f2(const Vector2 &x)
Matrix< SCALARA, Dynamic, Dynamic, opt_A > A
const DiscreteKey & key() const
Algebraic Decision Trees.
#define EXPECT_DOUBLES_EQUAL(expected, actual, threshold)
DiscreteKeys discreteKeys() const
DecisionTree combine(const L &label, size_t cardinality, const Binary &op) const
#define EXPECT(condition)
Eigen::Triplet< double > T
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
Array< double, 1, 3 > e(1./3., 0.5, 2.)
std::vector< double > Row
static std::stringstream ss
Matrix< Scalar, Dynamic, Dynamic > C
#define LONGS_EQUAL(expected, actual)
std::vector< double > cpt() const
specialized key for discrete variables
ADT create(const Signature &signature)
#define EXPECT_LONGS_EQUAL(expected, actual)
Point2 f1(const Point3 &p, OptionalJacobian< 2, 3 > H)
std::pair< Key, size_t > DiscreteKey
double f3(double x1, double x2)
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
void dot(const T &f, const string &filename)