00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #ifndef EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_H
00026 #define EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_H
00027
00028 namespace internal {
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs, int UpLo>
00039 struct tribb_kernel;
00040
00041
00042 template <typename Index,
00043 typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
00044 typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
00045 int ResStorageOrder, int UpLo>
00046 struct general_matrix_matrix_triangular_product;
00047
00048
00049 template <typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
00050 typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, int UpLo>
00051 struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor,UpLo>
00052 {
00053 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
00054 static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* lhs, Index lhsStride,
00055 const RhsScalar* rhs, Index rhsStride, ResScalar* res, Index resStride, ResScalar alpha)
00056 {
00057 general_matrix_matrix_triangular_product<Index,
00058 RhsScalar, RhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateRhs,
00059 LhsScalar, LhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateLhs,
00060 ColMajor, UpLo==Lower?Upper:Lower>
00061 ::run(size,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha);
00062 }
00063 };
00064
00065 template <typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
00066 typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, int UpLo>
00067 struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor,UpLo>
00068 {
00069 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
00070 static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* _lhs, Index lhsStride,
00071 const RhsScalar* _rhs, Index rhsStride, ResScalar* res, Index resStride, ResScalar alpha)
00072 {
00073 const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
00074 const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride);
00075
00076 typedef gebp_traits<LhsScalar,RhsScalar> Traits;
00077
00078 Index kc = depth;
00079 Index mc = size;
00080 Index nc = size;
00081 computeProductBlockingSizes<LhsScalar,RhsScalar>(kc, mc, nc);
00082
00083 if(mc > Traits::nr)
00084 mc = (mc/Traits::nr)*Traits::nr;
00085
00086 std::size_t sizeW = kc*Traits::WorkSpaceFactor;
00087 std::size_t sizeB = sizeW + kc*size;
00088 ei_declare_aligned_stack_constructed_variable(LhsScalar, blockA, kc*mc, 0);
00089 ei_declare_aligned_stack_constructed_variable(RhsScalar, allocatedBlockB, sizeB, 0);
00090 RhsScalar* blockB = allocatedBlockB + sizeW;
00091
00092 gemm_pack_lhs<LhsScalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
00093 gemm_pack_rhs<RhsScalar, Index, Traits::nr, RhsStorageOrder> pack_rhs;
00094 gebp_kernel <LhsScalar, RhsScalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp;
00095 tribb_kernel<LhsScalar, RhsScalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs, UpLo> sybb;
00096
00097 for(Index k2=0; k2<depth; k2+=kc)
00098 {
00099 const Index actual_kc = std::min(k2+kc,depth)-k2;
00100
00101
00102 pack_rhs(blockB, &rhs(k2,0), rhsStride, actual_kc, size);
00103
00104 for(Index i2=0; i2<size; i2+=mc)
00105 {
00106 const Index actual_mc = std::min(i2+mc,size)-i2;
00107
00108 pack_lhs(blockA, &lhs(i2, k2), lhsStride, actual_kc, actual_mc);
00109
00110
00111
00112
00113
00114 if (UpLo==Lower)
00115 gebp(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, std::min(size,i2), alpha,
00116 -1, -1, 0, 0, allocatedBlockB);
00117
00118 sybb(res+resStride*i2 + i2, resStride, blockA, blockB + actual_kc*i2, actual_mc, actual_kc, alpha, allocatedBlockB);
00119
00120 if (UpLo==Upper)
00121 {
00122 Index j2 = i2+actual_mc;
00123 gebp(res+resStride*j2+i2, resStride, blockA, blockB+actual_kc*j2, actual_mc, actual_kc, std::max(Index(0), size-j2), alpha,
00124 -1, -1, 0, 0, allocatedBlockB);
00125 }
00126 }
00127 }
00128 }
00129 };
00130
00131
00132
00133
00134
00135
00136
00137
00138
00139
00140 template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs, int UpLo>
00141 struct tribb_kernel
00142 {
00143 typedef gebp_traits<LhsScalar,RhsScalar,ConjLhs,ConjRhs> Traits;
00144 typedef typename Traits::ResScalar ResScalar;
00145
00146 enum {
00147 BlockSize = EIGEN_PLAIN_ENUM_MAX(mr,nr)
00148 };
00149 void operator()(ResScalar* res, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB, Index size, Index depth, ResScalar alpha, RhsScalar* workspace)
00150 {
00151 gebp_kernel<LhsScalar, RhsScalar, Index, mr, nr, ConjLhs, ConjRhs> gebp_kernel;
00152 Matrix<ResScalar,BlockSize,BlockSize,ColMajor> buffer;
00153
00154
00155
00156 for (Index j=0; j<size; j+=BlockSize)
00157 {
00158 Index actualBlockSize = std::min<Index>(BlockSize,size - j);
00159 const RhsScalar* actual_b = blockB+j*depth;
00160
00161 if(UpLo==Upper)
00162 gebp_kernel(res+j*resStride, resStride, blockA, actual_b, j, depth, actualBlockSize, alpha,
00163 -1, -1, 0, 0, workspace);
00164
00165
00166 {
00167 Index i = j;
00168 buffer.setZero();
00169
00170 gebp_kernel(buffer.data(), BlockSize, blockA+depth*i, actual_b, actualBlockSize, depth, actualBlockSize, alpha,
00171 -1, -1, 0, 0, workspace);
00172
00173 for(Index j1=0; j1<actualBlockSize; ++j1)
00174 {
00175 ResScalar* r = res + (j+j1)*resStride + i;
00176 for(Index i1=UpLo==Lower ? j1 : 0;
00177 UpLo==Lower ? i1<actualBlockSize : i1<=j1; ++i1)
00178 r[i1] += buffer(i1,j1);
00179 }
00180 }
00181
00182 if(UpLo==Lower)
00183 {
00184 Index i = j+actualBlockSize;
00185 gebp_kernel(res+j*resStride+i, resStride, blockA+depth*i, actual_b, size-i, depth, actualBlockSize, alpha,
00186 -1, -1, 0, 0, workspace);
00187 }
00188 }
00189 }
00190 };
00191
00192 }
00193
00194
00195
00196 template<typename MatrixType, unsigned int UpLo>
00197 template<typename ProductDerived, typename _Lhs, typename _Rhs>
00198 TriangularView<MatrixType,UpLo>& TriangularView<MatrixType,UpLo>::assignProduct(const ProductBase<ProductDerived, _Lhs,_Rhs>& prod, const Scalar& alpha)
00199 {
00200 typedef typename internal::remove_all<typename ProductDerived::LhsNested>::type Lhs;
00201 typedef internal::blas_traits<Lhs> LhsBlasTraits;
00202 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs;
00203 typedef typename internal::remove_all<ActualLhs>::type _ActualLhs;
00204 const ActualLhs actualLhs = LhsBlasTraits::extract(prod.lhs());
00205
00206 typedef typename internal::remove_all<typename ProductDerived::RhsNested>::type Rhs;
00207 typedef internal::blas_traits<Rhs> RhsBlasTraits;
00208 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs;
00209 typedef typename internal::remove_all<ActualRhs>::type _ActualRhs;
00210 const ActualRhs actualRhs = RhsBlasTraits::extract(prod.rhs());
00211
00212 typename ProductDerived::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived());
00213
00214 internal::general_matrix_matrix_triangular_product<Index,
00215 typename Lhs::Scalar, _ActualLhs::Flags&RowMajorBit ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate,
00216 typename Rhs::Scalar, _ActualRhs::Flags&RowMajorBit ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
00217 MatrixType::Flags&RowMajorBit ? RowMajor : ColMajor, UpLo>
00218 ::run(m_matrix.cols(), actualLhs.cols(),
00219 &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &actualRhs.coeffRef(0,0), actualRhs.outerStride(),
00220 const_cast<Scalar*>(m_matrix.data()), m_matrix.outerStride(), actualAlpha);
00221
00222 return *this;
00223 }
00224
00225 #endif // EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_H