00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #ifndef EIGEN_BLASUTIL_H
00011 #define EIGEN_BLASUTIL_H
00012
00013
00014
00015
00016 namespace Eigen {
00017
00018 namespace internal {
00019
00020
00021 template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjugateLhs=false, bool ConjugateRhs=false>
00022 struct gebp_kernel;
00023
00024 template<typename Scalar, typename Index, int nr, int StorageOrder, bool Conjugate = false, bool PanelMode=false>
00025 struct gemm_pack_rhs;
00026
00027 template<typename Scalar, typename Index, int Pack1, int Pack2, int StorageOrder, bool Conjugate = false, bool PanelMode = false>
00028 struct gemm_pack_lhs;
00029
00030 template<
00031 typename Index,
00032 typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
00033 typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
00034 int ResStorageOrder>
00035 struct general_matrix_matrix_product;
00036
00037 template<typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, typename RhsScalar, bool ConjugateRhs, int Version=Specialized>
00038 struct general_matrix_vector_product;
00039
00040
00041 template<bool Conjugate> struct conj_if;
00042
00043 template<> struct conj_if<true> {
00044 template<typename T>
00045 inline T operator()(const T& x) { return conj(x); }
00046 template<typename T>
00047 inline T pconj(const T& x) { return internal::pconj(x); }
00048 };
00049
00050 template<> struct conj_if<false> {
00051 template<typename T>
00052 inline const T& operator()(const T& x) { return x; }
00053 template<typename T>
00054 inline const T& pconj(const T& x) { return x; }
00055 };
00056
00057 template<typename Scalar> struct conj_helper<Scalar,Scalar,false,false>
00058 {
00059 EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const { return internal::pmadd(x,y,c); }
00060 EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const { return internal::pmul(x,y); }
00061 };
00062
00063 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, false,true>
00064 {
00065 typedef std::complex<RealScalar> Scalar;
00066 EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
00067 { return c + pmul(x,y); }
00068
00069 EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
00070 { return Scalar(real(x)*real(y) + imag(x)*imag(y), imag(x)*real(y) - real(x)*imag(y)); }
00071 };
00072
00073 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,false>
00074 {
00075 typedef std::complex<RealScalar> Scalar;
00076 EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
00077 { return c + pmul(x,y); }
00078
00079 EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
00080 { return Scalar(real(x)*real(y) + imag(x)*imag(y), real(x)*imag(y) - imag(x)*real(y)); }
00081 };
00082
00083 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,true>
00084 {
00085 typedef std::complex<RealScalar> Scalar;
00086 EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
00087 { return c + pmul(x,y); }
00088
00089 EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
00090 { return Scalar(real(x)*real(y) - imag(x)*imag(y), - real(x)*imag(y) - imag(x)*real(y)); }
00091 };
00092
00093 template<typename RealScalar,bool Conj> struct conj_helper<std::complex<RealScalar>, RealScalar, Conj,false>
00094 {
00095 typedef std::complex<RealScalar> Scalar;
00096 EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const RealScalar& y, const Scalar& c) const
00097 { return padd(c, pmul(x,y)); }
00098 EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const RealScalar& y) const
00099 { return conj_if<Conj>()(x)*y; }
00100 };
00101
00102 template<typename RealScalar,bool Conj> struct conj_helper<RealScalar, std::complex<RealScalar>, false,Conj>
00103 {
00104 typedef std::complex<RealScalar> Scalar;
00105 EIGEN_STRONG_INLINE Scalar pmadd(const RealScalar& x, const Scalar& y, const Scalar& c) const
00106 { return padd(c, pmul(x,y)); }
00107 EIGEN_STRONG_INLINE Scalar pmul(const RealScalar& x, const Scalar& y) const
00108 { return x*conj_if<Conj>()(y); }
00109 };
00110
00111 template<typename From,typename To> struct get_factor {
00112 static EIGEN_STRONG_INLINE To run(const From& x) { return x; }
00113 };
00114
00115 template<typename Scalar> struct get_factor<Scalar,typename NumTraits<Scalar>::Real> {
00116 static EIGEN_STRONG_INLINE typename NumTraits<Scalar>::Real run(const Scalar& x) { return real(x); }
00117 };
00118
00119
00120
00121
00122 template<typename Scalar, typename Index, int StorageOrder>
00123 class blas_data_mapper
00124 {
00125 public:
00126 blas_data_mapper(Scalar* data, Index stride) : m_data(data), m_stride(stride) {}
00127 EIGEN_STRONG_INLINE Scalar& operator()(Index i, Index j)
00128 { return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride]; }
00129 protected:
00130 Scalar* EIGEN_RESTRICT m_data;
00131 Index m_stride;
00132 };
00133
00134
00135 template<typename Scalar, typename Index, int StorageOrder>
00136 class const_blas_data_mapper
00137 {
00138 public:
00139 const_blas_data_mapper(const Scalar* data, Index stride) : m_data(data), m_stride(stride) {}
00140 EIGEN_STRONG_INLINE const Scalar& operator()(Index i, Index j) const
00141 { return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride]; }
00142 protected:
00143 const Scalar* EIGEN_RESTRICT m_data;
00144 Index m_stride;
00145 };
00146
00147
00148
00149
00150
00151 template<typename XprType> struct blas_traits
00152 {
00153 typedef typename traits<XprType>::Scalar Scalar;
00154 typedef const XprType& ExtractType;
00155 typedef XprType _ExtractType;
00156 enum {
00157 IsComplex = NumTraits<Scalar>::IsComplex,
00158 IsTransposed = false,
00159 NeedToConjugate = false,
00160 HasUsableDirectAccess = ( (int(XprType::Flags)&DirectAccessBit)
00161 && ( bool(XprType::IsVectorAtCompileTime)
00162 || int(inner_stride_at_compile_time<XprType>::ret) == 1)
00163 ) ? 1 : 0
00164 };
00165 typedef typename conditional<bool(HasUsableDirectAccess),
00166 ExtractType,
00167 typename _ExtractType::PlainObject
00168 >::type DirectLinearAccessType;
00169 static inline ExtractType extract(const XprType& x) { return x; }
00170 static inline const Scalar extractScalarFactor(const XprType&) { return Scalar(1); }
00171 };
00172
00173
00174 template<typename Scalar, typename NestedXpr>
00175 struct blas_traits<CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> >
00176 : blas_traits<NestedXpr>
00177 {
00178 typedef blas_traits<NestedXpr> Base;
00179 typedef CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> XprType;
00180 typedef typename Base::ExtractType ExtractType;
00181
00182 enum {
00183 IsComplex = NumTraits<Scalar>::IsComplex,
00184 NeedToConjugate = Base::NeedToConjugate ? 0 : IsComplex
00185 };
00186 static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
00187 static inline Scalar extractScalarFactor(const XprType& x) { return conj(Base::extractScalarFactor(x.nestedExpression())); }
00188 };
00189
00190
00191 template<typename Scalar, typename NestedXpr>
00192 struct blas_traits<CwiseUnaryOp<scalar_multiple_op<Scalar>, NestedXpr> >
00193 : blas_traits<NestedXpr>
00194 {
00195 typedef blas_traits<NestedXpr> Base;
00196 typedef CwiseUnaryOp<scalar_multiple_op<Scalar>, NestedXpr> XprType;
00197 typedef typename Base::ExtractType ExtractType;
00198 static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
00199 static inline Scalar extractScalarFactor(const XprType& x)
00200 { return x.functor().m_other * Base::extractScalarFactor(x.nestedExpression()); }
00201 };
00202
00203
00204 template<typename Scalar, typename NestedXpr>
00205 struct blas_traits<CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> >
00206 : blas_traits<NestedXpr>
00207 {
00208 typedef blas_traits<NestedXpr> Base;
00209 typedef CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> XprType;
00210 typedef typename Base::ExtractType ExtractType;
00211 static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
00212 static inline Scalar extractScalarFactor(const XprType& x)
00213 { return - Base::extractScalarFactor(x.nestedExpression()); }
00214 };
00215
00216
00217 template<typename NestedXpr>
00218 struct blas_traits<Transpose<NestedXpr> >
00219 : blas_traits<NestedXpr>
00220 {
00221 typedef typename NestedXpr::Scalar Scalar;
00222 typedef blas_traits<NestedXpr> Base;
00223 typedef Transpose<NestedXpr> XprType;
00224 typedef Transpose<const typename Base::_ExtractType> ExtractType;
00225 typedef Transpose<const typename Base::_ExtractType> _ExtractType;
00226 typedef typename conditional<bool(Base::HasUsableDirectAccess),
00227 ExtractType,
00228 typename ExtractType::PlainObject
00229 >::type DirectLinearAccessType;
00230 enum {
00231 IsTransposed = Base::IsTransposed ? 0 : 1
00232 };
00233 static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
00234 static inline Scalar extractScalarFactor(const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); }
00235 };
00236
00237 template<typename T>
00238 struct blas_traits<const T>
00239 : blas_traits<T>
00240 {};
00241
00242 template<typename T, bool HasUsableDirectAccess=blas_traits<T>::HasUsableDirectAccess>
00243 struct extract_data_selector {
00244 static const typename T::Scalar* run(const T& m)
00245 {
00246 return blas_traits<T>::extract(m).data();
00247 }
00248 };
00249
00250 template<typename T>
00251 struct extract_data_selector<T,false> {
00252 static typename T::Scalar* run(const T&) { return 0; }
00253 };
00254
00255 template<typename T> const typename T::Scalar* extract_data(const T& m)
00256 {
00257 return extract_data_selector<T>::run(m);
00258 }
00259
00260 }
00261
00262 }
00263
00264 #endif // EIGEN_BLASUTIL_H