00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #ifndef EIGEN_TRIANGULARMATRIXVECTOR_H
00026 #define EIGEN_TRIANGULARMATRIXVECTOR_H
00027
00028 namespace internal {
00029
00030 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder>
00031 struct product_triangular_matrix_vector;
00032
00033 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs>
00034 struct product_triangular_matrix_vector<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor>
00035 {
00036 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
00037 enum {
00038 IsLower = ((Mode&Lower)==Lower),
00039 HasUnitDiag = (Mode & UnitDiag)==UnitDiag
00040 };
00041 static EIGEN_DONT_INLINE void run(Index rows, Index cols, const LhsScalar* _lhs, Index lhsStride,
00042 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha)
00043 {
00044 static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
00045
00046 typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
00047 const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
00048 typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
00049
00050 typedef Map<const Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap;
00051 const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr));
00052 typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
00053
00054 typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap;
00055 ResMap res(_res,rows);
00056
00057 for (Index pi=0; pi<cols; pi+=PanelWidth)
00058 {
00059 Index actualPanelWidth = (std::min)(PanelWidth, cols-pi);
00060 for (Index k=0; k<actualPanelWidth; ++k)
00061 {
00062 Index i = pi + k;
00063 Index s = IsLower ? (HasUnitDiag ? i+1 : i ) : pi;
00064 Index r = IsLower ? actualPanelWidth-k : k+1;
00065 if ((!HasUnitDiag) || (--r)>0)
00066 res.segment(s,r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s,r);
00067 if (HasUnitDiag)
00068 res.coeffRef(i) += alpha * cjRhs.coeff(i);
00069 }
00070 Index r = IsLower ? cols - pi - actualPanelWidth : pi;
00071 if (r>0)
00072 {
00073 Index s = IsLower ? pi+actualPanelWidth : 0;
00074 general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs>::run(
00075 r, actualPanelWidth,
00076 &lhs.coeffRef(s,pi), lhsStride,
00077 &rhs.coeffRef(pi), rhsIncr,
00078 &res.coeffRef(s), resIncr, alpha);
00079 }
00080 }
00081 }
00082 };
00083
00084 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs>
00085 struct product_triangular_matrix_vector<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor>
00086 {
00087 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
00088 enum {
00089 IsLower = ((Mode&Lower)==Lower),
00090 HasUnitDiag = (Mode & UnitDiag)==UnitDiag
00091 };
00092 static void run(Index rows, Index cols, const LhsScalar* _lhs, Index lhsStride,
00093 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha)
00094 {
00095 static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
00096
00097 typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
00098 const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
00099 typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
00100
00101 typedef Map<const Matrix<RhsScalar,Dynamic,1> > RhsMap;
00102 const RhsMap rhs(_rhs,cols);
00103 typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
00104
00105 typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap;
00106 ResMap res(_res,rows,InnerStride<>(resIncr));
00107
00108 for (Index pi=0; pi<cols; pi+=PanelWidth)
00109 {
00110 Index actualPanelWidth = (std::min)(PanelWidth, cols-pi);
00111 for (Index k=0; k<actualPanelWidth; ++k)
00112 {
00113 Index i = pi + k;
00114 Index s = IsLower ? pi : (HasUnitDiag ? i+1 : i);
00115 Index r = IsLower ? k+1 : actualPanelWidth-k;
00116 if ((!HasUnitDiag) || (--r)>0)
00117 res.coeffRef(i) += alpha * (cjLhs.row(i).segment(s,r).cwiseProduct(cjRhs.segment(s,r).transpose())).sum();
00118 if (HasUnitDiag)
00119 res.coeffRef(i) += alpha * cjRhs.coeff(i);
00120 }
00121 Index r = IsLower ? pi : cols - pi - actualPanelWidth;
00122 if (r>0)
00123 {
00124 Index s = IsLower ? 0 : pi + actualPanelWidth;
00125 general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs>::run(
00126 actualPanelWidth, r,
00127 &lhs.coeffRef(pi,s), lhsStride,
00128 &rhs.coeffRef(s), rhsIncr,
00129 &res.coeffRef(pi), resIncr, alpha);
00130 }
00131 }
00132 }
00133 };
00134
00135
00136
00137
00138
00139 template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
00140 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true> >
00141 : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true>, Lhs, Rhs> >
00142 {};
00143
00144 template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
00145 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false> >
00146 : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false>, Lhs, Rhs> >
00147 {};
00148
00149
00150 template<int StorageOrder>
00151 struct trmv_selector;
00152
00153 }
00154
00155 template<int Mode, typename Lhs, typename Rhs>
00156 struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
00157 : public ProductBase<TriangularProduct<Mode,true,Lhs,false,Rhs,true>, Lhs, Rhs >
00158 {
00159 EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
00160
00161 TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
00162
00163 template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
00164 {
00165 eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
00166
00167 internal::trmv_selector<(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dst, alpha);
00168 }
00169 };
00170
00171 template<int Mode, typename Lhs, typename Rhs>
00172 struct TriangularProduct<Mode,false,Lhs,true,Rhs,false>
00173 : public ProductBase<TriangularProduct<Mode,false,Lhs,true,Rhs,false>, Lhs, Rhs >
00174 {
00175 EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
00176
00177 TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
00178
00179 template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
00180 {
00181 eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
00182
00183 typedef TriangularProduct<(Mode & UnitDiag) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,false,Transpose<const Lhs>,true> TriangularProductTranspose;
00184 Transpose<Dest> dstT(dst);
00185 internal::trmv_selector<(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>::run(
00186 TriangularProductTranspose(m_rhs.transpose(),m_lhs.transpose()), dstT, alpha);
00187 }
00188 };
00189
00190 namespace internal {
00191
00192
00193
00194 template<> struct trmv_selector<ColMajor>
00195 {
00196 template<int Mode, typename Lhs, typename Rhs, typename Dest>
00197 static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
00198 {
00199 typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
00200 typedef typename ProductType::Index Index;
00201 typedef typename ProductType::LhsScalar LhsScalar;
00202 typedef typename ProductType::RhsScalar RhsScalar;
00203 typedef typename ProductType::Scalar ResScalar;
00204 typedef typename ProductType::RealScalar RealScalar;
00205 typedef typename ProductType::ActualLhsType ActualLhsType;
00206 typedef typename ProductType::ActualRhsType ActualRhsType;
00207 typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
00208 typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
00209 typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest;
00210
00211 const ActualLhsType actualLhs = LhsBlasTraits::extract(prod.lhs());
00212 const ActualRhsType actualRhs = RhsBlasTraits::extract(prod.rhs());
00213
00214 ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
00215 * RhsBlasTraits::extractScalarFactor(prod.rhs());
00216
00217 enum {
00218
00219
00220 EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
00221 ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
00222 MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
00223 };
00224
00225 gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
00226
00227 bool alphaIsCompatible = (!ComplexByReal) || (imag(actualAlpha)==RealScalar(0));
00228 bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
00229
00230 RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
00231
00232 ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
00233 evalToDest ? dest.data() : static_dest.data());
00234
00235 if(!evalToDest)
00236 {
00237 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
00238 int size = dest.size();
00239 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
00240 #endif
00241 if(!alphaIsCompatible)
00242 {
00243 MappedDest(actualDestPtr, dest.size()).setZero();
00244 compatibleAlpha = RhsScalar(1);
00245 }
00246 else
00247 MappedDest(actualDestPtr, dest.size()) = dest;
00248 }
00249
00250 internal::product_triangular_matrix_vector
00251 <Index,Mode,
00252 LhsScalar, LhsBlasTraits::NeedToConjugate,
00253 RhsScalar, RhsBlasTraits::NeedToConjugate,
00254 ColMajor>
00255 ::run(actualLhs.rows(),actualLhs.cols(),
00256 actualLhs.data(),actualLhs.outerStride(),
00257 actualRhs.data(),actualRhs.innerStride(),
00258 actualDestPtr,1,compatibleAlpha);
00259
00260 if (!evalToDest)
00261 {
00262 if(!alphaIsCompatible)
00263 dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
00264 else
00265 dest = MappedDest(actualDestPtr, dest.size());
00266 }
00267 }
00268 };
00269
00270 template<> struct trmv_selector<RowMajor>
00271 {
00272 template<int Mode, typename Lhs, typename Rhs, typename Dest>
00273 static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
00274 {
00275 typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
00276 typedef typename ProductType::LhsScalar LhsScalar;
00277 typedef typename ProductType::RhsScalar RhsScalar;
00278 typedef typename ProductType::Scalar ResScalar;
00279 typedef typename ProductType::Index Index;
00280 typedef typename ProductType::ActualLhsType ActualLhsType;
00281 typedef typename ProductType::ActualRhsType ActualRhsType;
00282 typedef typename ProductType::_ActualRhsType _ActualRhsType;
00283 typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
00284 typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
00285
00286 typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
00287 typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
00288
00289 ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
00290 * RhsBlasTraits::extractScalarFactor(prod.rhs());
00291
00292 enum {
00293 DirectlyUseRhs = _ActualRhsType::InnerStrideAtCompileTime==1
00294 };
00295
00296 gemv_static_vector_if<RhsScalar,_ActualRhsType::SizeAtCompileTime,_ActualRhsType::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
00297
00298 ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
00299 DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data());
00300
00301 if(!DirectlyUseRhs)
00302 {
00303 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
00304 int size = actualRhs.size();
00305 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
00306 #endif
00307 Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
00308 }
00309
00310 internal::product_triangular_matrix_vector
00311 <Index,Mode,
00312 LhsScalar, LhsBlasTraits::NeedToConjugate,
00313 RhsScalar, RhsBlasTraits::NeedToConjugate,
00314 RowMajor>
00315 ::run(actualLhs.rows(),actualLhs.cols(),
00316 actualLhs.data(),actualLhs.outerStride(),
00317 actualRhsPtr,1,
00318 dest.data(),dest.innerStride(),
00319 actualAlpha);
00320 }
00321 };
00322
00323 }
00324
00325 #endif // EIGEN_TRIANGULARMATRIXVECTOR_H