00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033 #ifndef EIGEN_GENERAL_MATRIX_VECTOR_MKL_H
00034 #define EIGEN_GENERAL_MATRIX_VECTOR_MKL_H
00035
00036 namespace Eigen {
00037
00038 namespace internal {
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049 template<typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, typename RhsScalar, bool ConjugateRhs>
00050 struct general_matrix_vector_product_gemv :
00051 general_matrix_vector_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,ConjugateRhs,BuiltIn> {};
00052
00053 #define EIGEN_MKL_GEMV_SPECIALIZE(Scalar) \
00054 template<typename Index, bool ConjugateLhs, bool ConjugateRhs> \
00055 struct general_matrix_vector_product<Index,Scalar,ColMajor,ConjugateLhs,Scalar,ConjugateRhs,Specialized> { \
00056 static void run( \
00057 Index rows, Index cols, \
00058 const Scalar* lhs, Index lhsStride, \
00059 const Scalar* rhs, Index rhsIncr, \
00060 Scalar* res, Index resIncr, Scalar alpha) \
00061 { \
00062 if (ConjugateLhs) { \
00063 general_matrix_vector_product<Index,Scalar,ColMajor,ConjugateLhs,Scalar,ConjugateRhs,BuiltIn>::run( \
00064 rows, cols, lhs, lhsStride, rhs, rhsIncr, res, resIncr, alpha); \
00065 } else { \
00066 general_matrix_vector_product_gemv<Index,Scalar,ColMajor,ConjugateLhs,Scalar,ConjugateRhs>::run( \
00067 rows, cols, lhs, lhsStride, rhs, rhsIncr, res, resIncr, alpha); \
00068 } \
00069 } \
00070 }; \
00071 template<typename Index, bool ConjugateLhs, bool ConjugateRhs> \
00072 struct general_matrix_vector_product<Index,Scalar,RowMajor,ConjugateLhs,Scalar,ConjugateRhs,Specialized> { \
00073 static void run( \
00074 Index rows, Index cols, \
00075 const Scalar* lhs, Index lhsStride, \
00076 const Scalar* rhs, Index rhsIncr, \
00077 Scalar* res, Index resIncr, Scalar alpha) \
00078 { \
00079 general_matrix_vector_product_gemv<Index,Scalar,RowMajor,ConjugateLhs,Scalar,ConjugateRhs>::run( \
00080 rows, cols, lhs, lhsStride, rhs, rhsIncr, res, resIncr, alpha); \
00081 } \
00082 }; \
00083
00084 EIGEN_MKL_GEMV_SPECIALIZE(double)
00085 EIGEN_MKL_GEMV_SPECIALIZE(float)
00086 EIGEN_MKL_GEMV_SPECIALIZE(dcomplex)
00087 EIGEN_MKL_GEMV_SPECIALIZE(scomplex)
00088
00089 #define EIGEN_MKL_GEMV_SPECIALIZATION(EIGTYPE,MKLTYPE,MKLPREFIX) \
00090 template<typename Index, int LhsStorageOrder, bool ConjugateLhs, bool ConjugateRhs> \
00091 struct general_matrix_vector_product_gemv<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,ConjugateRhs> \
00092 { \
00093 typedef Matrix<EIGTYPE,Dynamic,1,ColMajor> GEMVVector;\
00094 \
00095 static void run( \
00096 Index rows, Index cols, \
00097 const EIGTYPE* lhs, Index lhsStride, \
00098 const EIGTYPE* rhs, Index rhsIncr, \
00099 EIGTYPE* res, Index resIncr, EIGTYPE alpha) \
00100 { \
00101 MKL_INT m=rows, n=cols, lda=lhsStride, incx=rhsIncr, incy=resIncr; \
00102 MKLTYPE alpha_, beta_; \
00103 const EIGTYPE *x_ptr, myone(1); \
00104 char trans=(LhsStorageOrder==ColMajor) ? 'N' : (ConjugateLhs) ? 'C' : 'T'; \
00105 if (LhsStorageOrder==RowMajor) { \
00106 m=cols; \
00107 n=rows; \
00108 }\
00109 assign_scalar_eig2mkl(alpha_, alpha); \
00110 assign_scalar_eig2mkl(beta_, myone); \
00111 GEMVVector x_tmp; \
00112 if (ConjugateRhs) { \
00113 Map<const GEMVVector, 0, InnerStride<> > map_x(rhs,cols,1,InnerStride<>(incx)); \
00114 x_tmp=map_x.conjugate(); \
00115 x_ptr=x_tmp.data(); \
00116 incx=1; \
00117 } else x_ptr=rhs; \
00118 MKLPREFIX##gemv(&trans, &m, &n, &alpha_, (const MKLTYPE*)lhs, &lda, (const MKLTYPE*)x_ptr, &incx, &beta_, (MKLTYPE*)res, &incy); \
00119 }\
00120 };
00121
00122 EIGEN_MKL_GEMV_SPECIALIZATION(double, double, d)
00123 EIGEN_MKL_GEMV_SPECIALIZATION(float, float, s)
00124 EIGEN_MKL_GEMV_SPECIALIZATION(dcomplex, MKL_Complex16, z)
00125 EIGEN_MKL_GEMV_SPECIALIZATION(scomplex, MKL_Complex8, c)
00126
00127 }
00128
00129 }
00130
00131 #endif // EIGEN_GENERAL_MATRIX_VECTOR_MKL_H