00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033 #ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H
00034 #define EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H
00035
00036 namespace Eigen {
00037
00038 namespace internal {
00039
00040
00041 template <typename Scalar, typename Index,
00042 int Mode, bool LhsIsTriangular,
00043 int LhsStorageOrder, bool ConjugateLhs,
00044 int RhsStorageOrder, bool ConjugateRhs,
00045 int ResStorageOrder>
00046 struct product_triangular_matrix_matrix_trmm :
00047 product_triangular_matrix_matrix<Scalar,Index,Mode,
00048 LhsIsTriangular,LhsStorageOrder,ConjugateLhs,
00049 RhsStorageOrder, ConjugateRhs, ResStorageOrder, BuiltIn> {};
00050
00051
00052
00053 #define EIGEN_MKL_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \
00054 template <typename Index, int Mode, \
00055 int LhsStorageOrder, bool ConjugateLhs, \
00056 int RhsStorageOrder, bool ConjugateRhs> \
00057 struct product_triangular_matrix_matrix<Scalar,Index, Mode, LhsIsTriangular, \
00058 LhsStorageOrder,ConjugateLhs, RhsStorageOrder,ConjugateRhs,ColMajor,Specialized> { \
00059 static inline void run(Index _rows, Index _cols, Index _depth, const Scalar* _lhs, Index lhsStride,\
00060 const Scalar* _rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha, level3_blocking<Scalar,Scalar>& blocking) { \
00061 product_triangular_matrix_matrix_trmm<Scalar,Index,Mode, \
00062 LhsIsTriangular,LhsStorageOrder,ConjugateLhs, \
00063 RhsStorageOrder, ConjugateRhs, ColMajor>::run( \
00064 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
00065 } \
00066 };
00067
00068 EIGEN_MKL_TRMM_SPECIALIZE(double, true)
00069 EIGEN_MKL_TRMM_SPECIALIZE(double, false)
00070 EIGEN_MKL_TRMM_SPECIALIZE(dcomplex, true)
00071 EIGEN_MKL_TRMM_SPECIALIZE(dcomplex, false)
00072 EIGEN_MKL_TRMM_SPECIALIZE(float, true)
00073 EIGEN_MKL_TRMM_SPECIALIZE(float, false)
00074 EIGEN_MKL_TRMM_SPECIALIZE(scomplex, true)
00075 EIGEN_MKL_TRMM_SPECIALIZE(scomplex, false)
00076
00077
00078 #define EIGEN_MKL_TRMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
00079 template <typename Index, int Mode, \
00080 int LhsStorageOrder, bool ConjugateLhs, \
00081 int RhsStorageOrder, bool ConjugateRhs> \
00082 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
00083 LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \
00084 { \
00085 enum { \
00086 IsLower = (Mode&Lower) == Lower, \
00087 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
00088 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
00089 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
00090 LowUp = IsLower ? Lower : Upper, \
00091 conjA = ((LhsStorageOrder==ColMajor) && ConjugateLhs) ? 1 : 0 \
00092 }; \
00093 \
00094 static void run( \
00095 Index _rows, Index _cols, Index _depth, \
00096 const EIGTYPE* _lhs, Index lhsStride, \
00097 const EIGTYPE* _rhs, Index rhsStride, \
00098 EIGTYPE* res, Index resStride, \
00099 EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \
00100 { \
00101 Index diagSize = (std::min)(_rows,_depth); \
00102 Index rows = IsLower ? _rows : diagSize; \
00103 Index depth = IsLower ? diagSize : _depth; \
00104 Index cols = _cols; \
00105 \
00106 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
00107 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
00108 \
00109 \
00110 if (rows != depth) { \
00111 \
00112 int nthr = mkl_domain_get_max_threads(MKL_BLAS); \
00113 \
00114 if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \
00115 \
00116 product_triangular_matrix_matrix<EIGTYPE,Index,Mode,true, \
00117 LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \
00118 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
00119 \
00120 } else { \
00121 \
00122 Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \
00123 MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \
00124 MKL_INT aStride = aa_tmp.outerStride(); \
00125 gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth); \
00126 general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
00127 rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \
00128 \
00129 \
00130 } \
00131 return; \
00132 } \
00133 char side = 'L', transa, uplo, diag = 'N'; \
00134 EIGTYPE *b; \
00135 const EIGTYPE *a; \
00136 MKL_INT m, n, lda, ldb; \
00137 MKLTYPE alpha_; \
00138 \
00139 \
00140 assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
00141 \
00142 \
00143 m = (MKL_INT)diagSize; \
00144 n = (MKL_INT)cols; \
00145 \
00146 \
00147 transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \
00148 \
00149 \
00150 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols,OuterStride<>(rhsStride)); \
00151 MatrixX##EIGPREFIX b_tmp; \
00152 \
00153 if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \
00154 b = b_tmp.data(); \
00155 ldb = b_tmp.outerStride(); \
00156 \
00157 \
00158 uplo = IsLower ? 'L' : 'U'; \
00159 if (LhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
00160 \
00161 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \
00162 MatrixLhs a_tmp; \
00163 \
00164 if ((conjA!=0) || (SetDiag==0)) { \
00165 if (conjA) a_tmp = lhs.conjugate(); else a_tmp = lhs; \
00166 if (IsZeroDiag) \
00167 a_tmp.diagonal().setZero(); \
00168 else if (IsUnitDiag) \
00169 a_tmp.diagonal().setOnes();\
00170 a = a_tmp.data(); \
00171 lda = a_tmp.outerStride(); \
00172 } else { \
00173 a = _lhs; \
00174 lda = lhsStride; \
00175 } \
00176 \
00177 \
00178 MKLPREFIX##trmm(&side, &uplo, &transa, &diag, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \
00179 \
00180 \
00181 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
00182 res_tmp=res_tmp+b_tmp; \
00183 } \
00184 };
00185
00186 EIGEN_MKL_TRMM_L(double, double, d, d)
00187 EIGEN_MKL_TRMM_L(dcomplex, MKL_Complex16, cd, z)
00188 EIGEN_MKL_TRMM_L(float, float, f, s)
00189 EIGEN_MKL_TRMM_L(scomplex, MKL_Complex8, cf, c)
00190
00191
00192 #define EIGEN_MKL_TRMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
00193 template <typename Index, int Mode, \
00194 int LhsStorageOrder, bool ConjugateLhs, \
00195 int RhsStorageOrder, bool ConjugateRhs> \
00196 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
00197 LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \
00198 { \
00199 enum { \
00200 IsLower = (Mode&Lower) == Lower, \
00201 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
00202 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
00203 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
00204 LowUp = IsLower ? Lower : Upper, \
00205 conjA = ((RhsStorageOrder==ColMajor) && ConjugateRhs) ? 1 : 0 \
00206 }; \
00207 \
00208 static void run( \
00209 Index _rows, Index _cols, Index _depth, \
00210 const EIGTYPE* _lhs, Index lhsStride, \
00211 const EIGTYPE* _rhs, Index rhsStride, \
00212 EIGTYPE* res, Index resStride, \
00213 EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \
00214 { \
00215 Index diagSize = (std::min)(_cols,_depth); \
00216 Index rows = _rows; \
00217 Index depth = IsLower ? _depth : diagSize; \
00218 Index cols = IsLower ? diagSize : _cols; \
00219 \
00220 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
00221 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
00222 \
00223 \
00224 if (cols != depth) { \
00225 \
00226 int nthr = mkl_domain_get_max_threads(MKL_BLAS); \
00227 \
00228 if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \
00229 \
00230 product_triangular_matrix_matrix<EIGTYPE,Index,Mode,false, \
00231 LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \
00232 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
00233 \
00234 } else { \
00235 \
00236 Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \
00237 MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \
00238 MKL_INT aStride = aa_tmp.outerStride(); \
00239 gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth); \
00240 general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
00241 rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \
00242 \
00243 \
00244 } \
00245 return; \
00246 } \
00247 char side = 'R', transa, uplo, diag = 'N'; \
00248 EIGTYPE *b; \
00249 const EIGTYPE *a; \
00250 MKL_INT m, n, lda, ldb; \
00251 MKLTYPE alpha_; \
00252 \
00253 \
00254 assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
00255 \
00256 \
00257 m = (MKL_INT)rows; \
00258 n = (MKL_INT)diagSize; \
00259 \
00260 \
00261 transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
00262 \
00263 \
00264 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \
00265 MatrixX##EIGPREFIX b_tmp; \
00266 \
00267 if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \
00268 b = b_tmp.data(); \
00269 ldb = b_tmp.outerStride(); \
00270 \
00271 \
00272 uplo = IsLower ? 'L' : 'U'; \
00273 if (RhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
00274 \
00275 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols, OuterStride<>(rhsStride)); \
00276 MatrixRhs a_tmp; \
00277 \
00278 if ((conjA!=0) || (SetDiag==0)) { \
00279 if (conjA) a_tmp = rhs.conjugate(); else a_tmp = rhs; \
00280 if (IsZeroDiag) \
00281 a_tmp.diagonal().setZero(); \
00282 else if (IsUnitDiag) \
00283 a_tmp.diagonal().setOnes();\
00284 a = a_tmp.data(); \
00285 lda = a_tmp.outerStride(); \
00286 } else { \
00287 a = _rhs; \
00288 lda = rhsStride; \
00289 } \
00290 \
00291 \
00292 MKLPREFIX##trmm(&side, &uplo, &transa, &diag, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \
00293 \
00294 \
00295 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
00296 res_tmp=res_tmp+b_tmp; \
00297 } \
00298 };
00299
00300 EIGEN_MKL_TRMM_R(double, double, d, d)
00301 EIGEN_MKL_TRMM_R(dcomplex, MKL_Complex16, cd, z)
00302 EIGEN_MKL_TRMM_R(float, float, f, s)
00303 EIGEN_MKL_TRMM_R(scomplex, MKL_Complex8, cf, c)
00304
00305 }
00306
00307 }
00308
00309 #endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H