33 #ifndef EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H 34 #define EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H 46 template<
typename Index,
int Mode,
typename LhsScalar,
bool ConjLhs,
typename RhsScalar,
bool ConjRhs,
int StorageOrder>
50 #define EIGEN_BLAS_TRMV_SPECIALIZE(Scalar) \ 51 template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \ 52 struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor,Specialized> { \ 53 static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \ 54 const Scalar* _rhs, Index rhsIncr, Scalar* _res, Index resIncr, Scalar alpha) { \ 55 triangular_matrix_vector_product_trmv<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor>::run( \ 56 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \ 59 template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \ 60 struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,RowMajor,Specialized> { \ 61 static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \ 62 const Scalar* _rhs, Index rhsIncr, Scalar* _res, Index resIncr, Scalar alpha) { \ 63 triangular_matrix_vector_product_trmv<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,RowMajor>::run( \ 64 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \ 74 #define EIGEN_BLAS_TRMV_CM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \ 75 template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \ 76 struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor> { \ 78 IsLower = (Mode&Lower) == Lower, \ 79 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \ 80 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \ 81 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \ 82 LowUp = IsLower ? Lower : Upper \ 84 static void run(Index _rows, Index _cols, const EIGTYPE* _lhs, Index lhsStride, \ 85 const EIGTYPE* _rhs, Index rhsIncr, EIGTYPE* _res, Index resIncr, EIGTYPE alpha) \ 87 if (ConjLhs || IsZeroDiag) { \ 88 triangular_matrix_vector_product<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor,BuiltIn>::run( \ 89 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \ 92 Index size = (std::min)(_rows,_cols); \ 93 Index rows = IsLower ? _rows : size; \ 94 Index cols = IsLower ? size : _cols; \ 96 typedef VectorX##EIGPREFIX VectorRhs; \ 100 Map<const VectorRhs, 0, InnerStride<> > rhs(_rhs,cols,InnerStride<>(rhsIncr)); \ 102 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \ 107 char trans, uplo, diag; \ 108 BlasIndex m, n, lda, incx, incy; \ 113 n = convert_index<BlasIndex>(size); \ 114 lda = convert_index<BlasIndex>(lhsStride); \ 116 incy = convert_index<BlasIndex>(resIncr); \ 120 uplo = IsLower ? 'L' : 'U'; \ 121 diag = IsUnitDiag ? 'U' : 'N'; \ 124 BLASPREFIX##trmv_(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \ 127 BLASPREFIX##axpy_(&n, &numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \ 129 if (size<(std::max)(rows,cols)) { \ 130 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \ 133 y = _res + size*resIncr; \ 135 m = convert_index<BlasIndex>(rows-size); \ 136 n = convert_index<BlasIndex>(size); \ 141 a = _lhs + size*lda; \ 142 m = convert_index<BlasIndex>(size); \ 143 n = convert_index<BlasIndex>(cols-size); \ 145 BLASPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, &numext::real_ref(beta), (BLASTYPE*)y, &incy); \ 156 #define EIGEN_BLAS_TRMV_RM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \ 157 template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \ 158 struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor> { \ 160 IsLower = (Mode&Lower) == Lower, \ 161 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \ 162 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \ 163 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \ 164 LowUp = IsLower ? Lower : Upper \ 166 static void run(Index _rows, Index _cols, const EIGTYPE* _lhs, Index lhsStride, \ 167 const EIGTYPE* _rhs, Index rhsIncr, EIGTYPE* _res, Index resIncr, EIGTYPE alpha) \ 170 triangular_matrix_vector_product<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor,BuiltIn>::run( \ 171 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \ 174 Index size = (std::min)(_rows,_cols); \ 175 Index rows = IsLower ? _rows : size; \ 176 Index cols = IsLower ? size : _cols; \ 178 typedef VectorX##EIGPREFIX VectorRhs; \ 182 Map<const VectorRhs, 0, InnerStride<> > rhs(_rhs,cols,InnerStride<>(rhsIncr)); \ 184 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \ 189 char trans, uplo, diag; \ 190 BlasIndex m, n, lda, incx, incy; \ 195 n = convert_index<BlasIndex>(size); \ 196 lda = convert_index<BlasIndex>(lhsStride); \ 198 incy = convert_index<BlasIndex>(resIncr); \ 201 trans = ConjLhs ? 'C' : 'T'; \ 202 uplo = IsLower ? 'U' : 'L'; \ 203 diag = IsUnitDiag ? 'U' : 'N'; \ 206 BLASPREFIX##trmv_(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \ 209 BLASPREFIX##axpy_(&n, &numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \ 211 if (size<(std::max)(rows,cols)) { \ 212 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \ 215 y = _res + size*resIncr; \ 216 a = _lhs + size*lda; \ 217 m = convert_index<BlasIndex>(rows-size); \ 218 n = convert_index<BlasIndex>(size); \ 224 m = convert_index<BlasIndex>(size); \ 225 n = convert_index<BlasIndex>(cols-size); \ 227 BLASPREFIX##gemv_(&trans, &n, &m, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, &numext::real_ref(beta), (BLASTYPE*)y, &incy); \ 241 #endif // EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H
static int f(const TensorMap< Tensor< int, 3 > > &tensor)
#define EIGEN_BLAS_TRMV_CM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX)
std::complex< float > scomplex
std::complex< double > dcomplex
#define EIGEN_BLAS_TRMV_SPECIALIZE(Scalar)
#define EIGEN_BLAS_TRMV_RM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX)