Program Listing for File basic-preconditioners.hpp

Return to documentation for file (include/nanoeigenpy/solvers/basic-preconditioners.hpp)

#pragma once

#include "nanoeigenpy/fwd.hpp"
#include <Eigen/IterativeLinearSolvers>

namespace nanoeigenpy {
namespace nb = nanobind;

template <typename Preconditioner>
struct PreconditionerBaseVisitor
    : nb::def_visitor<PreconditionerBaseVisitor<Preconditioner>> {
  using MatrixType = Eigen::MatrixXd;
  using VectorType = Eigen::VectorXd;

  template <typename... Ts>
  void execute(nb::class_<Preconditioner, Ts...>& cl) {
    using namespace nb::literals;
    cl.def(nb::init<>())
        .def(nb::init<MatrixType>(), "A"_a)
        .def("info", &Preconditioner::info,
             "Returns success if the Preconditioner has been well initialized.")
        .def("solve", &solve, "b"_a,
             "Returns the solution A * z = b where the preconditioner is an "
             "estimate of A^-1.")

        .def("compute", &Preconditioner::template compute<MatrixType>, "mat"_a,
             "Initialize the preconditioner from the matrix value.",
             nb::rv_policy::reference)
        .def("factorize", &Preconditioner::template factorize<MatrixType>,
             "mat"_a,
             "Initialize the preconditioner from the matrix value, i.e "
             "factorize the mat given as input to approximate its inverse.",
             nb::rv_policy::reference);
  }

 private:
  static VectorType solve(const Preconditioner& self, const VectorType& vec) {
    return self.solve(vec);
  }
};

template <typename Scalar>
struct DiagonalPreconditionerVisitor
    : PreconditionerBaseVisitor<DiagonalPreconditionerVisitor<Scalar>> {
  using Preconditioner = Eigen::DiagonalPreconditioner<Scalar>;

  template <typename... Ts>
  void execute(nb::class_<Scalar, Ts...>& cl) {
    using namespace nb::literals;
    cl.def(PreconditionerBaseVisitor<Preconditioner>())
        .def("rows", &Preconditioner::rows,
             "Returns the number of rows in the preconditioner.")
        .def("cols", &Preconditioner::rows,
             "Returns the number of cols in the preconditioner.");
  }

  static void expose(nb::module_& m, const char* name) {
    if (check_registration_alias<Preconditioner>(m)) {
      return;
    }
    nb::class_<Preconditioner>(m, name).def(IdVisitor());
  }
};

template <typename Scalar>
void exposeDiagonalPreconditioner(nb::module_& m, const char* name) {
  DiagonalPreconditionerVisitor<Scalar>::expose(m, name);
}

#if EIGEN_VERSION_AT_LEAST(3, 3, 5)
template <typename Scalar>
struct LeastSquareDiagonalPreconditionerVisitor
    : PreconditionerBaseVisitor<
          LeastSquareDiagonalPreconditionerVisitor<Scalar>> {
  using Preconditioner = Eigen::LeastSquareDiagonalPreconditioner<Scalar>;

  template <typename... Ts>
  void execute(nb::class_<Scalar, Ts...>& cl) {
    cl.def(PreconditionerBaseVisitor<Preconditioner>())
        .def("rows", &Preconditioner::rows,
             "Returns the number of rows in the preconditioner.")
        .def("cols", &Preconditioner::rows,
             "Returns the number of cols in the preconditioner.");
  }

  static void expose(nb::module_& m, const char* name) {
    if (check_registration_alias<Preconditioner>(m)) {
      return;
    }
    nb::class_<Preconditioner>(m, name).def(IdVisitor());
  }
};

template <typename Scalar>
void exposeLeastSquareDiagonalPreconditioner(nb::module_& m, const char* name) {
  LeastSquareDiagonalPreconditionerVisitor<Scalar>::expose(m, name);
}
#endif

template <typename Scalar>
struct IdentityPreconditionerVisitor
    : PreconditionerBaseVisitor<Eigen::IdentityPreconditioner> {
  using Preconditioner = Eigen::IdentityPreconditioner;

  template <typename... Ts>
  void execute(nb::class_<Scalar, Ts...>&) {}

  static void expose(nb::module_& m, const char* name) {
    if (check_registration_alias<Preconditioner>(m)) {
      return;
    }
    nb::class_<Preconditioner>(m, name)
        .def(PreconditionerBaseVisitor<Preconditioner>())
        .def(IdVisitor());
  }
};

template <typename Scalar>
void exposeIdentityPreconditioner(nb::module_& m, const char* name) {
  IdentityPreconditionerVisitor<Scalar>::expose(m, name);
}

}  // namespace nanoeigenpy