Go to the documentation of this file.
33 #ifndef EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H
34 #define EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H
46 template<
typename Index,
int Mode,
typename LhsScalar,
bool ConjLhs,
typename RhsScalar,
bool ConjRhs,
int StorageOrder>
50 #define EIGEN_BLAS_TRMV_SPECIALIZE(Scalar) \
51 template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
52 struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor,Specialized> { \
53 static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \
54 const Scalar* _rhs, Index rhsIncr, Scalar* _res, Index resIncr, Scalar alpha) { \
55 triangular_matrix_vector_product_trmv<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor>::run( \
56 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
59 template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
60 struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,RowMajor,Specialized> { \
61 static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \
62 const Scalar* _rhs, Index rhsIncr, Scalar* _res, Index resIncr, Scalar alpha) { \
63 triangular_matrix_vector_product_trmv<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,RowMajor>::run( \
64 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
74 #define EIGEN_BLAS_TRMV_CM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX, BLASPOSTFIX) \
75 template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
76 struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor> { \
78 IsLower = (Mode&Lower) == Lower, \
79 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
80 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
81 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
82 LowUp = IsLower ? Lower : Upper \
84 static void run(Index _rows, Index _cols, const EIGTYPE* _lhs, Index lhsStride, \
85 const EIGTYPE* _rhs, Index rhsIncr, EIGTYPE* _res, Index resIncr, EIGTYPE alpha) \
87 if (ConjLhs || IsZeroDiag) { \
88 triangular_matrix_vector_product<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor,BuiltIn>::run( \
89 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
92 Index size = (std::min)(_rows,_cols); \
93 Index rows = IsLower ? _rows : size; \
94 Index cols = IsLower ? size : _cols; \
96 typedef VectorX##EIGPREFIX VectorRhs; \
100 Map<const VectorRhs, 0, InnerStride<> > rhs(_rhs,cols,InnerStride<>(rhsIncr)); \
102 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
107 char trans, uplo, diag; \
108 BlasIndex m, n, lda, incx, incy; \
113 n = convert_index<BlasIndex>(size); \
114 lda = convert_index<BlasIndex>(lhsStride); \
116 incy = convert_index<BlasIndex>(resIncr); \
120 uplo = IsLower ? 'L' : 'U'; \
121 diag = IsUnitDiag ? 'U' : 'N'; \
124 BLASPREFIX##trmv##BLASPOSTFIX(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \
127 BLASPREFIX##axpy##BLASPOSTFIX(&n, (const BLASTYPE*)&numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \
129 if (size<(std::max)(rows,cols)) { \
130 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
133 y = _res + size*resIncr; \
135 m = convert_index<BlasIndex>(rows-size); \
136 n = convert_index<BlasIndex>(size); \
141 a = _lhs + size*lda; \
142 m = convert_index<BlasIndex>(size); \
143 n = convert_index<BlasIndex>(cols-size); \
145 BLASPREFIX##gemv##BLASPOSTFIX(&trans, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)y, &incy); \
163 #define EIGEN_BLAS_TRMV_RM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX, BLASPOSTFIX) \
164 template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
165 struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor> { \
167 IsLower = (Mode&Lower) == Lower, \
168 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
169 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
170 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
171 LowUp = IsLower ? Lower : Upper \
173 static void run(Index _rows, Index _cols, const EIGTYPE* _lhs, Index lhsStride, \
174 const EIGTYPE* _rhs, Index rhsIncr, EIGTYPE* _res, Index resIncr, EIGTYPE alpha) \
177 triangular_matrix_vector_product<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor,BuiltIn>::run( \
178 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
181 Index size = (std::min)(_rows,_cols); \
182 Index rows = IsLower ? _rows : size; \
183 Index cols = IsLower ? size : _cols; \
185 typedef VectorX##EIGPREFIX VectorRhs; \
189 Map<const VectorRhs, 0, InnerStride<> > rhs(_rhs,cols,InnerStride<>(rhsIncr)); \
191 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
196 char trans, uplo, diag; \
197 BlasIndex m, n, lda, incx, incy; \
202 n = convert_index<BlasIndex>(size); \
203 lda = convert_index<BlasIndex>(lhsStride); \
205 incy = convert_index<BlasIndex>(resIncr); \
208 trans = ConjLhs ? 'C' : 'T'; \
209 uplo = IsLower ? 'U' : 'L'; \
210 diag = IsUnitDiag ? 'U' : 'N'; \
213 BLASPREFIX##trmv##BLASPOSTFIX(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \
216 BLASPREFIX##axpy##BLASPOSTFIX(&n, (const BLASTYPE*)&numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \
218 if (size<(std::max)(rows,cols)) { \
219 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
222 y = _res + size*resIncr; \
223 a = _lhs + size*lda; \
224 m = convert_index<BlasIndex>(rows-size); \
225 n = convert_index<BlasIndex>(size); \
231 m = convert_index<BlasIndex>(size); \
232 n = convert_index<BlasIndex>(cols-size); \
234 BLASPREFIX##gemv##BLASPOSTFIX(&trans, &n, &m, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)y, &incy); \
255 #endif // EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H
Namespace containing all symbols from the Eigen library.
#define EIGEN_BLAS_TRMV_SPECIALIZE(Scalar)
static const double d[K][N]
#define EIGEN_BLAS_TRMV_CM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX, BLASPOSTFIX)
std::complex< float > scomplex
std::complex< double > dcomplex
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
#define EIGEN_BLAS_TRMV_RM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX, BLASPOSTFIX)
constexpr descr< N - 1 > _(char const (&text)[N])
gtsam
Author(s):
autogenerated on Wed Jan 1 2025 04:08:11