00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #ifndef EIGEN_TRIANGULARMATRIXVECTOR_H
00011 #define EIGEN_TRIANGULARMATRIXVECTOR_H
00012
00013 namespace Eigen {
00014
00015 namespace internal {
00016
00017 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder, int Version=Specialized>
00018 struct triangular_matrix_vector_product;
00019
00020 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
00021 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version>
00022 {
00023 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
00024 enum {
00025 IsLower = ((Mode&Lower)==Lower),
00026 HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
00027 HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
00028 };
00029 static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
00030 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha)
00031 {
00032 static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
00033 Index size = (std::min)(_rows,_cols);
00034 Index rows = IsLower ? _rows : (std::min)(_rows,_cols);
00035 Index cols = IsLower ? (std::min)(_rows,_cols) : _cols;
00036
00037 typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
00038 const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
00039 typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
00040
00041 typedef Map<const Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap;
00042 const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr));
00043 typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
00044
00045 typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap;
00046 ResMap res(_res,rows);
00047
00048 for (Index pi=0; pi<size; pi+=PanelWidth)
00049 {
00050 Index actualPanelWidth = (std::min)(PanelWidth, size-pi);
00051 for (Index k=0; k<actualPanelWidth; ++k)
00052 {
00053 Index i = pi + k;
00054 Index s = IsLower ? ((HasUnitDiag||HasZeroDiag) ? i+1 : i ) : pi;
00055 Index r = IsLower ? actualPanelWidth-k : k+1;
00056 if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
00057 res.segment(s,r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s,r);
00058 if (HasUnitDiag)
00059 res.coeffRef(i) += alpha * cjRhs.coeff(i);
00060 }
00061 Index r = IsLower ? rows - pi - actualPanelWidth : pi;
00062 if (r>0)
00063 {
00064 Index s = IsLower ? pi+actualPanelWidth : 0;
00065 general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs,BuiltIn>::run(
00066 r, actualPanelWidth,
00067 &lhs.coeffRef(s,pi), lhsStride,
00068 &rhs.coeffRef(pi), rhsIncr,
00069 &res.coeffRef(s), resIncr, alpha);
00070 }
00071 }
00072 if((!IsLower) && cols>size)
00073 {
00074 general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs>::run(
00075 rows, cols-size,
00076 &lhs.coeffRef(0,size), lhsStride,
00077 &rhs.coeffRef(size), rhsIncr,
00078 _res, resIncr, alpha);
00079 }
00080 }
00081 };
00082
00083 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
00084 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version>
00085 {
00086 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
00087 enum {
00088 IsLower = ((Mode&Lower)==Lower),
00089 HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
00090 HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
00091 };
00092 static void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
00093 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha)
00094 {
00095 static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
00096 Index diagSize = (std::min)(_rows,_cols);
00097 Index rows = IsLower ? _rows : diagSize;
00098 Index cols = IsLower ? diagSize : _cols;
00099
00100 typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
00101 const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
00102 typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
00103
00104 typedef Map<const Matrix<RhsScalar,Dynamic,1> > RhsMap;
00105 const RhsMap rhs(_rhs,cols);
00106 typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
00107
00108 typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap;
00109 ResMap res(_res,rows,InnerStride<>(resIncr));
00110
00111 for (Index pi=0; pi<diagSize; pi+=PanelWidth)
00112 {
00113 Index actualPanelWidth = (std::min)(PanelWidth, diagSize-pi);
00114 for (Index k=0; k<actualPanelWidth; ++k)
00115 {
00116 Index i = pi + k;
00117 Index s = IsLower ? pi : ((HasUnitDiag||HasZeroDiag) ? i+1 : i);
00118 Index r = IsLower ? k+1 : actualPanelWidth-k;
00119 if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
00120 res.coeffRef(i) += alpha * (cjLhs.row(i).segment(s,r).cwiseProduct(cjRhs.segment(s,r).transpose())).sum();
00121 if (HasUnitDiag)
00122 res.coeffRef(i) += alpha * cjRhs.coeff(i);
00123 }
00124 Index r = IsLower ? pi : cols - pi - actualPanelWidth;
00125 if (r>0)
00126 {
00127 Index s = IsLower ? 0 : pi + actualPanelWidth;
00128 general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs,BuiltIn>::run(
00129 actualPanelWidth, r,
00130 &lhs.coeffRef(pi,s), lhsStride,
00131 &rhs.coeffRef(s), rhsIncr,
00132 &res.coeffRef(pi), resIncr, alpha);
00133 }
00134 }
00135 if(IsLower && rows>diagSize)
00136 {
00137 general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs>::run(
00138 rows-diagSize, cols,
00139 &lhs.coeffRef(diagSize,0), lhsStride,
00140 &rhs.coeffRef(0), rhsIncr,
00141 &res.coeffRef(diagSize), resIncr, alpha);
00142 }
00143 }
00144 };
00145
00146
00147
00148
00149
00150 template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
00151 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true> >
00152 : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true>, Lhs, Rhs> >
00153 {};
00154
00155 template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
00156 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false> >
00157 : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false>, Lhs, Rhs> >
00158 {};
00159
00160
00161 template<int StorageOrder>
00162 struct trmv_selector;
00163
00164 }
00165
00166 template<int Mode, typename Lhs, typename Rhs>
00167 struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
00168 : public ProductBase<TriangularProduct<Mode,true,Lhs,false,Rhs,true>, Lhs, Rhs >
00169 {
00170 EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
00171
00172 TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
00173
00174 template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
00175 {
00176 eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
00177
00178 internal::trmv_selector<(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dst, alpha);
00179 }
00180 };
00181
00182 template<int Mode, typename Lhs, typename Rhs>
00183 struct TriangularProduct<Mode,false,Lhs,true,Rhs,false>
00184 : public ProductBase<TriangularProduct<Mode,false,Lhs,true,Rhs,false>, Lhs, Rhs >
00185 {
00186 EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
00187
00188 TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
00189
00190 template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
00191 {
00192 eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
00193
00194 typedef TriangularProduct<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,false,Transpose<const Lhs>,true> TriangularProductTranspose;
00195 Transpose<Dest> dstT(dst);
00196 internal::trmv_selector<(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>::run(
00197 TriangularProductTranspose(m_rhs.transpose(),m_lhs.transpose()), dstT, alpha);
00198 }
00199 };
00200
00201 namespace internal {
00202
00203
00204
00205 template<> struct trmv_selector<ColMajor>
00206 {
00207 template<int Mode, typename Lhs, typename Rhs, typename Dest>
00208 static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
00209 {
00210 typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
00211 typedef typename ProductType::Index Index;
00212 typedef typename ProductType::LhsScalar LhsScalar;
00213 typedef typename ProductType::RhsScalar RhsScalar;
00214 typedef typename ProductType::Scalar ResScalar;
00215 typedef typename ProductType::RealScalar RealScalar;
00216 typedef typename ProductType::ActualLhsType ActualLhsType;
00217 typedef typename ProductType::ActualRhsType ActualRhsType;
00218 typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
00219 typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
00220 typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest;
00221
00222 typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
00223 typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
00224
00225 ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
00226 * RhsBlasTraits::extractScalarFactor(prod.rhs());
00227
00228 enum {
00229
00230
00231 EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
00232 ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
00233 MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
00234 };
00235
00236 gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
00237
00238 bool alphaIsCompatible = (!ComplexByReal) || (imag(actualAlpha)==RealScalar(0));
00239 bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
00240
00241 RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
00242
00243 ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
00244 evalToDest ? dest.data() : static_dest.data());
00245
00246 if(!evalToDest)
00247 {
00248 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
00249 int size = dest.size();
00250 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
00251 #endif
00252 if(!alphaIsCompatible)
00253 {
00254 MappedDest(actualDestPtr, dest.size()).setZero();
00255 compatibleAlpha = RhsScalar(1);
00256 }
00257 else
00258 MappedDest(actualDestPtr, dest.size()) = dest;
00259 }
00260
00261 internal::triangular_matrix_vector_product
00262 <Index,Mode,
00263 LhsScalar, LhsBlasTraits::NeedToConjugate,
00264 RhsScalar, RhsBlasTraits::NeedToConjugate,
00265 ColMajor>
00266 ::run(actualLhs.rows(),actualLhs.cols(),
00267 actualLhs.data(),actualLhs.outerStride(),
00268 actualRhs.data(),actualRhs.innerStride(),
00269 actualDestPtr,1,compatibleAlpha);
00270
00271 if (!evalToDest)
00272 {
00273 if(!alphaIsCompatible)
00274 dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
00275 else
00276 dest = MappedDest(actualDestPtr, dest.size());
00277 }
00278 }
00279 };
00280
00281 template<> struct trmv_selector<RowMajor>
00282 {
00283 template<int Mode, typename Lhs, typename Rhs, typename Dest>
00284 static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
00285 {
00286 typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
00287 typedef typename ProductType::LhsScalar LhsScalar;
00288 typedef typename ProductType::RhsScalar RhsScalar;
00289 typedef typename ProductType::Scalar ResScalar;
00290 typedef typename ProductType::Index Index;
00291 typedef typename ProductType::ActualLhsType ActualLhsType;
00292 typedef typename ProductType::ActualRhsType ActualRhsType;
00293 typedef typename ProductType::_ActualRhsType _ActualRhsType;
00294 typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
00295 typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
00296
00297 typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
00298 typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
00299
00300 ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
00301 * RhsBlasTraits::extractScalarFactor(prod.rhs());
00302
00303 enum {
00304 DirectlyUseRhs = _ActualRhsType::InnerStrideAtCompileTime==1
00305 };
00306
00307 gemv_static_vector_if<RhsScalar,_ActualRhsType::SizeAtCompileTime,_ActualRhsType::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
00308
00309 ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
00310 DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data());
00311
00312 if(!DirectlyUseRhs)
00313 {
00314 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
00315 int size = actualRhs.size();
00316 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
00317 #endif
00318 Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
00319 }
00320
00321 internal::triangular_matrix_vector_product
00322 <Index,Mode,
00323 LhsScalar, LhsBlasTraits::NeedToConjugate,
00324 RhsScalar, RhsBlasTraits::NeedToConjugate,
00325 RowMajor>
00326 ::run(actualLhs.rows(),actualLhs.cols(),
00327 actualLhs.data(),actualLhs.outerStride(),
00328 actualRhsPtr,1,
00329 dest.data(),dest.innerStride(),
00330 actualAlpha);
00331 }
00332 };
00333
00334 }
00335
00336 }
00337
00338 #endif // EIGEN_TRIANGULARMATRIXVECTOR_H