33 #ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H 34 #define EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H 42 int Mode,
bool LhsIsTriangular,
43 int LhsStorageOrder,
bool ConjugateLhs,
44 int RhsStorageOrder,
bool ConjugateRhs,
48 LhsIsTriangular,LhsStorageOrder,ConjugateLhs,
49 RhsStorageOrder, ConjugateRhs, ResStorageOrder, BuiltIn> {};
53 #define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \ 54 template <typename Index, int Mode, \ 55 int LhsStorageOrder, bool ConjugateLhs, \ 56 int RhsStorageOrder, bool ConjugateRhs> \ 57 struct product_triangular_matrix_matrix<Scalar,Index, Mode, LhsIsTriangular, \ 58 LhsStorageOrder,ConjugateLhs, RhsStorageOrder,ConjugateRhs,ColMajor,Specialized> { \ 59 static inline void run(Index _rows, Index _cols, Index _depth, const Scalar* _lhs, Index lhsStride,\ 60 const Scalar* _rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha, level3_blocking<Scalar,Scalar>& blocking) { \ 61 product_triangular_matrix_matrix_trmm<Scalar,Index,Mode, \ 62 LhsIsTriangular,LhsStorageOrder,ConjugateLhs, \ 63 RhsStorageOrder, ConjugateRhs, ColMajor>::run( \ 64 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \ 78 #define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \ 79 template <typename Index, int Mode, \ 80 int LhsStorageOrder, bool ConjugateLhs, \ 81 int RhsStorageOrder, bool ConjugateRhs> \ 82 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \ 83 LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \ 86 IsLower = (Mode&Lower) == Lower, \ 87 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \ 88 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \ 89 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \ 90 LowUp = IsLower ? Lower : Upper, \ 91 conjA = ((LhsStorageOrder==ColMajor) && ConjugateLhs) ? 1 : 0 \ 95 Index _rows, Index _cols, Index _depth, \ 96 const EIGTYPE* _lhs, Index lhsStride, \ 97 const EIGTYPE* _rhs, Index rhsStride, \ 98 EIGTYPE* res, Index resStride, \ 99 EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \ 101 Index diagSize = (std::min)(_rows,_depth); \ 102 Index rows = IsLower ? _rows : diagSize; \ 103 Index depth = IsLower ? diagSize : _depth; \ 104 Index cols = _cols; \ 106 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \ 107 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \ 110 if (rows != depth) { \ 115 if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \ 117 product_triangular_matrix_matrix<EIGTYPE,Index,Mode,true, \ 118 LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \ 119 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \ 123 Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \ 124 MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \ 125 BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \ 126 gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \ 127 general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \ 128 rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \ 134 char side = 'L', transa, uplo, diag = 'N'; \ 137 BlasIndex m, n, lda, ldb; \ 140 m = convert_index<BlasIndex>(diagSize); \ 141 n = convert_index<BlasIndex>(cols); \ 144 transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \ 147 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols,OuterStride<>(rhsStride)); \ 148 MatrixX##EIGPREFIX b_tmp; \ 150 if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \ 152 ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \ 155 uplo = IsLower ? 'L' : 'U'; \ 156 if (LhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \ 158 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \ 161 if ((conjA!=0) || (SetDiag==0)) { \ 162 if (conjA) a_tmp = lhs.conjugate(); else a_tmp = lhs; \ 164 a_tmp.diagonal().setZero(); \ 165 else if (IsUnitDiag) \ 166 a_tmp.diagonal().setOnes();\ 168 lda = convert_index<BlasIndex>(a_tmp.outerStride()); \ 171 lda = convert_index<BlasIndex>(lhsStride); \ 175 BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \ 178 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ 179 res_tmp=res_tmp+b_tmp; \ 196 #define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \ 197 template <typename Index, int Mode, \ 198 int LhsStorageOrder, bool ConjugateLhs, \ 199 int RhsStorageOrder, bool ConjugateRhs> \ 200 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \ 201 LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \ 204 IsLower = (Mode&Lower) == Lower, \ 205 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \ 206 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \ 207 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \ 208 LowUp = IsLower ? Lower : Upper, \ 209 conjA = ((RhsStorageOrder==ColMajor) && ConjugateRhs) ? 1 : 0 \ 213 Index _rows, Index _cols, Index _depth, \ 214 const EIGTYPE* _lhs, Index lhsStride, \ 215 const EIGTYPE* _rhs, Index rhsStride, \ 216 EIGTYPE* res, Index resStride, \ 217 EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \ 219 Index diagSize = (std::min)(_cols,_depth); \ 220 Index rows = _rows; \ 221 Index depth = IsLower ? _depth : diagSize; \ 222 Index cols = IsLower ? diagSize : _cols; \ 224 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \ 225 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \ 228 if (cols != depth) { \ 232 if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \ 234 product_triangular_matrix_matrix<EIGTYPE,Index,Mode,false, \ 235 LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \ 236 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \ 240 Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \ 241 MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \ 242 BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \ 243 gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \ 244 general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \ 245 rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \ 251 char side = 'R', transa, uplo, diag = 'N'; \ 254 BlasIndex m, n, lda, ldb; \ 257 m = convert_index<BlasIndex>(rows); \ 258 n = convert_index<BlasIndex>(diagSize); \ 261 transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \ 264 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \ 265 MatrixX##EIGPREFIX b_tmp; \ 267 if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \ 269 ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \ 272 uplo = IsLower ? 'L' : 'U'; \ 273 if (RhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \ 275 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols, OuterStride<>(rhsStride)); \ 278 if ((conjA!=0) || (SetDiag==0)) { \ 279 if (conjA) a_tmp = rhs.conjugate(); else a_tmp = rhs; \ 281 a_tmp.diagonal().setZero(); \ 282 else if (IsUnitDiag) \ 283 a_tmp.diagonal().setOnes();\ 285 lda = convert_index<BlasIndex>(a_tmp.outerStride()); \ 288 lda = convert_index<BlasIndex>(rhsStride); \ 292 BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \ 295 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ 296 res_tmp=res_tmp+b_tmp; \ 315 #endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
int BLASFUNC() strmm(char *, char *, char *, char *, int *, int *, float *, float *, int *, float *, int *)
int BLASFUNC() ztrmm(char *, char *, char *, char *, int *, int *, double *, double *, int *, double *, int *)
Namespace containing all symbols from the Eigen library.
#define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC)
std::complex< float > scomplex
std::complex< double > dcomplex
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
int BLASFUNC() dtrmm(char *, char *, char *, char *, int *, int *, double *, double *, int *, double *, int *)
#define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC)
int BLASFUNC() ctrmm(char *, char *, char *, char *, int *, int *, float *, float *, int *, float *, int *)
#define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular)