testCallRecord.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------1------------------------------------------- */
11 
22 #include <gtsam/base/Matrix.h>
23 #include <gtsam/base/Testable.h>
24 
26 #include <iostream>
27 
28 using namespace std;
29 using namespace gtsam;
30 
31 /* ************************************************************************* */
32 static const int Cols = 3;
33 
34 
37  return Eigen::Dynamic;
38  }
39  else return i;
40 }
41 struct CallConfig {
46  CallConfig(int rows, int cols):
47  compTimeRows(dynamicIfAboveMax(rows)),
48  compTimeCols(cols),
49  runTimeRows(rows),
50  runTimeCols(cols)
51  {
52  }
53  CallConfig(int compTimeRows, int compTimeCols, int runTimeRows, int runTimeCols):
54  compTimeRows(compTimeRows),
55  compTimeCols(compTimeCols),
56  runTimeRows(runTimeRows),
57  runTimeCols(runTimeCols)
58  {
59  }
60 
61  bool equals(const CallConfig & c, double /*tol*/) const {
62  return
63  this->compTimeRows == c.compTimeRows &&
64  this->compTimeCols == c.compTimeCols &&
65  this->runTimeRows == c.runTimeRows &&
66  this->runTimeCols == c.runTimeCols;
67  }
68  void print(const std::string & prefix) const {
69  std::cout << prefix << "{" << compTimeRows << ", " << compTimeCols << ", " << runTimeRows << ", " << runTimeCols << "}\n" ;
70  }
71 };
72 
74 namespace gtsam {
75 template<> struct traits<CallConfig> : public Testable<CallConfig> {};
76 }
77 
78 struct Record: public internal::CallRecordImplementor<Record, Cols> {
79  Record() : cc(0, 0) {}
80  ~Record() override {
81  }
82  void print(const std::string& indent) const {
83  }
84  void startReverseAD4(internal::JacobianMap& jacobians) const {
85  }
86 
87  mutable CallConfig cc;
88  private:
89  template<typename SomeMatrix>
90  void reverseAD4(const SomeMatrix & dFdT, internal::JacobianMap& jacobians) const {
91  cc.compTimeRows = SomeMatrix::RowsAtCompileTime;
92  cc.compTimeCols = SomeMatrix::ColsAtCompileTime;
93  cc.runTimeRows = dFdT.rows();
94  cc.runTimeCols = dFdT.cols();
95  }
96 
97  template<typename Derived, int Rows>
98  friend struct internal::CallRecordImplementor;
99 };
100 
101 internal::JacobianMap* NJM_ptr = static_cast<internal::JacobianMap *>(nullptr);
102 internal::JacobianMap & NJM = *NJM_ptr;
103 
104 /* ************************************************************************* */
106 
107 TEST(CallRecord, virtualReverseAdDispatching) {
108  Record record;
109  {
110  const int Rows = 1;
111  record.CallRecord::reverseAD2(Eigen::Matrix<double, Rows, Cols>(), NJM);
112  EXPECT((assert_equal(record.cc, CallConfig(Rows, Cols))));
113  record.CallRecord::reverseAD2(DynRowMat(Rows, Cols), NJM);
114  EXPECT((assert_equal(record.cc, CallConfig(Eigen::Dynamic, Cols, Rows, Cols))));
115  record.CallRecord::reverseAD2(Eigen::MatrixXd(Rows, Cols), NJM);
117  }
118  {
119  const int Rows = 2;
120  record.CallRecord::reverseAD2(Eigen::Matrix<double, Rows, Cols>(), NJM);
121  EXPECT((assert_equal(record.cc, CallConfig(Rows, Cols))));
122  record.CallRecord::reverseAD2(DynRowMat(Rows, Cols), NJM);
123  EXPECT((assert_equal(record.cc, CallConfig(Eigen::Dynamic, Cols, Rows, Cols))));
124  record.CallRecord::reverseAD2(Eigen::MatrixXd(Rows, Cols), NJM);
126  }
127  {
128  const int Rows = 3;
129  record.CallRecord::reverseAD2(Eigen::Matrix<double, Rows, Cols>(), NJM);
130  EXPECT((assert_equal(record.cc, CallConfig(Rows, Cols))));
131  record.CallRecord::reverseAD2(DynRowMat(Rows, Cols), NJM);
132  EXPECT((assert_equal(record.cc, CallConfig(Eigen::Dynamic, Cols, Rows, Cols))));
133  record.CallRecord::reverseAD2(Eigen::MatrixXd(Rows, Cols), NJM);
135  }
136  {
137  const int Rows = 4;
138  record.CallRecord::reverseAD2(Eigen::Matrix<double, Rows, Cols>(), NJM);
139  EXPECT((assert_equal(record.cc, CallConfig(Rows, Cols))));
140  record.CallRecord::reverseAD2(DynRowMat(Rows, Cols), NJM);
141  EXPECT((assert_equal(record.cc, CallConfig(Eigen::Dynamic, Cols, Rows, Cols))));
142  record.CallRecord::reverseAD2(Eigen::MatrixXd(Rows, Cols), NJM);
144  }
145  {
146  const int Rows = 5;
147  record.CallRecord::reverseAD2(Eigen::Matrix<double, Rows, Cols>(), NJM);
148  EXPECT((assert_equal(record.cc, CallConfig(Rows, Cols))));
149  record.CallRecord::reverseAD2(DynRowMat(Rows, Cols), NJM);
150  EXPECT((assert_equal(record.cc, CallConfig(Eigen::Dynamic, Cols, Rows, Cols))));
151  record.CallRecord::reverseAD2(Eigen::MatrixXd(Rows, Cols), NJM);
153  }
154  {
155  const int Rows = 6;
156  record.CallRecord::reverseAD2(Eigen::Matrix<double, Rows, Cols>::Zero(), NJM);
157  EXPECT((assert_equal(record.cc, CallConfig(Rows, Cols))));
158  record.CallRecord::reverseAD2(DynRowMat(Rows, Cols), NJM);
159  EXPECT((assert_equal(record.cc, CallConfig(Eigen::Dynamic, Cols, Rows, Cols))));
160  record.CallRecord::reverseAD2(Eigen::MatrixXd(Rows, Cols), NJM);
162  }
163 }
164 
165 /* ************************************************************************* */
166 int main() {
167  TestResult tr;
168  return TestRegistry::runAllTests(tr);
169 }
170 /* ************************************************************************* */
CallConfig::runTimeCols
int runTimeCols
Definition: testCallRecord.cpp:45
TestRegistry::runAllTests
static int runAllTests(TestResult &result)
Definition: TestRegistry.cpp:27
CallConfig::equals
bool equals(const CallConfig &c, double) const
Definition: testCallRecord.cpp:61
Record::Record
Record()
Definition: testCallRecord.cpp:79
Testable.h
Concept check for values that can be used in unit tests.
EXPECT
#define EXPECT(condition)
Definition: Test.h:150
CallConfig::print
void print(const std::string &prefix) const
Definition: testCallRecord.cpp:68
TestHarness.h
c
Scalar Scalar * c
Definition: benchVecAdd.cpp:17
Matrix.h
typedef and functions to augment Eigen's MatrixXd
dynamicIfAboveMax
int dynamicIfAboveMax(int i)
Definition: testCallRecord.cpp:35
NJM_ptr
internal::JacobianMap * NJM_ptr
Definition: testCallRecord.cpp:101
TEST
TEST(CallRecord, virtualReverseAdDispatching)
Definition: testCallRecord.cpp:107
CallConfig::compTimeRows
int compTimeRows
Definition: testCallRecord.cpp:42
rows
int rows
Definition: Tutorial_commainit_02.cpp:1
Record::cc
CallConfig cc
Definition: testCallRecord.cpp:87
CallConfig::compTimeCols
int compTimeCols
Definition: testCallRecord.cpp:43
Record::~Record
~Record() override
Definition: testCallRecord.cpp:80
Record::reverseAD4
void reverseAD4(const SomeMatrix &dFdT, internal::JacobianMap &jacobians) const
Definition: testCallRecord.cpp:90
CallConfig::CallConfig
CallConfig(int rows, int cols)
Definition: testCallRecord.cpp:46
Record::print
void print(const std::string &indent) const
Definition: testCallRecord.cpp:82
CallConfig::CallConfig
CallConfig(int compTimeRows, int compTimeCols, int runTimeRows, int runTimeCols)
Definition: testCallRecord.cpp:53
Eigen::Dynamic
const int Dynamic
Definition: Constants.h:22
NJM
internal::JacobianMap & NJM
Definition: testCallRecord.cpp:102
DynRowMat
Eigen::Matrix< double, Eigen::Dynamic, Cols > DynRowMat
Definition: testCallRecord.cpp:105
Record::startReverseAD4
void startReverseAD4(internal::JacobianMap &jacobians) const
Definition: testCallRecord.cpp:84
Record
Definition: testCallRecord.cpp:78
main
int main()
Definition: testCallRecord.cpp:166
CallConfig
Definition: testCallRecord.cpp:41
TestResult
Definition: TestResult.h:26
gtsam
traits
Definition: SFMdata.h:40
gtsam::Testable
Definition: Testable.h:152
gtsam::traits
Definition: Group.h:36
std
Definition: BFloat16.h:88
gtsam::internal::CallRecordMaxVirtualStaticRows
const int CallRecordMaxVirtualStaticRows
Definition: CallRecord.h:133
gtsam::assert_equal
bool assert_equal(const Matrix &expected, const Matrix &actual, double tol)
Definition: Matrix.cpp:41
Eigen::Matrix
The matrix class, also used for vectors and row-vectors.
Definition: 3rdparty/Eigen/Eigen/src/Core/Matrix.h:178
cols
int cols
Definition: Tutorial_commainit_02.cpp:1
i
int i
Definition: BiCGSTAB_step_by_step.cpp:9
CallConfig::runTimeRows
int runTimeRows
Definition: testCallRecord.cpp:44
Cols
static const int Cols
Definition: testCallRecord.cpp:32
CallRecord.h
Internals for Expression.h, not for general consumption.


gtsam
Author(s):
autogenerated on Thu Dec 19 2024 04:06:09