Extending MatrixBase (and other classes)

In this section we will see how to add custom methods to MatrixBase. Since all expressions and matrix types inherit MatrixBase, adding a method to MatrixBase make it immediately available to all expressions ! A typical use case is, for instance, to make Eigen compatible with another API.

You certainly know that in C++ it is not possible to add methods to an existing class. So how that's possible ? Here the trick is to include in the declaration of MatrixBase a file defined by the preprocessor token EIGEN_MATRIXBASE_PLUGIN:

class MatrixBase {
// ...
#ifdef EIGEN_MATRIXBASE_PLUGIN
#include EIGEN_MATRIXBASE_PLUGIN
#endif
};

Therefore to extend MatrixBase with your own methods you just have to create a file with your method declaration and define EIGEN_MATRIXBASE_PLUGIN before you include any Eigen's header file.

You can extend many of the other classes used in Eigen by defining similarly named preprocessor symbols. For instance, define EIGEN_ARRAYBASE_PLUGIN if you want to extend the ArrayBase class. A full list of classes that can be extended in this way and the corresponding preprocessor symbols can be found on our page Preprocessor directives.

Here is an example of an extension file for adding methods to MatrixBase:
MatrixBaseAddons.h

inline Scalar at(uint i, uint j) const { return this->operator()(i,j); }
inline Scalar& at(uint i, uint j) { return this->operator()(i,j); }
inline Scalar at(uint i) const { return this->operator[](i); }
inline Scalar& at(uint i) { return this->operator[](i); }
inline RealScalar squaredLength() const { return squaredNorm(); }
inline RealScalar length() const { return norm(); }
inline RealScalar invLength(void) const { return fast_inv_sqrt(squaredNorm()); }
template<typename OtherDerived>
inline Scalar squaredDistanceTo(const MatrixBase<OtherDerived>& other) const
{ return (derived() - other.derived()).squaredNorm(); }
template<typename OtherDerived>
inline RealScalar distanceTo(const MatrixBase<OtherDerived>& other) const
{ return internal::sqrt(derived().squaredDistanceTo(other)); }
inline void scaleTo(RealScalar l) { RealScalar vl = norm(); if (vl>1e-9) derived() *= (l/vl); }
inline Transpose<Derived> transposed() {return this->transpose();}
inline const Transpose<Derived> transposed() const {return this->transpose();}
inline uint minComponentId(void) const { int i; this->minCoeff(&i); return i; }
inline uint maxComponentId(void) const { int i; this->maxCoeff(&i); return i; }
template<typename OtherDerived>
void makeFloor(const MatrixBase<OtherDerived>& other) { derived() = derived().cwiseMin(other.derived()); }
template<typename OtherDerived>
void makeCeil(const MatrixBase<OtherDerived>& other) { derived() = derived().cwiseMax(other.derived()); }
const CwiseBinaryOp<internal::scalar_sum_op<Scalar>, const Derived, const ConstantReturnType>
operator+(const Scalar& scalar) const
{ return CwiseBinaryOp<internal::scalar_sum_op<Scalar>, const Derived, const ConstantReturnType>(derived(), Constant(rows(),cols(),scalar)); }
friend const CwiseBinaryOp<internal::scalar_sum_op<Scalar>, const ConstantReturnType, Derived>
operator+(const Scalar& scalar, const MatrixBase<Derived>& mat)
{ return CwiseBinaryOp<internal::scalar_sum_op<Scalar>, const ConstantReturnType, Derived>(Constant(rows(),cols(),scalar), mat.derived()); }

Then one can the following declaration in the config.h or whatever prerequisites header file of his project:

#define EIGEN_MATRIXBASE_PLUGIN "MatrixBaseAddons.h"
e
Array< double, 1, 3 > e(1./3., 0.5, 2.)
Eigen::operator+
const EIGEN_STRONG_INLINE CwiseBinaryOp< internal::scalar_sum_op< typename DenseDerived::Scalar, typename SparseDerived::Scalar >, const DenseDerived, const SparseDerived > operator+(const MatrixBase< DenseDerived > &a, const SparseMatrixBase< SparseDerived > &b)
Definition: SparseCwiseBinaryOp.h:694
mat
MatrixXf mat
Definition: Tutorial_AdvancedInitialization_CommaTemporary.cpp:1
rows
int rows
Definition: Tutorial_commainit_02.cpp:1
j
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2
operator()
internal::enable_if< internal::valid_indexed_view_overload< RowIndices, ColIndices >::value &&internal::traits< typename EIGEN_INDEXED_VIEW_METHOD_TYPE< RowIndices, ColIndices >::type >::ReturnAsIndexedView, typename EIGEN_INDEXED_VIEW_METHOD_TYPE< RowIndices, ColIndices >::type >::type operator()(const RowIndices &rowIndices, const ColIndices &colIndices) EIGEN_INDEXED_VIEW_METHOD_CONST
Definition: IndexedViewMethods.h:73
scalar
mxArray * scalar(mxClassID classid)
Definition: matlab.h:82
l
static const Line3 l(Rot3(), 1, 1)
RealScalar
NumTraits< Scalar >::Real RealScalar
Definition: bench_gemm.cpp:47
cols
int cols
Definition: Tutorial_commainit_02.cpp:1
ceres::sqrt
Jet< T, N > sqrt(const Jet< T, N > &f)
Definition: jet.h:418
i
int i
Definition: BiCGSTAB_step_by_step.cpp:9
pybind_wrapper_test_script.other
other
Definition: pybind_wrapper_test_script.py:42
Scalar
SCALAR Scalar
Definition: bench_gemm.cpp:46


gtsam
Author(s):
autogenerated on Sat Jan 4 2025 04:08:19