TriangularMatrixMatrix_MKL.h
Go to the documentation of this file.
00001 /*
00002  Copyright (c) 2011, Intel Corporation. All rights reserved.
00003 
00004  Redistribution and use in source and binary forms, with or without modification,
00005  are permitted provided that the following conditions are met:
00006 
00007  * Redistributions of source code must retain the above copyright notice, this
00008    list of conditions and the following disclaimer.
00009  * Redistributions in binary form must reproduce the above copyright notice,
00010    this list of conditions and the following disclaimer in the documentation
00011    and/or other materials provided with the distribution.
00012  * Neither the name of Intel Corporation nor the names of its contributors may
00013    be used to endorse or promote products derived from this software without
00014    specific prior written permission.
00015 
00016  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
00017  ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00018  WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00019  DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
00020  ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00021  (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00022  LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00023  ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00024  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00025  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00026 
00027  ********************************************************************************
00028  *   Content : Eigen bindings to Intel(R) MKL
00029  *   Triangular matrix * matrix product functionality based on ?TRMM.
00030  ********************************************************************************
00031 */
00032 
00033 #ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H
00034 #define EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H
00035 
00036 namespace Eigen { 
00037 
00038 namespace internal {
00039 
00040 
00041 template <typename Scalar, typename Index,
00042           int Mode, bool LhsIsTriangular,
00043           int LhsStorageOrder, bool ConjugateLhs,
00044           int RhsStorageOrder, bool ConjugateRhs,
00045           int ResStorageOrder>
00046 struct product_triangular_matrix_matrix_trmm :
00047        product_triangular_matrix_matrix<Scalar,Index,Mode,
00048           LhsIsTriangular,LhsStorageOrder,ConjugateLhs,
00049           RhsStorageOrder, ConjugateRhs, ResStorageOrder, BuiltIn> {};
00050 
00051 
00052 // try to go to BLAS specialization
00053 #define EIGEN_MKL_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \
00054 template <typename Index, int Mode, \
00055           int LhsStorageOrder, bool ConjugateLhs, \
00056           int RhsStorageOrder, bool ConjugateRhs> \
00057 struct product_triangular_matrix_matrix<Scalar,Index, Mode, LhsIsTriangular, \
00058            LhsStorageOrder,ConjugateLhs, RhsStorageOrder,ConjugateRhs,ColMajor,Specialized> { \
00059   static inline void run(Index _rows, Index _cols, Index _depth, const Scalar* _lhs, Index lhsStride,\
00060     const Scalar* _rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha, level3_blocking<Scalar,Scalar>& blocking) { \
00061       product_triangular_matrix_matrix_trmm<Scalar,Index,Mode, \
00062         LhsIsTriangular,LhsStorageOrder,ConjugateLhs, \
00063         RhsStorageOrder, ConjugateRhs, ColMajor>::run( \
00064         _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
00065   } \
00066 };
00067 
00068 EIGEN_MKL_TRMM_SPECIALIZE(double, true)
00069 EIGEN_MKL_TRMM_SPECIALIZE(double, false)
00070 EIGEN_MKL_TRMM_SPECIALIZE(dcomplex, true)
00071 EIGEN_MKL_TRMM_SPECIALIZE(dcomplex, false)
00072 EIGEN_MKL_TRMM_SPECIALIZE(float, true)
00073 EIGEN_MKL_TRMM_SPECIALIZE(float, false)
00074 EIGEN_MKL_TRMM_SPECIALIZE(scomplex, true)
00075 EIGEN_MKL_TRMM_SPECIALIZE(scomplex, false)
00076 
00077 // implements col-major += alpha * op(triangular) * op(general)
00078 #define EIGEN_MKL_TRMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
00079 template <typename Index, int Mode, \
00080           int LhsStorageOrder, bool ConjugateLhs, \
00081           int RhsStorageOrder, bool ConjugateRhs> \
00082 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
00083          LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \
00084 { \
00085   enum { \
00086     IsLower = (Mode&Lower) == Lower, \
00087     SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
00088     IsUnitDiag  = (Mode&UnitDiag) ? 1 : 0, \
00089     IsZeroDiag  = (Mode&ZeroDiag) ? 1 : 0, \
00090     LowUp = IsLower ? Lower : Upper, \
00091     conjA = ((LhsStorageOrder==ColMajor) && ConjugateLhs) ? 1 : 0 \
00092   }; \
00093 \
00094   static void run( \
00095     Index _rows, Index _cols, Index _depth, \
00096     const EIGTYPE* _lhs, Index lhsStride, \
00097     const EIGTYPE* _rhs, Index rhsStride, \
00098     EIGTYPE* res,        Index resStride, \
00099     EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \
00100   { \
00101    Index diagSize  = (std::min)(_rows,_depth); \
00102    Index rows      = IsLower ? _rows : diagSize; \
00103    Index depth     = IsLower ? diagSize : _depth; \
00104    Index cols      = _cols; \
00105 \
00106    typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
00107    typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
00108 \
00109 /* Non-square case - doesn't fit to MKL ?TRMM. Fall to default triangular product or call MKL ?GEMM*/ \
00110    if (rows != depth) { \
00111 \
00112      int nthr = mkl_domain_get_max_threads(MKL_BLAS); \
00113 \
00114      if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \
00115      /* Most likely no benefit to call TRMM or GEMM from MKL*/ \
00116        product_triangular_matrix_matrix<EIGTYPE,Index,Mode,true, \
00117        LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \
00118            _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
00119      /*std::cout << "TRMM_L: A is not square! Go to Eigen TRMM implementation!\n";*/ \
00120      } else { \
00121      /* Make sense to call GEMM */ \
00122        Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \
00123        MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \
00124        MKL_INT aStride = aa_tmp.outerStride(); \
00125        gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth); \
00126        general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
00127        rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \
00128 \
00129      /*std::cout << "TRMM_L: A is not square! Go to MKL GEMM implementation! " << nthr<<" \n";*/ \
00130      } \
00131      return; \
00132    } \
00133    char side = 'L', transa, uplo, diag = 'N'; \
00134    EIGTYPE *b; \
00135    const EIGTYPE *a; \
00136    MKL_INT m, n, lda, ldb; \
00137    MKLTYPE alpha_; \
00138 \
00139 /* Set alpha_*/ \
00140    assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
00141 \
00142 /* Set m, n */ \
00143    m = (MKL_INT)diagSize; \
00144    n = (MKL_INT)cols; \
00145 \
00146 /* Set trans */ \
00147    transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \
00148 \
00149 /* Set b, ldb */ \
00150    Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols,OuterStride<>(rhsStride)); \
00151    MatrixX##EIGPREFIX b_tmp; \
00152 \
00153    if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \
00154    b = b_tmp.data(); \
00155    ldb = b_tmp.outerStride(); \
00156 \
00157 /* Set uplo */ \
00158    uplo = IsLower ? 'L' : 'U'; \
00159    if (LhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
00160 /* Set a, lda */ \
00161    Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \
00162    MatrixLhs a_tmp; \
00163 \
00164    if ((conjA!=0) || (SetDiag==0)) { \
00165      if (conjA) a_tmp = lhs.conjugate(); else a_tmp = lhs; \
00166      if (IsZeroDiag) \
00167        a_tmp.diagonal().setZero(); \
00168      else if (IsUnitDiag) \
00169        a_tmp.diagonal().setOnes();\
00170      a = a_tmp.data(); \
00171      lda = a_tmp.outerStride(); \
00172    } else { \
00173      a = _lhs; \
00174      lda = lhsStride; \
00175    } \
00176    /*std::cout << "TRMM_L: A is square! Go to MKL TRMM implementation! \n";*/ \
00177 /* call ?trmm*/ \
00178    MKLPREFIX##trmm(&side, &uplo, &transa, &diag, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \
00179 \
00180 /* Add op(a_triangular)*b into res*/ \
00181    Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
00182    res_tmp=res_tmp+b_tmp; \
00183   } \
00184 };
00185 
00186 EIGEN_MKL_TRMM_L(double, double, d, d)
00187 EIGEN_MKL_TRMM_L(dcomplex, MKL_Complex16, cd, z)
00188 EIGEN_MKL_TRMM_L(float, float, f, s)
00189 EIGEN_MKL_TRMM_L(scomplex, MKL_Complex8, cf, c)
00190 
00191 // implements col-major += alpha * op(general) * op(triangular)
00192 #define EIGEN_MKL_TRMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
00193 template <typename Index, int Mode, \
00194           int LhsStorageOrder, bool ConjugateLhs, \
00195           int RhsStorageOrder, bool ConjugateRhs> \
00196 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
00197          LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \
00198 { \
00199   enum { \
00200     IsLower = (Mode&Lower) == Lower, \
00201     SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
00202     IsUnitDiag  = (Mode&UnitDiag) ? 1 : 0, \
00203     IsZeroDiag  = (Mode&ZeroDiag) ? 1 : 0, \
00204     LowUp = IsLower ? Lower : Upper, \
00205     conjA = ((RhsStorageOrder==ColMajor) && ConjugateRhs) ? 1 : 0 \
00206   }; \
00207 \
00208   static void run( \
00209     Index _rows, Index _cols, Index _depth, \
00210     const EIGTYPE* _lhs, Index lhsStride, \
00211     const EIGTYPE* _rhs, Index rhsStride, \
00212     EIGTYPE* res,        Index resStride, \
00213     EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \
00214   { \
00215    Index diagSize  = (std::min)(_cols,_depth); \
00216    Index rows      = _rows; \
00217    Index depth     = IsLower ? _depth : diagSize; \
00218    Index cols      = IsLower ? diagSize : _cols; \
00219 \
00220    typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
00221    typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
00222 \
00223 /* Non-square case - doesn't fit to MKL ?TRMM. Fall to default triangular product or call MKL ?GEMM*/ \
00224    if (cols != depth) { \
00225 \
00226      int nthr = mkl_domain_get_max_threads(MKL_BLAS); \
00227 \
00228      if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \
00229      /* Most likely no benefit to call TRMM or GEMM from MKL*/ \
00230        product_triangular_matrix_matrix<EIGTYPE,Index,Mode,false, \
00231        LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \
00232            _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
00233        /*std::cout << "TRMM_R: A is not square! Go to Eigen TRMM implementation!\n";*/ \
00234      } else { \
00235      /* Make sense to call GEMM */ \
00236        Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \
00237        MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \
00238        MKL_INT aStride = aa_tmp.outerStride(); \
00239        gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth); \
00240        general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
00241        rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \
00242 \
00243      /*std::cout << "TRMM_R: A is not square! Go to MKL GEMM implementation! " << nthr<<" \n";*/ \
00244      } \
00245      return; \
00246    } \
00247    char side = 'R', transa, uplo, diag = 'N'; \
00248    EIGTYPE *b; \
00249    const EIGTYPE *a; \
00250    MKL_INT m, n, lda, ldb; \
00251    MKLTYPE alpha_; \
00252 \
00253 /* Set alpha_*/ \
00254    assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
00255 \
00256 /* Set m, n */ \
00257    m = (MKL_INT)rows; \
00258    n = (MKL_INT)diagSize; \
00259 \
00260 /* Set trans */ \
00261    transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
00262 \
00263 /* Set b, ldb */ \
00264    Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \
00265    MatrixX##EIGPREFIX b_tmp; \
00266 \
00267    if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \
00268    b = b_tmp.data(); \
00269    ldb = b_tmp.outerStride(); \
00270 \
00271 /* Set uplo */ \
00272    uplo = IsLower ? 'L' : 'U'; \
00273    if (RhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
00274 /* Set a, lda */ \
00275    Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols, OuterStride<>(rhsStride)); \
00276    MatrixRhs a_tmp; \
00277 \
00278    if ((conjA!=0) || (SetDiag==0)) { \
00279      if (conjA) a_tmp = rhs.conjugate(); else a_tmp = rhs; \
00280      if (IsZeroDiag) \
00281        a_tmp.diagonal().setZero(); \
00282      else if (IsUnitDiag) \
00283        a_tmp.diagonal().setOnes();\
00284      a = a_tmp.data(); \
00285      lda = a_tmp.outerStride(); \
00286    } else { \
00287      a = _rhs; \
00288      lda = rhsStride; \
00289    } \
00290    /*std::cout << "TRMM_R: A is square! Go to MKL TRMM implementation! \n";*/ \
00291 /* call ?trmm*/ \
00292    MKLPREFIX##trmm(&side, &uplo, &transa, &diag, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \
00293 \
00294 /* Add op(a_triangular)*b into res*/ \
00295    Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
00296    res_tmp=res_tmp+b_tmp; \
00297   } \
00298 };
00299 
00300 EIGEN_MKL_TRMM_R(double, double, d, d)
00301 EIGEN_MKL_TRMM_R(dcomplex, MKL_Complex16, cd, z)
00302 EIGEN_MKL_TRMM_R(float, float, f, s)
00303 EIGEN_MKL_TRMM_R(scomplex, MKL_Complex8, cf, c)
00304 
00305 } // end namespace internal
00306 
00307 } // end namespace Eigen
00308 
00309 #endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H


acado
Author(s): Milan Vukov, Rien Quirynen
autogenerated on Thu Aug 27 2015 12:01:17