PartialSVDSolver.h
Go to the documentation of this file.
1 // Copyright (C) 2018 Yixuan Qiu <yixuan.qiu@cos.name>
2 //
3 // This Source Code Form is subject to the terms of the Mozilla
4 // Public License v. 2.0. If a copy of the MPL was not distributed
5 // with this file, You can obtain one at https://mozilla.org/MPL/2.0/.
6 
7 #ifndef PARTIAL_SVD_SOLVER_H
8 #define PARTIAL_SVD_SOLVER_H
9 
10 #include <Eigen/Core>
11 #include "../SymEigsSolver.h"
12 
13 namespace Spectra {
14 
15 // Abstract class for matrix operation
16 template <typename Scalar>
17 class SVDMatOp
18 {
19 public:
20  virtual int rows() const = 0;
21  virtual int cols() const = 0;
22 
23  // y_out = A' * A * x_in or y_out = A * A' * x_in
24  virtual void perform_op(const Scalar* x_in, Scalar* y_out) = 0;
25 
26  virtual ~SVDMatOp() {}
27 };
28 
29 // Operation of a tall matrix in SVD
30 // We compute the eigenvalues of A' * A
31 // MatrixType is either Eigen::Matrix<Scalar, ...> or Eigen::SparseMatrix<Scalar, ...>
32 template <typename Scalar, typename MatrixType>
33 class SVDTallMatOp : public SVDMatOp<Scalar>
34 {
35 private:
40 
41  ConstGenericMatrix m_mat;
42  const int m_dim;
43  Vector m_cache;
44 
45 public:
46  // Constructor
47  SVDTallMatOp(ConstGenericMatrix& mat) :
48  m_mat(mat),
49  m_dim(std::min(mat.rows(), mat.cols())),
50  m_cache(mat.rows())
51  {}
52 
53  // These are the rows and columns of A' * A
54  int rows() const { return m_dim; }
55  int cols() const { return m_dim; }
56 
57  // y_out = A' * A * x_in
58  void perform_op(const Scalar* x_in, Scalar* y_out)
59  {
60  MapConstVec x(x_in, m_mat.cols());
61  MapVec y(y_out, m_mat.cols());
62  m_cache.noalias() = m_mat * x;
63  y.noalias() = m_mat.transpose() * m_cache;
64  }
65 };
66 
67 // Operation of a wide matrix in SVD
68 // We compute the eigenvalues of A * A'
69 // MatrixType is either Eigen::Matrix<Scalar, ...> or Eigen::SparseMatrix<Scalar, ...>
70 template <typename Scalar, typename MatrixType>
71 class SVDWideMatOp : public SVDMatOp<Scalar>
72 {
73 private:
78 
79  ConstGenericMatrix m_mat;
80  const int m_dim;
81  Vector m_cache;
82 
83 public:
84  // Constructor
85  SVDWideMatOp(ConstGenericMatrix& mat) :
86  m_mat(mat),
87  m_dim(std::min(mat.rows(), mat.cols())),
88  m_cache(mat.cols())
89  {}
90 
91  // These are the rows and columns of A * A'
92  int rows() const { return m_dim; }
93  int cols() const { return m_dim; }
94 
95  // y_out = A * A' * x_in
96  void perform_op(const Scalar* x_in, Scalar* y_out)
97  {
98  MapConstVec x(x_in, m_mat.rows());
99  MapVec y(y_out, m_mat.rows());
100  m_cache.noalias() = m_mat.transpose() * x;
101  y.noalias() = m_mat * m_cache;
102  }
103 };
104 
105 // Partial SVD solver
106 // MatrixType is either Eigen::Matrix<Scalar, ...> or Eigen::SparseMatrix<Scalar, ...>
107 template <typename Scalar = double,
110 {
111 private:
115 
116  ConstGenericMatrix m_mat;
117  const int m_m;
118  const int m_n;
121  int m_nconv;
122  Matrix m_evecs;
123 
124 public:
125  // Constructor
126  PartialSVDSolver(ConstGenericMatrix& mat, int ncomp, int ncv) :
127  m_mat(mat), m_m(mat.rows()), m_n(mat.cols()), m_evecs(0, 0)
128  {
129  // Determine the matrix type, tall or wide
130  if (m_m > m_n)
131  {
133  }
134  else
135  {
137  }
138 
139  // Solver object
140  m_eigs = new SymEigsSolver<Scalar, LARGEST_ALGE, SVDMatOp<Scalar> >(m_op, ncomp, ncv);
141  }
142 
143  // Destructor
145  {
146  delete m_eigs;
147  delete m_op;
148  }
149 
150  // Computation
151  int compute(int maxit = 1000, Scalar tol = 1e-10)
152  {
153  m_eigs->init();
154  m_nconv = m_eigs->compute(maxit, tol);
155 
156  return m_nconv;
157  }
158 
159  // The converged singular values
160  Vector singular_values() const
161  {
162  Vector svals = m_eigs->eigenvalues().cwiseSqrt();
163 
164  return svals;
165  }
166 
167  // The converged left singular vectors
168  Matrix matrix_U(int nu)
169  {
170  if (m_evecs.cols() < 1)
171  {
172  m_evecs = m_eigs->eigenvectors();
173  }
174  nu = std::min(nu, m_nconv);
175  if (m_m <= m_n)
176  {
177  return m_evecs.leftCols(nu);
178  }
179 
180  return m_mat * (m_evecs.leftCols(nu).array().rowwise() / m_eigs->eigenvalues().head(nu).transpose().array().sqrt()).matrix();
181  }
182 
183  // The converged right singular vectors
184  Matrix matrix_V(int nv)
185  {
186  if (m_evecs.cols() < 1)
187  {
188  m_evecs = m_eigs->eigenvectors();
189  }
190  nv = std::min(nv, m_nconv);
191  if (m_m > m_n)
192  {
193  return m_evecs.leftCols(nv);
194  }
195 
196  return m_mat.transpose() * (m_evecs.leftCols(nv).array().rowwise() / m_eigs->eigenvalues().head(nv).transpose().array().sqrt()).matrix();
197  }
198 };
199 
200 } // namespace Spectra
201 
202 #endif // PARTIAL_SVD_SOLVER_H
SVDMatOp< Scalar > * m_op
SCALAR Scalar
Definition: bench_gemm.cpp:46
int compute(int maxit=1000, Scalar tol=1e-10)
ConstGenericMatrix m_mat
Scalar * y
const Eigen::Ref< const MatrixType > ConstGenericMatrix
virtual void perform_op(const Scalar *x_in, Scalar *y_out)=0
#define min(a, b)
Definition: datatypes.h:19
A matrix or vector expression mapping an existing array of data.
Definition: Map.h:94
virtual int rows() const =0
Eigen::Matrix< Scalar, Eigen::Dynamic, Eigen::Dynamic > Matrix
Definition: BFloat16.h:88
MatrixXf MatrixType
Eigen::Matrix< Scalar, Eigen::Dynamic, 1 > Vector
void perform_op(const Scalar *x_in, Scalar *y_out)
const Eigen::Ref< const MatrixType > ConstGenericMatrix
SVDTallMatOp(ConstGenericMatrix &mat)
Eigen::Map< const Vector > MapConstVec
Eigen::Map< Vector > MapVec
Eigen::Map< const Vector > MapConstVec
ConstGenericMatrix m_mat
SymEigsSolver< Scalar, LARGEST_ALGE, SVDMatOp< Scalar > > * m_eigs
Array< double, 1, 3 > e(1./3., 0.5, 2.)
Eigen::Matrix< Scalar, Eigen::Dynamic, 1 > Vector
Index compute(Index maxit=1000, Scalar tol=1e-10, int sort_rule=LARGEST_ALGE)
Definition: SymEigsBase.h:334
SVDWideMatOp(ConstGenericMatrix &mat)
void perform_op(const Scalar *x_in, Scalar *y_out)
Eigen::Map< Vector > MapVec
ConstGenericMatrix m_mat
PartialSVDSolver(ConstGenericMatrix &mat, int ncomp, int ncv)
Eigen::Matrix< Scalar, Eigen::Dynamic, 1 > Vector
const Eigen::Ref< const MatrixType > ConstGenericMatrix
const G double tol
Definition: Group.h:86
Map< Matrix< T, Dynamic, Dynamic, ColMajor >, 0, OuterStride<> > matrix(T *data, int rows, int cols, int stride)
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT
virtual int cols() const =0


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:35:12