33 #ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H 34 #define EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H 41 template <
typename Scalar,
typename Index,
42 int Mode,
bool LhsIsTriangular,
43 int LhsStorageOrder,
bool ConjugateLhs,
44 int RhsStorageOrder,
bool ConjugateRhs,
48 LhsIsTriangular,LhsStorageOrder,ConjugateLhs,
49 RhsStorageOrder, ConjugateRhs, ResStorageOrder, BuiltIn> {};
53 #define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \ 54 template <typename Index, int Mode, \ 55 int LhsStorageOrder, bool ConjugateLhs, \ 56 int RhsStorageOrder, bool ConjugateRhs> \ 57 struct product_triangular_matrix_matrix<Scalar,Index, Mode, LhsIsTriangular, \ 58 LhsStorageOrder,ConjugateLhs, RhsStorageOrder,ConjugateRhs,ColMajor,Specialized> { \ 59 static inline void run(Index _rows, Index _cols, Index _depth, const Scalar* _lhs, Index lhsStride,\ 60 const Scalar* _rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha, level3_blocking<Scalar,Scalar>& blocking) { \ 61 product_triangular_matrix_matrix_trmm<Scalar,Index,Mode, \ 62 LhsIsTriangular,LhsStorageOrder,ConjugateLhs, \ 63 RhsStorageOrder, ConjugateRhs, ColMajor>::run( \ 64 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \ 78 #define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \ 79 template <typename Index, int Mode, \ 80 int LhsStorageOrder, bool ConjugateLhs, \ 81 int RhsStorageOrder, bool ConjugateRhs> \ 82 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \ 83 LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \ 86 IsLower = (Mode&Lower) == Lower, \ 87 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \ 88 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \ 89 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \ 90 LowUp = IsLower ? Lower : Upper, \ 91 conjA = ((LhsStorageOrder==ColMajor) && ConjugateLhs) ? 1 : 0 \ 95 Index _rows, Index _cols, Index _depth, \ 96 const EIGTYPE* _lhs, Index lhsStride, \ 97 const EIGTYPE* _rhs, Index rhsStride, \ 98 EIGTYPE* res, Index resStride, \ 99 EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \ 101 Index diagSize = (std::min)(_rows,_depth); \ 102 Index rows = IsLower ? _rows : diagSize; \ 103 Index depth = IsLower ? diagSize : _depth; \ 104 Index cols = _cols; \ 106 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \ 107 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \ 110 if (rows != depth) { \ 115 if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \ 117 product_triangular_matrix_matrix<EIGTYPE,Index,Mode,true, \ 118 LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \ 119 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \ 123 Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \ 124 MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \ 125 BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \ 126 gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \ 127 general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \ 128 rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \ 134 char side = 'L', transa, uplo, diag = 'N'; \ 137 BlasIndex m, n, lda, ldb; \ 140 m = convert_index<BlasIndex>(diagSize); \ 141 n = convert_index<BlasIndex>(cols); \ 144 transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \ 147 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols,OuterStride<>(rhsStride)); \ 148 MatrixX##EIGPREFIX b_tmp; \ 150 if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \ 152 ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \ 155 uplo = IsLower ? 'L' : 'U'; \ 156 if (LhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \ 158 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \ 161 if ((conjA!=0) || (SetDiag==0)) { \ 162 if (conjA) a_tmp = lhs.conjugate(); else a_tmp = lhs; \ 164 a_tmp.diagonal().setZero(); \ 165 else if (IsUnitDiag) \ 166 a_tmp.diagonal().setOnes();\ 168 lda = convert_index<BlasIndex>(a_tmp.outerStride()); \ 171 lda = convert_index<BlasIndex>(lhsStride); \ 175 BLASPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \ 178 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ 179 res_tmp=res_tmp+b_tmp; \ 189 #define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \ 190 template <typename Index, int Mode, \ 191 int LhsStorageOrder, bool ConjugateLhs, \ 192 int RhsStorageOrder, bool ConjugateRhs> \ 193 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \ 194 LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \ 197 IsLower = (Mode&Lower) == Lower, \ 198 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \ 199 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \ 200 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \ 201 LowUp = IsLower ? Lower : Upper, \ 202 conjA = ((RhsStorageOrder==ColMajor) && ConjugateRhs) ? 1 : 0 \ 206 Index _rows, Index _cols, Index _depth, \ 207 const EIGTYPE* _lhs, Index lhsStride, \ 208 const EIGTYPE* _rhs, Index rhsStride, \ 209 EIGTYPE* res, Index resStride, \ 210 EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \ 212 Index diagSize = (std::min)(_cols,_depth); \ 213 Index rows = _rows; \ 214 Index depth = IsLower ? _depth : diagSize; \ 215 Index cols = IsLower ? diagSize : _cols; \ 217 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \ 218 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \ 221 if (cols != depth) { \ 225 if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \ 227 product_triangular_matrix_matrix<EIGTYPE,Index,Mode,false, \ 228 LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \ 229 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \ 233 Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \ 234 MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \ 235 BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \ 236 gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \ 237 general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \ 238 rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \ 244 char side = 'R', transa, uplo, diag = 'N'; \ 247 BlasIndex m, n, lda, ldb; \ 250 m = convert_index<BlasIndex>(rows); \ 251 n = convert_index<BlasIndex>(diagSize); \ 254 transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \ 257 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \ 258 MatrixX##EIGPREFIX b_tmp; \ 260 if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \ 262 ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \ 265 uplo = IsLower ? 'L' : 'U'; \ 266 if (RhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \ 268 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols, OuterStride<>(rhsStride)); \ 271 if ((conjA!=0) || (SetDiag==0)) { \ 272 if (conjA) a_tmp = rhs.conjugate(); else a_tmp = rhs; \ 274 a_tmp.diagonal().setZero(); \ 275 else if (IsUnitDiag) \ 276 a_tmp.diagonal().setOnes();\ 278 lda = convert_index<BlasIndex>(a_tmp.outerStride()); \ 281 lda = convert_index<BlasIndex>(rhsStride); \ 285 BLASPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \ 288 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ 289 res_tmp=res_tmp+b_tmp; \ 302 #endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H #define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX)
static int f(const TensorMap< Tensor< int, 3 > > &tensor)
#define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX)
std::complex< float > scomplex
std::complex< double > dcomplex
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
#define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular)