00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 
00014 
00015 
00016 
00017 
00018 
00019 
00020 
00021 
00022 
00023 
00024 
00025 
00026 
00027 
00028 
00029 
00030 
00031 
00032 
00033 #ifndef EIGEN_TRIANGULAR_MATRIX_VECTOR_MKL_H
00034 #define EIGEN_TRIANGULAR_MATRIX_VECTOR_MKL_H
00035 
00036 namespace Eigen { 
00037 
00038 namespace internal {
00039 
00040 
00041 
00042 
00043 
00044 
00045 
00046 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder>
00047 struct triangular_matrix_vector_product_trmv :
00048   triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,StorageOrder,BuiltIn> {};
00049 
00050 #define EIGEN_MKL_TRMV_SPECIALIZE(Scalar) \
00051 template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
00052 struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor,Specialized> { \
00053  static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \
00054                                      const Scalar* _rhs, Index rhsIncr, Scalar* _res, Index resIncr, Scalar alpha) { \
00055       triangular_matrix_vector_product_trmv<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor>::run( \
00056         _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
00057   } \
00058 }; \
00059 template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
00060 struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,RowMajor,Specialized> { \
00061  static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \
00062                                      const Scalar* _rhs, Index rhsIncr, Scalar* _res, Index resIncr, Scalar alpha) { \
00063       triangular_matrix_vector_product_trmv<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,RowMajor>::run( \
00064         _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
00065   } \
00066 };
00067 
00068 EIGEN_MKL_TRMV_SPECIALIZE(double)
00069 EIGEN_MKL_TRMV_SPECIALIZE(float)
00070 EIGEN_MKL_TRMV_SPECIALIZE(dcomplex)
00071 EIGEN_MKL_TRMV_SPECIALIZE(scomplex)
00072 
00073 
00074 #define EIGEN_MKL_TRMV_CM(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
00075 template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
00076 struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor> { \
00077   enum { \
00078     IsLower = (Mode&Lower) == Lower, \
00079     SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
00080     IsUnitDiag  = (Mode&UnitDiag) ? 1 : 0, \
00081     IsZeroDiag  = (Mode&ZeroDiag) ? 1 : 0, \
00082     LowUp = IsLower ? Lower : Upper \
00083   }; \
00084  static void run(Index _rows, Index _cols, const EIGTYPE* _lhs, Index lhsStride, \
00085                  const EIGTYPE* _rhs, Index rhsIncr, EIGTYPE* _res, Index resIncr, EIGTYPE alpha) \
00086  { \
00087    if (ConjLhs || IsZeroDiag) { \
00088      triangular_matrix_vector_product<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor,BuiltIn>::run( \
00089        _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
00090      return; \
00091    }\
00092    Index size = (std::min)(_rows,_cols); \
00093    Index rows = IsLower ? _rows : size; \
00094    Index cols = IsLower ? size : _cols; \
00095 \
00096    typedef VectorX##EIGPREFIX VectorRhs; \
00097    EIGTYPE *x, *y;\
00098 \
00099  \
00100    Map<const VectorRhs, 0, InnerStride<> > rhs(_rhs,cols,InnerStride<>(rhsIncr)); \
00101    VectorRhs x_tmp; \
00102    if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
00103    x = x_tmp.data(); \
00104 \
00105 \
00106 \
00107    char trans, uplo, diag; \
00108    MKL_INT m, n, lda, incx, incy; \
00109    EIGTYPE const *a; \
00110    MKLTYPE alpha_, beta_; \
00111    assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
00112    assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
00113 \
00114  \
00115    n = (MKL_INT)size; \
00116    lda = lhsStride; \
00117    incx = 1; \
00118    incy = resIncr; \
00119 \
00120  \
00121    trans = 'N'; \
00122    uplo = IsLower ? 'L' : 'U'; \
00123    diag = IsUnitDiag ? 'U' : 'N'; \
00124 \
00125  \
00126    MKLPREFIX##trmv(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \
00127 \
00128  \
00129    MKLPREFIX##axpy(&n, &alpha_,(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \
00130  \
00131    if (size<(std::max)(rows,cols)) { \
00132      typedef Matrix<EIGTYPE, Dynamic, Dynamic> MatrixLhs; \
00133      if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
00134      x = x_tmp.data(); \
00135      if (size<rows) { \
00136        y = _res + size*resIncr; \
00137        a = _lhs + size; \
00138        m = rows-size; \
00139        n = size; \
00140      } \
00141      else { \
00142        x += size; \
00143        y = _res; \
00144        a = _lhs + size*lda; \
00145        m = size; \
00146        n = cols-size; \
00147      } \
00148      MKLPREFIX##gemv(&trans, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &beta_, (MKLTYPE*)y, &incy); \
00149    } \
00150   } \
00151 };
00152 
00153 EIGEN_MKL_TRMV_CM(double, double, d, d)
00154 EIGEN_MKL_TRMV_CM(dcomplex, MKL_Complex16, cd, z)
00155 EIGEN_MKL_TRMV_CM(float, float, f, s)
00156 EIGEN_MKL_TRMV_CM(scomplex, MKL_Complex8, cf, c)
00157 
00158 
00159 #define EIGEN_MKL_TRMV_RM(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
00160 template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
00161 struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor> { \
00162   enum { \
00163     IsLower = (Mode&Lower) == Lower, \
00164     SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
00165     IsUnitDiag  = (Mode&UnitDiag) ? 1 : 0, \
00166     IsZeroDiag  = (Mode&ZeroDiag) ? 1 : 0, \
00167     LowUp = IsLower ? Lower : Upper \
00168   }; \
00169  static void run(Index _rows, Index _cols, const EIGTYPE* _lhs, Index lhsStride, \
00170                  const EIGTYPE* _rhs, Index rhsIncr, EIGTYPE* _res, Index resIncr, EIGTYPE alpha) \
00171  { \
00172    if (IsZeroDiag) { \
00173      triangular_matrix_vector_product<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor,BuiltIn>::run( \
00174        _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
00175      return; \
00176    }\
00177    Index size = (std::min)(_rows,_cols); \
00178    Index rows = IsLower ? _rows : size; \
00179    Index cols = IsLower ? size : _cols; \
00180 \
00181    typedef VectorX##EIGPREFIX VectorRhs; \
00182    EIGTYPE *x, *y;\
00183 \
00184  \
00185    Map<const VectorRhs, 0, InnerStride<> > rhs(_rhs,cols,InnerStride<>(rhsIncr)); \
00186    VectorRhs x_tmp; \
00187    if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
00188    x = x_tmp.data(); \
00189 \
00190 \
00191 \
00192    char trans, uplo, diag; \
00193    MKL_INT m, n, lda, incx, incy; \
00194    EIGTYPE const *a; \
00195    MKLTYPE alpha_, beta_; \
00196    assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
00197    assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
00198 \
00199  \
00200    n = (MKL_INT)size; \
00201    lda = lhsStride; \
00202    incx = 1; \
00203    incy = resIncr; \
00204 \
00205  \
00206    trans = ConjLhs ? 'C' : 'T'; \
00207    uplo = IsLower ? 'U' : 'L'; \
00208    diag = IsUnitDiag ? 'U' : 'N'; \
00209 \
00210  \
00211    MKLPREFIX##trmv(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \
00212 \
00213  \
00214    MKLPREFIX##axpy(&n, &alpha_,(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \
00215  \
00216    if (size<(std::max)(rows,cols)) { \
00217      typedef Matrix<EIGTYPE, Dynamic, Dynamic> MatrixLhs; \
00218      if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
00219      x = x_tmp.data(); \
00220      if (size<rows) { \
00221        y = _res + size*resIncr; \
00222        a = _lhs + size*lda; \
00223        m = rows-size; \
00224        n = size; \
00225      } \
00226      else { \
00227        x += size; \
00228        y = _res; \
00229        a = _lhs + size; \
00230        m = size; \
00231        n = cols-size; \
00232      } \
00233      MKLPREFIX##gemv(&trans, &n, &m, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &beta_, (MKLTYPE*)y, &incy); \
00234    } \
00235   } \
00236 };
00237 
00238 EIGEN_MKL_TRMV_RM(double, double, d, d)
00239 EIGEN_MKL_TRMV_RM(dcomplex, MKL_Complex16, cd, z)
00240 EIGEN_MKL_TRMV_RM(float, float, f, s)
00241 EIGEN_MKL_TRMV_RM(scomplex, MKL_Complex8, cf, c)
00242 
00243 } 
00244 
00245 } 
00246 
00247 #endif // EIGEN_TRIANGULAR_MATRIX_VECTOR_MKL_H