00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifndef EIGEN_DIAGONALPRODUCT_H
00012 #define EIGEN_DIAGONALPRODUCT_H
00013
00014 namespace Eigen {
00015
00016 namespace internal {
00017 template<typename MatrixType, typename DiagonalType, int ProductOrder>
00018 struct traits<DiagonalProduct<MatrixType, DiagonalType, ProductOrder> >
00019 : traits<MatrixType>
00020 {
00021 typedef typename scalar_product_traits<typename MatrixType::Scalar, typename DiagonalType::Scalar>::ReturnType Scalar;
00022 enum {
00023 RowsAtCompileTime = MatrixType::RowsAtCompileTime,
00024 ColsAtCompileTime = MatrixType::ColsAtCompileTime,
00025 MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime,
00026 MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime,
00027
00028 _StorageOrder = MatrixType::Flags & RowMajorBit ? RowMajor : ColMajor,
00029 _PacketOnDiag = !((int(_StorageOrder) == RowMajor && int(ProductOrder) == OnTheLeft)
00030 ||(int(_StorageOrder) == ColMajor && int(ProductOrder) == OnTheRight)),
00031 _SameTypes = is_same<typename MatrixType::Scalar, typename DiagonalType::Scalar>::value,
00032
00033
00034 _Vectorizable = bool(int(MatrixType::Flags)&PacketAccessBit) && _SameTypes && ((!_PacketOnDiag) || (bool(int(DiagonalType::Flags)&PacketAccessBit))),
00035
00036 Flags = (HereditaryBits & (unsigned int)(MatrixType::Flags)) | (_Vectorizable ? PacketAccessBit : 0),
00037 CoeffReadCost = NumTraits<Scalar>::MulCost + MatrixType::CoeffReadCost + DiagonalType::DiagonalVectorType::CoeffReadCost
00038 };
00039 };
00040 }
00041
00042 template<typename MatrixType, typename DiagonalType, int ProductOrder>
00043 class DiagonalProduct : internal::no_assignment_operator,
00044 public MatrixBase<DiagonalProduct<MatrixType, DiagonalType, ProductOrder> >
00045 {
00046 public:
00047
00048 typedef MatrixBase<DiagonalProduct> Base;
00049 EIGEN_DENSE_PUBLIC_INTERFACE(DiagonalProduct)
00050
00051 inline DiagonalProduct(const MatrixType& matrix, const DiagonalType& diagonal)
00052 : m_matrix(matrix), m_diagonal(diagonal)
00053 {
00054 eigen_assert(diagonal.diagonal().size() == (ProductOrder == OnTheLeft ? matrix.rows() : matrix.cols()));
00055 }
00056
00057 inline Index rows() const { return m_matrix.rows(); }
00058 inline Index cols() const { return m_matrix.cols(); }
00059
00060 const Scalar coeff(Index row, Index col) const
00061 {
00062 return m_diagonal.diagonal().coeff(ProductOrder == OnTheLeft ? row : col) * m_matrix.coeff(row, col);
00063 }
00064
00065 template<int LoadMode>
00066 EIGEN_STRONG_INLINE PacketScalar packet(Index row, Index col) const
00067 {
00068 enum {
00069 StorageOrder = Flags & RowMajorBit ? RowMajor : ColMajor
00070 };
00071 const Index indexInDiagonalVector = ProductOrder == OnTheLeft ? row : col;
00072
00073 return packet_impl<LoadMode>(row,col,indexInDiagonalVector,typename internal::conditional<
00074 ((int(StorageOrder) == RowMajor && int(ProductOrder) == OnTheLeft)
00075 ||(int(StorageOrder) == ColMajor && int(ProductOrder) == OnTheRight)), internal::true_type, internal::false_type>::type());
00076 }
00077
00078 protected:
00079 template<int LoadMode>
00080 EIGEN_STRONG_INLINE PacketScalar packet_impl(Index row, Index col, Index id, internal::true_type) const
00081 {
00082 return internal::pmul(m_matrix.template packet<LoadMode>(row, col),
00083 internal::pset1<PacketScalar>(m_diagonal.diagonal().coeff(id)));
00084 }
00085
00086 template<int LoadMode>
00087 EIGEN_STRONG_INLINE PacketScalar packet_impl(Index row, Index col, Index id, internal::false_type) const
00088 {
00089 enum {
00090 InnerSize = (MatrixType::Flags & RowMajorBit) ? MatrixType::ColsAtCompileTime : MatrixType::RowsAtCompileTime,
00091 DiagonalVectorPacketLoadMode = (LoadMode == Aligned && ((InnerSize%16) == 0)) ? Aligned : Unaligned
00092 };
00093 return internal::pmul(m_matrix.template packet<LoadMode>(row, col),
00094 m_diagonal.diagonal().template packet<DiagonalVectorPacketLoadMode>(id));
00095 }
00096
00097 typename MatrixType::Nested m_matrix;
00098 typename DiagonalType::Nested m_diagonal;
00099 };
00100
00103 template<typename Derived>
00104 template<typename DiagonalDerived>
00105 inline const DiagonalProduct<Derived, DiagonalDerived, OnTheRight>
00106 MatrixBase<Derived>::operator*(const DiagonalBase<DiagonalDerived> &diagonal) const
00107 {
00108 return DiagonalProduct<Derived, DiagonalDerived, OnTheRight>(derived(), diagonal.derived());
00109 }
00110
00113 template<typename DiagonalDerived>
00114 template<typename MatrixDerived>
00115 inline const DiagonalProduct<MatrixDerived, DiagonalDerived, OnTheLeft>
00116 DiagonalBase<DiagonalDerived>::operator*(const MatrixBase<MatrixDerived> &matrix) const
00117 {
00118 return DiagonalProduct<MatrixDerived, DiagonalDerived, OnTheLeft>(matrix.derived(), derived());
00119 }
00120
00121 }
00122
00123 #endif // EIGEN_DIAGONALPRODUCT_H