TriangularMatrixMatrix_BLAS.h
Go to the documentation of this file.
1 /*
2  Copyright (c) 2011, Intel Corporation. All rights reserved.
3 
4  Redistribution and use in source and binary forms, with or without modification,
5  are permitted provided that the following conditions are met:
6 
7  * Redistributions of source code must retain the above copyright notice, this
8  list of conditions and the following disclaimer.
9  * Redistributions in binary form must reproduce the above copyright notice,
10  this list of conditions and the following disclaimer in the documentation
11  and/or other materials provided with the distribution.
12  * Neither the name of Intel Corporation nor the names of its contributors may
13  be used to endorse or promote products derived from this software without
14  specific prior written permission.
15 
16  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
20  ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
23  ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 
27  ********************************************************************************
28  * Content : Eigen bindings to BLAS F77
29  * Triangular matrix * matrix product functionality based on ?TRMM.
30  ********************************************************************************
31 */
32 
33 #ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
34 #define EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
35 
36 namespace Eigen {
37 
38 namespace internal {
39 
40 
41 template <typename Scalar, typename Index,
42  int Mode, bool LhsIsTriangular,
43  int LhsStorageOrder, bool ConjugateLhs,
44  int RhsStorageOrder, bool ConjugateRhs,
45  int ResStorageOrder>
47  product_triangular_matrix_matrix<Scalar,Index,Mode,
48  LhsIsTriangular,LhsStorageOrder,ConjugateLhs,
49  RhsStorageOrder, ConjugateRhs, ResStorageOrder, 1, BuiltIn> {};
50 
51 
52 // try to go to BLAS specialization
53 #define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \
54 template <typename Index, int Mode, \
55  int LhsStorageOrder, bool ConjugateLhs, \
56  int RhsStorageOrder, bool ConjugateRhs> \
57 struct product_triangular_matrix_matrix<Scalar,Index, Mode, LhsIsTriangular, \
58  LhsStorageOrder,ConjugateLhs, RhsStorageOrder,ConjugateRhs,ColMajor,1,Specialized> { \
59  static inline void run(Index _rows, Index _cols, Index _depth, const Scalar* _lhs, Index lhsStride,\
60  const Scalar* _rhs, Index rhsStride, Scalar* res, Index resIncr, Index resStride, Scalar alpha, level3_blocking<Scalar,Scalar>& blocking) { \
61  EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
62  eigen_assert(resIncr == 1); \
63  product_triangular_matrix_matrix_trmm<Scalar,Index,Mode, \
64  LhsIsTriangular,LhsStorageOrder,ConjugateLhs, \
65  RhsStorageOrder, ConjugateRhs, ColMajor>::run( \
66  _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
67  } \
68 };
69 
70 EIGEN_BLAS_TRMM_SPECIALIZE(double, true)
71 EIGEN_BLAS_TRMM_SPECIALIZE(double, false)
73 EIGEN_BLAS_TRMM_SPECIALIZE(dcomplex, false)
74 EIGEN_BLAS_TRMM_SPECIALIZE(float, true)
75 EIGEN_BLAS_TRMM_SPECIALIZE(float, false)
77 EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, false)
78 
79 // implements col-major += alpha * op(triangular) * op(general)
80 #define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
81 template <typename Index, int Mode, \
82  int LhsStorageOrder, bool ConjugateLhs, \
83  int RhsStorageOrder, bool ConjugateRhs> \
84 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
85  LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \
86 { \
87  enum { \
88  IsLower = (Mode&Lower) == Lower, \
89  SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
90  IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
91  IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
92  LowUp = IsLower ? Lower : Upper, \
93  conjA = ((LhsStorageOrder==ColMajor) && ConjugateLhs) ? 1 : 0 \
94  }; \
95 \
96  static void run( \
97  Index _rows, Index _cols, Index _depth, \
98  const EIGTYPE* _lhs, Index lhsStride, \
99  const EIGTYPE* _rhs, Index rhsStride, \
100  EIGTYPE* res, Index resStride, \
101  EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \
102  { \
103  Index diagSize = (std::min)(_rows,_depth); \
104  Index rows = IsLower ? _rows : diagSize; \
105  Index depth = IsLower ? diagSize : _depth; \
106  Index cols = _cols; \
107 \
108  typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
109  typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
110 \
111 /* Non-square case - doesn't fit to BLAS ?TRMM. Fall to default triangular product or call BLAS ?GEMM*/ \
112  if (rows != depth) { \
113 \
114  /* FIXME handle mkl_domain_get_max_threads */ \
115  /*int nthr = mkl_domain_get_max_threads(EIGEN_BLAS_DOMAIN_BLAS);*/ int nthr = 1;\
116 \
117  if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \
118  /* Most likely no benefit to call TRMM or GEMM from BLAS */ \
119  product_triangular_matrix_matrix<EIGTYPE,Index,Mode,true, \
120  LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, 1, BuiltIn>::run( \
121  _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, 1, resStride, alpha, blocking); \
122  /*std::cout << "TRMM_L: A is not square! Go to Eigen TRMM implementation!\n";*/ \
123  } else { \
124  /* Make sense to call GEMM */ \
125  Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \
126  MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \
127  BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
128  gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
129  general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1>::run( \
130  rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, 1, resStride, alpha, gemm_blocking, 0); \
131 \
132  /*std::cout << "TRMM_L: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
133  } \
134  return; \
135  } \
136  char side = 'L', transa, uplo, diag = 'N'; \
137  EIGTYPE *b; \
138  const EIGTYPE *a; \
139  BlasIndex m, n, lda, ldb; \
140 \
141 /* Set m, n */ \
142  m = convert_index<BlasIndex>(diagSize); \
143  n = convert_index<BlasIndex>(cols); \
144 \
145 /* Set trans */ \
146  transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \
147 \
148 /* Set b, ldb */ \
149  Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols,OuterStride<>(rhsStride)); \
150  MatrixX##EIGPREFIX b_tmp; \
151 \
152  if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \
153  b = b_tmp.data(); \
154  ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
155 \
156 /* Set uplo */ \
157  uplo = IsLower ? 'L' : 'U'; \
158  if (LhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
159 /* Set a, lda */ \
160  Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \
161  MatrixLhs a_tmp; \
162 \
163  if ((conjA!=0) || (SetDiag==0)) { \
164  if (conjA) a_tmp = lhs.conjugate(); else a_tmp = lhs; \
165  if (IsZeroDiag) \
166  a_tmp.diagonal().setZero(); \
167  else if (IsUnitDiag) \
168  a_tmp.diagonal().setOnes();\
169  a = a_tmp.data(); \
170  lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
171  } else { \
172  a = _lhs; \
173  lda = convert_index<BlasIndex>(lhsStride); \
174  } \
175  /*std::cout << "TRMM_L: A is square! Go to BLAS TRMM implementation! \n";*/ \
176 /* call ?trmm*/ \
177  BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
178 \
179 /* Add op(a_triangular)*b into res*/ \
180  Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
181  res_tmp=res_tmp+b_tmp; \
182  } \
183 };
184 
185 #ifdef EIGEN_USE_MKL
186 EIGEN_BLAS_TRMM_L(double, double, d, dtrmm)
187 EIGEN_BLAS_TRMM_L(dcomplex, MKL_Complex16, cd, ztrmm)
188 EIGEN_BLAS_TRMM_L(float, float, f, strmm)
189 EIGEN_BLAS_TRMM_L(scomplex, MKL_Complex8, cf, ctrmm)
190 #else
191 EIGEN_BLAS_TRMM_L(double, double, d, dtrmm_)
192 EIGEN_BLAS_TRMM_L(dcomplex, double, cd, ztrmm_)
193 EIGEN_BLAS_TRMM_L(float, float, f, strmm_)
194 EIGEN_BLAS_TRMM_L(scomplex, float, cf, ctrmm_)
195 #endif
196 
197 // implements col-major += alpha * op(general) * op(triangular)
198 #define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
199 template <typename Index, int Mode, \
200  int LhsStorageOrder, bool ConjugateLhs, \
201  int RhsStorageOrder, bool ConjugateRhs> \
202 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
203  LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \
204 { \
205  enum { \
206  IsLower = (Mode&Lower) == Lower, \
207  SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
208  IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
209  IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
210  LowUp = IsLower ? Lower : Upper, \
211  conjA = ((RhsStorageOrder==ColMajor) && ConjugateRhs) ? 1 : 0 \
212  }; \
213 \
214  static void run( \
215  Index _rows, Index _cols, Index _depth, \
216  const EIGTYPE* _lhs, Index lhsStride, \
217  const EIGTYPE* _rhs, Index rhsStride, \
218  EIGTYPE* res, Index resStride, \
219  EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \
220  { \
221  Index diagSize = (std::min)(_cols,_depth); \
222  Index rows = _rows; \
223  Index depth = IsLower ? _depth : diagSize; \
224  Index cols = IsLower ? diagSize : _cols; \
225 \
226  typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
227  typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
228 \
229 /* Non-square case - doesn't fit to BLAS ?TRMM. Fall to default triangular product or call BLAS ?GEMM*/ \
230  if (cols != depth) { \
231 \
232  int nthr = 1 /*mkl_domain_get_max_threads(EIGEN_BLAS_DOMAIN_BLAS)*/; \
233 \
234  if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \
235  /* Most likely no benefit to call TRMM or GEMM from BLAS*/ \
236  product_triangular_matrix_matrix<EIGTYPE,Index,Mode,false, \
237  LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, 1, BuiltIn>::run( \
238  _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, 1, resStride, alpha, blocking); \
239  /*std::cout << "TRMM_R: A is not square! Go to Eigen TRMM implementation!\n";*/ \
240  } else { \
241  /* Make sense to call GEMM */ \
242  Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \
243  MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \
244  BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
245  gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
246  general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1>::run( \
247  rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, 1, resStride, alpha, gemm_blocking, 0); \
248 \
249  /*std::cout << "TRMM_R: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
250  } \
251  return; \
252  } \
253  char side = 'R', transa, uplo, diag = 'N'; \
254  EIGTYPE *b; \
255  const EIGTYPE *a; \
256  BlasIndex m, n, lda, ldb; \
257 \
258 /* Set m, n */ \
259  m = convert_index<BlasIndex>(rows); \
260  n = convert_index<BlasIndex>(diagSize); \
261 \
262 /* Set trans */ \
263  transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
264 \
265 /* Set b, ldb */ \
266  Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \
267  MatrixX##EIGPREFIX b_tmp; \
268 \
269  if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \
270  b = b_tmp.data(); \
271  ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
272 \
273 /* Set uplo */ \
274  uplo = IsLower ? 'L' : 'U'; \
275  if (RhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
276 /* Set a, lda */ \
277  Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols, OuterStride<>(rhsStride)); \
278  MatrixRhs a_tmp; \
279 \
280  if ((conjA!=0) || (SetDiag==0)) { \
281  if (conjA) a_tmp = rhs.conjugate(); else a_tmp = rhs; \
282  if (IsZeroDiag) \
283  a_tmp.diagonal().setZero(); \
284  else if (IsUnitDiag) \
285  a_tmp.diagonal().setOnes();\
286  a = a_tmp.data(); \
287  lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
288  } else { \
289  a = _rhs; \
290  lda = convert_index<BlasIndex>(rhsStride); \
291  } \
292  /*std::cout << "TRMM_R: A is square! Go to BLAS TRMM implementation! \n";*/ \
293 /* call ?trmm*/ \
294  BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
295 \
296 /* Add op(a_triangular)*b into res*/ \
297  Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
298  res_tmp=res_tmp+b_tmp; \
299  } \
300 };
301 
302 #ifdef EIGEN_USE_MKL
303 EIGEN_BLAS_TRMM_R(double, double, d, dtrmm)
304 EIGEN_BLAS_TRMM_R(dcomplex, MKL_Complex16, cd, ztrmm)
305 EIGEN_BLAS_TRMM_R(float, float, f, strmm)
306 EIGEN_BLAS_TRMM_R(scomplex, MKL_Complex8, cf, ctrmm)
307 #else
308 EIGEN_BLAS_TRMM_R(double, double, d, dtrmm_)
309 EIGEN_BLAS_TRMM_R(dcomplex, double, cd, ztrmm_)
310 EIGEN_BLAS_TRMM_R(float, float, f, strmm_)
311 EIGEN_BLAS_TRMM_R(scomplex, float, cf, ctrmm_)
312 #endif
313 } // end namespace internal
314 
315 } // end namespace Eigen
316 
317 #endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
int BLASFUNC() strmm(char *, char *, char *, char *, int *, int *, float *, float *, int *, float *, int *)
SCALAR Scalar
Definition: bench_gemm.cpp:46
int BLASFUNC() ztrmm(char *, char *, char *, char *, int *, int *, double *, double *, int *, double *, int *)
Namespace containing all symbols from the Eigen library.
Definition: jet.h:637
#define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC)
std::complex< float > scomplex
Definition: MKL_support.h:126
std::complex< double > dcomplex
Definition: MKL_support.h:125
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:74
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
int BLASFUNC() dtrmm(char *, char *, char *, char *, int *, int *, double *, double *, int *, double *, int *)
#define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC)
int BLASFUNC() ctrmm(char *, char *, char *, char *, int *, int *, float *, float *, int *, float *, int *)
#define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular)


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:40:32