PartialSVDSolver.h
Go to the documentation of this file.
1 // Copyright (C) 2018-2025 Yixuan Qiu <yixuan.qiu@cos.name>
2 //
3 // This Source Code Form is subject to the terms of the Mozilla
4 // Public License v. 2.0. If a copy of the MPL was not distributed
5 // with this file, You can obtain one at https://mozilla.org/MPL/2.0/.
6 
7 #ifndef SPECTRA_PARTIAL_SVD_SOLVER_H
8 #define SPECTRA_PARTIAL_SVD_SOLVER_H
9 
10 #include <Eigen/Core>
11 #include "../SymEigsSolver.h"
12 
13 namespace Spectra {
14 
15 // Abstract class for matrix operation
16 template <typename Scalar_>
17 class SVDMatOp
18 {
19 public:
20  using Scalar = Scalar_;
21 
22 private:
24 
25 public:
26  virtual Index rows() const = 0;
27  virtual Index cols() const = 0;
28 
29  // y_out = A' * A * x_in or y_out = A * A' * x_in
30  virtual void perform_op(const Scalar* x_in, Scalar* y_out) const = 0;
31 
32  virtual ~SVDMatOp() {}
33 };
34 
35 // Operation of a tall matrix in SVD
36 // We compute the eigenvalues of A' * A
37 // MatrixType is either Eigen::Matrix<Scalar, ...> or Eigen::SparseMatrix<Scalar, ...>
38 template <typename Scalar, typename MatrixType>
39 class SVDTallMatOp : public SVDMatOp<Scalar>
40 {
41 private:
47 
49  const Index m_dim;
50  mutable Vector m_cache;
51 
52 public:
53  // Constructor
55  m_mat(mat),
56  m_dim((std::min)(mat.rows(), mat.cols())),
57  m_cache(mat.rows())
58  {}
59 
60  // These are the rows and columns of A' * A
61  Index rows() const override { return m_dim; }
62  Index cols() const override { return m_dim; }
63 
64  // y_out = A' * A * x_in
65  void perform_op(const Scalar* x_in, Scalar* y_out) const override
66  {
67  MapConstVec x(x_in, m_mat.cols());
68  MapVec y(y_out, m_mat.cols());
69  m_cache.noalias() = m_mat * x;
70  y.noalias() = m_mat.transpose() * m_cache;
71  }
72 };
73 
74 // Operation of a wide matrix in SVD
75 // We compute the eigenvalues of A * A'
76 // MatrixType is either Eigen::Matrix<Scalar, ...> or Eigen::SparseMatrix<Scalar, ...>
77 template <typename Scalar, typename MatrixType>
78 class SVDWideMatOp : public SVDMatOp<Scalar>
79 {
80 private:
86 
88  const Index m_dim;
89  mutable Vector m_cache;
90 
91 public:
92  // Constructor
94  m_mat(mat),
95  m_dim((std::min)(mat.rows(), mat.cols())),
96  m_cache(mat.cols())
97  {}
98 
99  // These are the rows and columns of A * A'
100  Index rows() const override { return m_dim; }
101  Index cols() const override { return m_dim; }
102 
103  // y_out = A * A' * x_in
104  void perform_op(const Scalar* x_in, Scalar* y_out) const override
105  {
106  MapConstVec x(x_in, m_mat.rows());
107  MapVec y(y_out, m_mat.rows());
108  m_cache.noalias() = m_mat.transpose() * x;
109  y.noalias() = m_mat * m_cache;
110  }
111 };
112 
113 // Partial SVD solver
114 // MatrixType is either Eigen::Matrix<Scalar, ...> or Eigen::SparseMatrix<Scalar, ...>
115 template <typename MatrixType = Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>>
117 {
118 private:
119  using Scalar = typename MatrixType::Scalar;
124 
126  const Index m_m;
127  const Index m_n;
132 
133 public:
134  // Constructor
136  m_mat(mat), m_m(mat.rows()), m_n(mat.cols()), m_evecs(0, 0)
137  {
138  // Determine the matrix type, tall or wide
139  if (m_m > m_n)
140  {
142  }
143  else
144  {
146  }
147 
148  // Solver object
149  m_eigs = new SymEigsSolver<SVDMatOp<Scalar>>(*m_op, ncomp, ncv);
150  }
151 
152  // Destructor
154  {
155  delete m_eigs;
156  delete m_op;
157  }
158 
159  // Computation
160  Index compute(Index maxit = 1000, Scalar tol = 1e-10)
161  {
162  m_eigs->init();
163  m_nconv = m_eigs->compute(SortRule::LargestAlge, maxit, tol);
164 
165  return m_nconv;
166  }
167 
168  // The converged singular values
170  {
171  Vector svals = m_eigs->eigenvalues().cwiseSqrt();
172 
173  return svals;
174  }
175 
176  // The converged left singular vectors
178  {
179  if (m_evecs.cols() < 1)
180  {
181  m_evecs = m_eigs->eigenvectors();
182  }
183  nu = (std::min)(nu, m_nconv);
184  if (m_m <= m_n)
185  {
186  return m_evecs.leftCols(nu);
187  }
188 
189  return m_mat * (m_evecs.leftCols(nu).array().rowwise() / m_eigs->eigenvalues().head(nu).transpose().array().sqrt()).matrix();
190  }
191 
192  // The converged right singular vectors
194  {
195  if (m_evecs.cols() < 1)
196  {
197  m_evecs = m_eigs->eigenvectors();
198  }
199  nv = (std::min)(nv, m_nconv);
200  if (m_m > m_n)
201  {
202  return m_evecs.leftCols(nv);
203  }
204 
205  return m_mat.transpose() * (m_evecs.leftCols(nv).array().rowwise() / m_eigs->eigenvalues().head(nv).transpose().array().sqrt()).matrix();
206  }
207 };
208 
209 } // namespace Spectra
210 
211 #endif // SPECTRA_PARTIAL_SVD_SOLVER_H
Spectra::PartialSVDSolver::m_op
SVDMatOp< Scalar > * m_op
Definition: PartialSVDSolver.h:128
Spectra::SVDWideMatOp::rows
Index rows() const override
Definition: PartialSVDSolver.h:100
Spectra::SVDMatOp< Scalar >::Scalar
Scalar Scalar
Definition: PartialSVDSolver.h:20
Spectra::SVDWideMatOp::m_cache
Vector m_cache
Definition: PartialSVDSolver.h:89
Spectra::SVDWideMatOp::Index
Eigen::Index Index
Definition: PartialSVDSolver.h:81
Spectra::SVDWideMatOp::cols
Index cols() const override
Definition: PartialSVDSolver.h:101
e
Array< double, 1, 3 > e(1./3., 0.5, 2.)
Spectra::SVDTallMatOp::m_mat
ConstGenericMatrix m_mat
Definition: PartialSVDSolver.h:48
x
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
Definition: gnuplot_common_settings.hh:12
Spectra::PartialSVDSolver::m_eigs
SymEigsSolver< SVDMatOp< Scalar > > * m_eigs
Definition: PartialSVDSolver.h:129
Spectra::SVDMatOp< Scalar >::Index
Eigen::Index Index
Definition: PartialSVDSolver.h:23
Spectra::SVDMatOp::rows
virtual Index rows() const =0
Spectra::SymEigsSolver
Definition: SymEigsSolver.h:134
mat
MatrixXf mat
Definition: Tutorial_AdvancedInitialization_CommaTemporary.cpp:1
Spectra::SVDTallMatOp::SVDTallMatOp
SVDTallMatOp(ConstGenericMatrix &mat)
Definition: PartialSVDSolver.h:54
rows
int rows
Definition: Tutorial_commainit_02.cpp:1
Spectra::SortRule::LargestAlge
@ LargestAlge
Spectra::SVDWideMatOp::m_mat
ConstGenericMatrix m_mat
Definition: PartialSVDSolver.h:87
Spectra::PartialSVDSolver
Definition: PartialSVDSolver.h:116
Spectra::SVDWideMatOp
Definition: PartialSVDSolver.h:78
Spectra::PartialSVDSolver::matrix_V
Matrix matrix_V(Index nv)
Definition: PartialSVDSolver.h:193
Spectra::PartialSVDSolver::PartialSVDSolver
PartialSVDSolver(ConstGenericMatrix &mat, Index ncomp, Index ncv)
Definition: PartialSVDSolver.h:135
Spectra::SVDTallMatOp::m_dim
const Index m_dim
Definition: PartialSVDSolver.h:49
Spectra::SVDMatOp::cols
virtual Index cols() const =0
Spectra::SVDMatOp::~SVDMatOp
virtual ~SVDMatOp()
Definition: PartialSVDSolver.h:32
Spectra::PartialSVDSolver::m_mat
ConstGenericMatrix m_mat
Definition: PartialSVDSolver.h:125
Spectra::SVDTallMatOp::m_cache
Vector m_cache
Definition: PartialSVDSolver.h:50
Spectra::SVDTallMatOp
Definition: PartialSVDSolver.h:39
Spectra::PartialSVDSolver::m_nconv
Index m_nconv
Definition: PartialSVDSolver.h:130
Spectra::PartialSVDSolver::Index
Eigen::Index Index
Definition: PartialSVDSolver.h:120
Spectra::PartialSVDSolver::Scalar
typename MatrixType::Scalar Scalar
Definition: PartialSVDSolver.h:119
Spectra::PartialSVDSolver::singular_values
Vector singular_values() const
Definition: PartialSVDSolver.h:169
Eigen::Map
A matrix or vector expression mapping an existing array of data.
Definition: Map.h:94
y
Scalar * y
Definition: level1_cplx_impl.h:124
matrix
Map< Matrix< T, Dynamic, Dynamic, ColMajor >, 0, OuterStride<> > matrix(T *data, int rows, int cols, int stride)
Definition: gtsam/3rdparty/Eigen/blas/common.h:110
Spectra::SVDMatOp::perform_op
virtual void perform_op(const Scalar *x_in, Scalar *y_out) const =0
Spectra::SVDMatOp
Definition: PartialSVDSolver.h:17
Eigen::Ref< const MatrixType >
Eigen::PlainObjectBase::cols
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT
Definition: PlainObjectBase.h:145
std
Definition: BFloat16.h:88
Spectra::SVDTallMatOp::rows
Index rows() const override
Definition: PartialSVDSolver.h:61
Spectra::SVDTallMatOp::Index
Eigen::Index Index
Definition: PartialSVDSolver.h:42
Spectra::PartialSVDSolver::~PartialSVDSolver
virtual ~PartialSVDSolver()
Definition: PartialSVDSolver.h:153
Spectra
Definition: LOBPCGSolver.h:19
Spectra::PartialSVDSolver::matrix_U
Matrix matrix_U(Index nu)
Definition: PartialSVDSolver.h:177
min
#define min(a, b)
Definition: datatypes.h:19
gtsam::tol
const G double tol
Definition: Group.h:79
Spectra::SVDTallMatOp::cols
Index cols() const override
Definition: PartialSVDSolver.h:62
Eigen::Matrix< Scalar, Eigen::Dynamic, 1 >
Spectra::PartialSVDSolver::m_m
const Index m_m
Definition: PartialSVDSolver.h:126
Spectra::SVDTallMatOp::perform_op
void perform_op(const Scalar *x_in, Scalar *y_out) const override
Definition: PartialSVDSolver.h:65
cols
int cols
Definition: Tutorial_commainit_02.cpp:1
Spectra::PartialSVDSolver::compute
Index compute(Index maxit=1000, Scalar tol=1e-10)
Definition: PartialSVDSolver.h:160
Spectra::SVDWideMatOp::perform_op
void perform_op(const Scalar *x_in, Scalar *y_out) const override
Definition: PartialSVDSolver.h:104
Spectra::PartialSVDSolver::m_n
const Index m_n
Definition: PartialSVDSolver.h:127
Spectra::SVDWideMatOp::SVDWideMatOp
SVDWideMatOp(ConstGenericMatrix &mat)
Definition: PartialSVDSolver.h:93
Spectra::SVDWideMatOp::m_dim
const Index m_dim
Definition: PartialSVDSolver.h:88
Scalar
SCALAR Scalar
Definition: bench_gemm.cpp:46
Spectra::PartialSVDSolver::m_evecs
Matrix m_evecs
Definition: PartialSVDSolver.h:131
Eigen::Index
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:74


gtsam
Author(s):
autogenerated on Sun Feb 16 2025 04:02:32