33 #ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H 34 #define EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H 41 template <
typename Scalar,
typename Index,
42 int Mode,
bool LhsIsTriangular,
43 int LhsStorageOrder,
bool ConjugateLhs,
44 int RhsStorageOrder,
bool ConjugateRhs,
48 LhsIsTriangular,LhsStorageOrder,ConjugateLhs,
49 RhsStorageOrder, ConjugateRhs, ResStorageOrder, BuiltIn> {};
53 #define EIGEN_MKL_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \ 54 template <typename Index, int Mode, \ 55 int LhsStorageOrder, bool ConjugateLhs, \ 56 int RhsStorageOrder, bool ConjugateRhs> \ 57 struct product_triangular_matrix_matrix<Scalar,Index, Mode, LhsIsTriangular, \ 58 LhsStorageOrder,ConjugateLhs, RhsStorageOrder,ConjugateRhs,ColMajor,Specialized> { \ 59 static inline void run(Index _rows, Index _cols, Index _depth, const Scalar* _lhs, Index lhsStride,\ 60 const Scalar* _rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha, level3_blocking<Scalar,Scalar>& blocking) { \ 61 product_triangular_matrix_matrix_trmm<Scalar,Index,Mode, \ 62 LhsIsTriangular,LhsStorageOrder,ConjugateLhs, \ 63 RhsStorageOrder, ConjugateRhs, ColMajor>::run( \ 64 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \ 78 #define EIGEN_MKL_TRMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ 79 template <typename Index, int Mode, \ 80 int LhsStorageOrder, bool ConjugateLhs, \ 81 int RhsStorageOrder, bool ConjugateRhs> \ 82 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \ 83 LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \ 86 IsLower = (Mode&Lower) == Lower, \ 87 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \ 88 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \ 89 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \ 90 LowUp = IsLower ? Lower : Upper, \ 91 conjA = ((LhsStorageOrder==ColMajor) && ConjugateLhs) ? 1 : 0 \ 95 Index _rows, Index _cols, Index _depth, \ 96 const EIGTYPE* _lhs, Index lhsStride, \ 97 const EIGTYPE* _rhs, Index rhsStride, \ 98 EIGTYPE* res, Index resStride, \ 99 EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \ 101 Index diagSize = (std::min)(_rows,_depth); \ 102 Index rows = IsLower ? _rows : diagSize; \ 103 Index depth = IsLower ? diagSize : _depth; \ 104 Index cols = _cols; \ 106 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \ 107 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \ 110 if (rows != depth) { \ 112 int nthr = mkl_domain_get_max_threads(MKL_BLAS); \ 114 if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \ 116 product_triangular_matrix_matrix<EIGTYPE,Index,Mode,true, \ 117 LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \ 118 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \ 122 Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \ 123 MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \ 124 MKL_INT aStride = aa_tmp.outerStride(); \ 125 gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth); \ 126 general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \ 127 rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \ 133 char side = 'L', transa, uplo, diag = 'N'; \ 136 MKL_INT m, n, lda, ldb; \ 140 assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \ 143 m = (MKL_INT)diagSize; \ 147 transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \ 150 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols,OuterStride<>(rhsStride)); \ 151 MatrixX##EIGPREFIX b_tmp; \ 153 if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \ 155 ldb = b_tmp.outerStride(); \ 158 uplo = IsLower ? 'L' : 'U'; \ 159 if (LhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \ 161 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \ 164 if ((conjA!=0) || (SetDiag==0)) { \ 165 if (conjA) a_tmp = lhs.conjugate(); else a_tmp = lhs; \ 167 a_tmp.diagonal().setZero(); \ 168 else if (IsUnitDiag) \ 169 a_tmp.diagonal().setOnes();\ 171 lda = a_tmp.outerStride(); \ 178 MKLPREFIX##trmm(&side, &uplo, &transa, &diag, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \ 181 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ 182 res_tmp=res_tmp+b_tmp; \ 192 #define EIGEN_MKL_TRMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ 193 template <typename Index, int Mode, \ 194 int LhsStorageOrder, bool ConjugateLhs, \ 195 int RhsStorageOrder, bool ConjugateRhs> \ 196 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \ 197 LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \ 200 IsLower = (Mode&Lower) == Lower, \ 201 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \ 202 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \ 203 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \ 204 LowUp = IsLower ? Lower : Upper, \ 205 conjA = ((RhsStorageOrder==ColMajor) && ConjugateRhs) ? 1 : 0 \ 209 Index _rows, Index _cols, Index _depth, \ 210 const EIGTYPE* _lhs, Index lhsStride, \ 211 const EIGTYPE* _rhs, Index rhsStride, \ 212 EIGTYPE* res, Index resStride, \ 213 EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \ 215 Index diagSize = (std::min)(_cols,_depth); \ 216 Index rows = _rows; \ 217 Index depth = IsLower ? _depth : diagSize; \ 218 Index cols = IsLower ? diagSize : _cols; \ 220 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \ 221 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \ 224 if (cols != depth) { \ 226 int nthr = mkl_domain_get_max_threads(MKL_BLAS); \ 228 if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \ 230 product_triangular_matrix_matrix<EIGTYPE,Index,Mode,false, \ 231 LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \ 232 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \ 236 Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \ 237 MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \ 238 MKL_INT aStride = aa_tmp.outerStride(); \ 239 gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth); \ 240 general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \ 241 rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \ 247 char side = 'R', transa, uplo, diag = 'N'; \ 250 MKL_INT m, n, lda, ldb; \ 254 assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \ 258 n = (MKL_INT)diagSize; \ 261 transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \ 264 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \ 265 MatrixX##EIGPREFIX b_tmp; \ 267 if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \ 269 ldb = b_tmp.outerStride(); \ 272 uplo = IsLower ? 'L' : 'U'; \ 273 if (RhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \ 275 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols, OuterStride<>(rhsStride)); \ 278 if ((conjA!=0) || (SetDiag==0)) { \ 279 if (conjA) a_tmp = rhs.conjugate(); else a_tmp = rhs; \ 281 a_tmp.diagonal().setZero(); \ 282 else if (IsUnitDiag) \ 283 a_tmp.diagonal().setOnes();\ 285 lda = a_tmp.outerStride(); \ 292 MKLPREFIX##trmm(&side, &uplo, &transa, &diag, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \ 295 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ 296 res_tmp=res_tmp+b_tmp; \ 309 #endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H #define EIGEN_MKL_TRMM_SPECIALIZE(Scalar, LhsIsTriangular)
#define EIGEN_MKL_TRMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX)
#define EIGEN_MKL_TRMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX)