33 #ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H 34 #define EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H 42 int Mode,
bool LhsIsTriangular,
43 int LhsStorageOrder,
bool ConjugateLhs,
44 int RhsStorageOrder,
bool ConjugateRhs,
48 LhsIsTriangular,LhsStorageOrder,ConjugateLhs,
49 RhsStorageOrder, ConjugateRhs, ResStorageOrder, 1, BuiltIn> {};
53 #define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \ 54 template <typename Index, int Mode, \ 55 int LhsStorageOrder, bool ConjugateLhs, \ 56 int RhsStorageOrder, bool ConjugateRhs> \ 57 struct product_triangular_matrix_matrix<Scalar,Index, Mode, LhsIsTriangular, \ 58 LhsStorageOrder,ConjugateLhs, RhsStorageOrder,ConjugateRhs,ColMajor,1,Specialized> { \ 59 static inline void run(Index _rows, Index _cols, Index _depth, const Scalar* _lhs, Index lhsStride,\ 60 const Scalar* _rhs, Index rhsStride, Scalar* res, Index resIncr, Index resStride, Scalar alpha, level3_blocking<Scalar,Scalar>& blocking) { \ 61 EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \ 62 eigen_assert(resIncr == 1); \ 63 product_triangular_matrix_matrix_trmm<Scalar,Index,Mode, \ 64 LhsIsTriangular,LhsStorageOrder,ConjugateLhs, \ 65 RhsStorageOrder, ConjugateRhs, ColMajor>::run( \ 66 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \ 80 #define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \ 81 template <typename Index, int Mode, \ 82 int LhsStorageOrder, bool ConjugateLhs, \ 83 int RhsStorageOrder, bool ConjugateRhs> \ 84 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \ 85 LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \ 88 IsLower = (Mode&Lower) == Lower, \ 89 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \ 90 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \ 91 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \ 92 LowUp = IsLower ? Lower : Upper, \ 93 conjA = ((LhsStorageOrder==ColMajor) && ConjugateLhs) ? 1 : 0 \ 97 Index _rows, Index _cols, Index _depth, \ 98 const EIGTYPE* _lhs, Index lhsStride, \ 99 const EIGTYPE* _rhs, Index rhsStride, \ 100 EIGTYPE* res, Index resStride, \ 101 EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \ 103 Index diagSize = (std::min)(_rows,_depth); \ 104 Index rows = IsLower ? _rows : diagSize; \ 105 Index depth = IsLower ? diagSize : _depth; \ 106 Index cols = _cols; \ 108 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \ 109 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \ 112 if (rows != depth) { \ 117 if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \ 119 product_triangular_matrix_matrix<EIGTYPE,Index,Mode,true, \ 120 LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, 1, BuiltIn>::run( \ 121 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, 1, resStride, alpha, blocking); \ 125 Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \ 126 MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \ 127 BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \ 128 gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \ 129 general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1>::run( \ 130 rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, 1, resStride, alpha, gemm_blocking, 0); \ 136 char side = 'L', transa, uplo, diag = 'N'; \ 139 BlasIndex m, n, lda, ldb; \ 142 m = convert_index<BlasIndex>(diagSize); \ 143 n = convert_index<BlasIndex>(cols); \ 146 transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \ 149 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols,OuterStride<>(rhsStride)); \ 150 MatrixX##EIGPREFIX b_tmp; \ 152 if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \ 154 ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \ 157 uplo = IsLower ? 'L' : 'U'; \ 158 if (LhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \ 160 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \ 163 if ((conjA!=0) || (SetDiag==0)) { \ 164 if (conjA) a_tmp = lhs.conjugate(); else a_tmp = lhs; \ 166 a_tmp.diagonal().setZero(); \ 167 else if (IsUnitDiag) \ 168 a_tmp.diagonal().setOnes();\ 170 lda = convert_index<BlasIndex>(a_tmp.outerStride()); \ 173 lda = convert_index<BlasIndex>(lhsStride); \ 177 BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \ 180 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ 181 res_tmp=res_tmp+b_tmp; \ 198 #define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \ 199 template <typename Index, int Mode, \ 200 int LhsStorageOrder, bool ConjugateLhs, \ 201 int RhsStorageOrder, bool ConjugateRhs> \ 202 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \ 203 LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \ 206 IsLower = (Mode&Lower) == Lower, \ 207 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \ 208 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \ 209 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \ 210 LowUp = IsLower ? Lower : Upper, \ 211 conjA = ((RhsStorageOrder==ColMajor) && ConjugateRhs) ? 1 : 0 \ 215 Index _rows, Index _cols, Index _depth, \ 216 const EIGTYPE* _lhs, Index lhsStride, \ 217 const EIGTYPE* _rhs, Index rhsStride, \ 218 EIGTYPE* res, Index resStride, \ 219 EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \ 221 Index diagSize = (std::min)(_cols,_depth); \ 222 Index rows = _rows; \ 223 Index depth = IsLower ? _depth : diagSize; \ 224 Index cols = IsLower ? diagSize : _cols; \ 226 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \ 227 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \ 230 if (cols != depth) { \ 234 if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \ 236 product_triangular_matrix_matrix<EIGTYPE,Index,Mode,false, \ 237 LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, 1, BuiltIn>::run( \ 238 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, 1, resStride, alpha, blocking); \ 242 Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \ 243 MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \ 244 BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \ 245 gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \ 246 general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1>::run( \ 247 rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, 1, resStride, alpha, gemm_blocking, 0); \ 253 char side = 'R', transa, uplo, diag = 'N'; \ 256 BlasIndex m, n, lda, ldb; \ 259 m = convert_index<BlasIndex>(rows); \ 260 n = convert_index<BlasIndex>(diagSize); \ 263 transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \ 266 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \ 267 MatrixX##EIGPREFIX b_tmp; \ 269 if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \ 271 ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \ 274 uplo = IsLower ? 'L' : 'U'; \ 275 if (RhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \ 277 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols, OuterStride<>(rhsStride)); \ 280 if ((conjA!=0) || (SetDiag==0)) { \ 281 if (conjA) a_tmp = rhs.conjugate(); else a_tmp = rhs; \ 283 a_tmp.diagonal().setZero(); \ 284 else if (IsUnitDiag) \ 285 a_tmp.diagonal().setOnes();\ 287 lda = convert_index<BlasIndex>(a_tmp.outerStride()); \ 290 lda = convert_index<BlasIndex>(rhsStride); \ 294 BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \ 297 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ 298 res_tmp=res_tmp+b_tmp; \ 317 #endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
int BLASFUNC() strmm(char *, char *, char *, char *, int *, int *, float *, float *, int *, float *, int *)
int BLASFUNC() ztrmm(char *, char *, char *, char *, int *, int *, double *, double *, int *, double *, int *)
Namespace containing all symbols from the Eigen library.
#define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC)
std::complex< float > scomplex
std::complex< double > dcomplex
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
int BLASFUNC() dtrmm(char *, char *, char *, char *, int *, int *, double *, double *, int *, double *, int *)
#define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC)
int BLASFUNC() ctrmm(char *, char *, char *, char *, int *, int *, float *, float *, int *, float *, int *)
#define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular)