level3_impl.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 #include <iostream>
10 #include "common.h"
11 
12 int EIGEN_BLAS_FUNC(gemm)(const char *opa, const char *opb, const int *m, const int *n, const int *k, const RealScalar *palpha,
13  const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
14 {
15 // std::cerr << "in gemm " << *opa << " " << *opb << " " << *m << " " << *n << " " << *k << " " << *lda << " " << *ldb << " " << *ldc << " " << *palpha << " " << *pbeta << "\n";
16  typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, Scalar, internal::level3_blocking<Scalar,Scalar>&, Eigen::internal::GemmParallelInfo<DenseIndex>*);
17  static const functype func[12] = {
18  // array index: NOTR | (NOTR << 2)
20  // array index: TR | (NOTR << 2)
22  // array index: ADJ | (NOTR << 2)
24  0,
25  // array index: NOTR | (TR << 2)
27  // array index: TR | (TR << 2)
29  // array index: ADJ | (TR << 2)
31  0,
32  // array index: NOTR | (ADJ << 2)
34  // array index: TR | (ADJ << 2)
36  // array index: ADJ | (ADJ << 2)
38  0
39  };
40 
41  const Scalar* a = reinterpret_cast<const Scalar*>(pa);
42  const Scalar* b = reinterpret_cast<const Scalar*>(pb);
43  Scalar* c = reinterpret_cast<Scalar*>(pc);
44  Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
45  Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
46 
47  int info = 0;
48  if(OP(*opa)==INVALID) info = 1;
49  else if(OP(*opb)==INVALID) info = 2;
50  else if(*m<0) info = 3;
51  else if(*n<0) info = 4;
52  else if(*k<0) info = 5;
53  else if(*lda<std::max(1,(OP(*opa)==NOTR)?*m:*k)) info = 8;
54  else if(*ldb<std::max(1,(OP(*opb)==NOTR)?*k:*n)) info = 10;
55  else if(*ldc<std::max(1,*m)) info = 13;
56  if(info)
57  return xerbla_(SCALAR_SUFFIX_UP"GEMM ",&info,6);
58 
59  if (*m == 0 || *n == 0)
60  return 0;
61 
62  if(beta!=Scalar(1))
63  {
64  if(beta==Scalar(0)) matrix(c, *m, *n, *ldc).setZero();
65  else matrix(c, *m, *n, *ldc) *= beta;
66  }
67 
68  if(*k == 0)
69  return 0;
70 
71  internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*m,*n,*k,1,true);
72 
73  int code = OP(*opa) | (OP(*opb) << 2);
74  func[code](*m, *n, *k, a, *lda, b, *ldb, c, *ldc, alpha, blocking, 0);
75  return 0;
76 }
77 
78 int EIGEN_BLAS_FUNC(trsm)(const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n,
79  const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb)
80 {
81 // std::cerr << "in trsm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << "," << *n << " " << *palpha << " " << *lda << " " << *ldb<< "\n";
82  typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, internal::level3_blocking<Scalar,Scalar>&);
83  static const functype func[32] = {
84  // array index: NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)
86  // array index: TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)
88  // array index: ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4)
90  0,
91  // array index: NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)
93  // array index: TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)
95  // array index: ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)
97  0,
98  // array index: NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)
100  // array index: TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)
102  // array index: ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4)
104  0,
105  // array index: NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)
107  // array index: TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)
109  // array index: ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)
111  0,
112  // array index: NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4)
114  // array index: TR | (LEFT << 2) | (UP << 3) | (UNIT << 4)
116  // array index: ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4)
118  0,
119  // array index: NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)
121  // array index: TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)
123  // array index: ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4)
125  0,
126  // array index: NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4)
128  // array index: TR | (LEFT << 2) | (LO << 3) | (UNIT << 4)
130  // array index: ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4)
132  0,
133  // array index: NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)
135  // array index: TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)
137  // array index: ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4)
139  0
140  };
141 
142  const Scalar* a = reinterpret_cast<const Scalar*>(pa);
143  Scalar* b = reinterpret_cast<Scalar*>(pb);
144  Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
145 
146  int info = 0;
147  if(SIDE(*side)==INVALID) info = 1;
148  else if(UPLO(*uplo)==INVALID) info = 2;
149  else if(OP(*opa)==INVALID) info = 3;
150  else if(DIAG(*diag)==INVALID) info = 4;
151  else if(*m<0) info = 5;
152  else if(*n<0) info = 6;
153  else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 9;
154  else if(*ldb<std::max(1,*m)) info = 11;
155  if(info)
156  return xerbla_(SCALAR_SUFFIX_UP"TRSM ",&info,6);
157 
158  if(*m==0 || *n==0)
159  return 0;
160 
161  int code = OP(*opa) | (SIDE(*side) << 2) | (UPLO(*uplo) << 3) | (DIAG(*diag) << 4);
162 
163  if(SIDE(*side)==LEFT)
164  {
165  internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*m,1,false);
166  func[code](*m, *n, a, *lda, b, *ldb, blocking);
167  }
168  else
169  {
170  internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*n,1,false);
171  func[code](*n, *m, a, *lda, b, *ldb, blocking);
172  }
173 
174  if(alpha!=Scalar(1))
175  matrix(b,*m,*n,*ldb) *= alpha;
176 
177  return 0;
178 }
179 
180 
181 // b = alpha*op(a)*b for side = 'L'or'l'
182 // b = alpha*b*op(a) for side = 'R'or'r'
183 int EIGEN_BLAS_FUNC(trmm)(const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n,
184  const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb)
185 {
186 // std::cerr << "in trmm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << " " << *n << " " << *lda << " " << *ldb << " " << *palpha << "\n";
187  typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, const Scalar&, internal::level3_blocking<Scalar,Scalar>&);
188  static const functype func[32] = {
189  // array index: NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)
191  // array index: TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)
193  // array index: ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4)
195  0,
196  // array index: NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)
198  // array index: TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)
200  // array index: ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)
202  0,
203  // array index: NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)
205  // array index: TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)
207  // array index: ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4)
209  0,
210  // array index: NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)
212  // array index: TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)
214  // array index: ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)
216  0,
217  // array index: NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4)
219  // array index: TR | (LEFT << 2) | (UP << 3) | (UNIT << 4)
221  // array index: ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4)
223  0,
224  // array index: NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)
226  // array index: TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)
228  // array index: ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4)
230  0,
231  // array index: NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4)
233  // array index: TR | (LEFT << 2) | (LO << 3) | (UNIT << 4)
235  // array index: ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4)
237  0,
238  // array index: NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)
240  // array index: TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)
242  // array index: ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4)
244  0
245  };
246 
247  const Scalar* a = reinterpret_cast<const Scalar*>(pa);
248  Scalar* b = reinterpret_cast<Scalar*>(pb);
249  Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
250 
251  int info = 0;
252  if(SIDE(*side)==INVALID) info = 1;
253  else if(UPLO(*uplo)==INVALID) info = 2;
254  else if(OP(*opa)==INVALID) info = 3;
255  else if(DIAG(*diag)==INVALID) info = 4;
256  else if(*m<0) info = 5;
257  else if(*n<0) info = 6;
258  else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 9;
259  else if(*ldb<std::max(1,*m)) info = 11;
260  if(info)
261  return xerbla_(SCALAR_SUFFIX_UP"TRMM ",&info,6);
262 
263  int code = OP(*opa) | (SIDE(*side) << 2) | (UPLO(*uplo) << 3) | (DIAG(*diag) << 4);
264 
265  if(*m==0 || *n==0)
266  return 1;
267 
268  // FIXME find a way to avoid this copy
270  matrix(b,*m,*n,*ldb).setZero();
271 
272  if(SIDE(*side)==LEFT)
273  {
274  internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*m,1,false);
275  func[code](*m, *n, *m, a, *lda, tmp.data(), tmp.outerStride(), b, *ldb, alpha, blocking);
276  }
277  else
278  {
279  internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*n,1,false);
280  func[code](*m, *n, *n, tmp.data(), tmp.outerStride(), a, *lda, b, *ldb, alpha, blocking);
281  }
282  return 1;
283 }
284 
285 // c = alpha*a*b + beta*c for side = 'L'or'l'
286 // c = alpha*b*a + beta*c for side = 'R'or'r
287 int EIGEN_BLAS_FUNC(symm)(const char *side, const char *uplo, const int *m, const int *n, const RealScalar *palpha,
288  const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
289 {
290 // std::cerr << "in symm " << *side << " " << *uplo << " " << *m << "x" << *n << " lda:" << *lda << " ldb:" << *ldb << " ldc:" << *ldc << " alpha:" << *palpha << " beta:" << *pbeta << "\n";
291  const Scalar* a = reinterpret_cast<const Scalar*>(pa);
292  const Scalar* b = reinterpret_cast<const Scalar*>(pb);
293  Scalar* c = reinterpret_cast<Scalar*>(pc);
294  Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
295  Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
296 
297  int info = 0;
298  if(SIDE(*side)==INVALID) info = 1;
299  else if(UPLO(*uplo)==INVALID) info = 2;
300  else if(*m<0) info = 3;
301  else if(*n<0) info = 4;
302  else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 7;
303  else if(*ldb<std::max(1,*m)) info = 9;
304  else if(*ldc<std::max(1,*m)) info = 12;
305  if(info)
306  return xerbla_(SCALAR_SUFFIX_UP"SYMM ",&info,6);
307 
308  if(beta!=Scalar(1))
309  {
310  if(beta==Scalar(0)) matrix(c, *m, *n, *ldc).setZero();
311  else matrix(c, *m, *n, *ldc) *= beta;
312  }
313 
314  if(*m==0 || *n==0)
315  {
316  return 1;
317  }
318 
319  int size = (SIDE(*side)==LEFT) ? (*m) : (*n);
320  #if ISCOMPLEX
321  // FIXME add support for symmetric complex matrix
323  if(UPLO(*uplo)==UP)
324  {
325  matA.triangularView<Upper>() = matrix(a,size,size,*lda);
326  matA.triangularView<Lower>() = matrix(a,size,size,*lda).transpose();
327  }
328  else if(UPLO(*uplo)==LO)
329  {
330  matA.triangularView<Lower>() = matrix(a,size,size,*lda);
331  matA.triangularView<Upper>() = matrix(a,size,size,*lda).transpose();
332  }
333  if(SIDE(*side)==LEFT)
334  matrix(c, *m, *n, *ldc) += alpha * matA * matrix(b, *m, *n, *ldb);
335  else if(SIDE(*side)==RIGHT)
336  matrix(c, *m, *n, *ldc) += alpha * matrix(b, *m, *n, *ldb) * matA;
337  #else
338  internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*m,*n,size,1,false);
339 
340  if(SIDE(*side)==LEFT)
343  else return 0;
344  else if(SIDE(*side)==RIGHT)
347  else return 0;
348  else
349  return 0;
350  #endif
351 
352  return 0;
353 }
354 
355 // c = alpha*a*a' + beta*c for op = 'N'or'n'
356 // c = alpha*a'*a + beta*c for op = 'T'or't','C'or'c'
357 int EIGEN_BLAS_FUNC(syrk)(const char *uplo, const char *op, const int *n, const int *k,
358  const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
359 {
360 // std::cerr << "in syrk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " << *pbeta << " " << *ldc << "\n";
361  #if !ISCOMPLEX
362  typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, const Scalar&, internal::level3_blocking<Scalar,Scalar>&);
363  static const functype func[8] = {
364  // array index: NOTR | (UP << 2)
366  // array index: TR | (UP << 2)
368  // array index: ADJ | (UP << 2)
370  0,
371  // array index: NOTR | (LO << 2)
373  // array index: TR | (LO << 2)
375  // array index: ADJ | (LO << 2)
377  0
378  };
379  #endif
380 
381  const Scalar* a = reinterpret_cast<const Scalar*>(pa);
382  Scalar* c = reinterpret_cast<Scalar*>(pc);
383  Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
384  Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
385 
386  int info = 0;
387  if(UPLO(*uplo)==INVALID) info = 1;
388  else if(OP(*op)==INVALID || (ISCOMPLEX && OP(*op)==ADJ) ) info = 2;
389  else if(*n<0) info = 3;
390  else if(*k<0) info = 4;
391  else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7;
392  else if(*ldc<std::max(1,*n)) info = 10;
393  if(info)
394  return xerbla_(SCALAR_SUFFIX_UP"SYRK ",&info,6);
395 
396  if(beta!=Scalar(1))
397  {
398  if(UPLO(*uplo)==UP)
399  if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero();
400  else matrix(c, *n, *n, *ldc).triangularView<Upper>() *= beta;
401  else
402  if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero();
403  else matrix(c, *n, *n, *ldc).triangularView<Lower>() *= beta;
404  }
405 
406  if(*n==0 || *k==0)
407  return 0;
408 
409  #if ISCOMPLEX
410  // FIXME add support for symmetric complex matrix
411  if(UPLO(*uplo)==UP)
412  {
413  if(OP(*op)==NOTR)
414  matrix(c, *n, *n, *ldc).triangularView<Upper>() += alpha * matrix(a,*n,*k,*lda) * matrix(a,*n,*k,*lda).transpose();
415  else
416  matrix(c, *n, *n, *ldc).triangularView<Upper>() += alpha * matrix(a,*k,*n,*lda).transpose() * matrix(a,*k,*n,*lda);
417  }
418  else
419  {
420  if(OP(*op)==NOTR)
421  matrix(c, *n, *n, *ldc).triangularView<Lower>() += alpha * matrix(a,*n,*k,*lda) * matrix(a,*n,*k,*lda).transpose();
422  else
423  matrix(c, *n, *n, *ldc).triangularView<Lower>() += alpha * matrix(a,*k,*n,*lda).transpose() * matrix(a,*k,*n,*lda);
424  }
425  #else
426  internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*n,*n,*k,1,false);
427 
428  int code = OP(*op) | (UPLO(*uplo) << 2);
429  func[code](*n, *k, a, *lda, a, *lda, c, *ldc, alpha, blocking);
430  #endif
431 
432  return 0;
433 }
434 
435 // c = alpha*a*b' + alpha*b*a' + beta*c for op = 'N'or'n'
436 // c = alpha*a'*b + alpha*b'*a + beta*c for op = 'T'or't'
437 int EIGEN_BLAS_FUNC(syr2k)(const char *uplo, const char *op, const int *n, const int *k, const RealScalar *palpha,
438  const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
439 {
440  const Scalar* a = reinterpret_cast<const Scalar*>(pa);
441  const Scalar* b = reinterpret_cast<const Scalar*>(pb);
442  Scalar* c = reinterpret_cast<Scalar*>(pc);
443  Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
444  Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
445 
446 // std::cerr << "in syr2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << *ldb << " " << beta << " " << *ldc << "\n";
447 
448  int info = 0;
449  if(UPLO(*uplo)==INVALID) info = 1;
450  else if(OP(*op)==INVALID || (ISCOMPLEX && OP(*op)==ADJ) ) info = 2;
451  else if(*n<0) info = 3;
452  else if(*k<0) info = 4;
453  else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7;
454  else if(*ldb<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 9;
455  else if(*ldc<std::max(1,*n)) info = 12;
456  if(info)
457  return xerbla_(SCALAR_SUFFIX_UP"SYR2K",&info,6);
458 
459  if(beta!=Scalar(1))
460  {
461  if(UPLO(*uplo)==UP)
462  if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero();
463  else matrix(c, *n, *n, *ldc).triangularView<Upper>() *= beta;
464  else
465  if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero();
466  else matrix(c, *n, *n, *ldc).triangularView<Lower>() *= beta;
467  }
468 
469  if(*k==0)
470  return 1;
471 
472  if(OP(*op)==NOTR)
473  {
474  if(UPLO(*uplo)==UP)
475  {
476  matrix(c, *n, *n, *ldc).triangularView<Upper>()
477  += alpha *matrix(a, *n, *k, *lda)*matrix(b, *n, *k, *ldb).transpose()
478  + alpha*matrix(b, *n, *k, *ldb)*matrix(a, *n, *k, *lda).transpose();
479  }
480  else if(UPLO(*uplo)==LO)
481  matrix(c, *n, *n, *ldc).triangularView<Lower>()
482  += alpha*matrix(a, *n, *k, *lda)*matrix(b, *n, *k, *ldb).transpose()
483  + alpha*matrix(b, *n, *k, *ldb)*matrix(a, *n, *k, *lda).transpose();
484  }
485  else if(OP(*op)==TR || OP(*op)==ADJ)
486  {
487  if(UPLO(*uplo)==UP)
488  matrix(c, *n, *n, *ldc).triangularView<Upper>()
489  += alpha*matrix(a, *k, *n, *lda).transpose()*matrix(b, *k, *n, *ldb)
490  + alpha*matrix(b, *k, *n, *ldb).transpose()*matrix(a, *k, *n, *lda);
491  else if(UPLO(*uplo)==LO)
492  matrix(c, *n, *n, *ldc).triangularView<Lower>()
493  += alpha*matrix(a, *k, *n, *lda).transpose()*matrix(b, *k, *n, *ldb)
494  + alpha*matrix(b, *k, *n, *ldb).transpose()*matrix(a, *k, *n, *lda);
495  }
496 
497  return 0;
498 }
499 
500 
501 #if ISCOMPLEX
502 
503 // c = alpha*a*b + beta*c for side = 'L'or'l'
504 // c = alpha*b*a + beta*c for side = 'R'or'r
505 int EIGEN_BLAS_FUNC(hemm)(const char *side, const char *uplo, const int *m, const int *n, const RealScalar *palpha,
506  const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
507 {
508  const Scalar* a = reinterpret_cast<const Scalar*>(pa);
509  const Scalar* b = reinterpret_cast<const Scalar*>(pb);
510  Scalar* c = reinterpret_cast<Scalar*>(pc);
511  Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
512  Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
513 
514 // std::cerr << "in hemm " << *side << " " << *uplo << " " << *m << " " << *n << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n";
515 
516  int info = 0;
517  if(SIDE(*side)==INVALID) info = 1;
518  else if(UPLO(*uplo)==INVALID) info = 2;
519  else if(*m<0) info = 3;
520  else if(*n<0) info = 4;
521  else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 7;
522  else if(*ldb<std::max(1,*m)) info = 9;
523  else if(*ldc<std::max(1,*m)) info = 12;
524  if(info)
525  return xerbla_(SCALAR_SUFFIX_UP"HEMM ",&info,6);
526 
527  if(beta==Scalar(0)) matrix(c, *m, *n, *ldc).setZero();
528  else if(beta!=Scalar(1)) matrix(c, *m, *n, *ldc) *= beta;
529 
530  if(*m==0 || *n==0)
531  {
532  return 1;
533  }
534 
535  int size = (SIDE(*side)==LEFT) ? (*m) : (*n);
536  internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*m,*n,size,1,false);
537 
538  if(SIDE(*side)==LEFT)
539  {
541  ::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha, blocking);
543  ::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha, blocking);
544  else return 0;
545  }
546  else if(SIDE(*side)==RIGHT)
547  {
548  if(UPLO(*uplo)==UP) matrix(c,*m,*n,*ldc) += alpha * matrix(b,*m,*n,*ldb) * matrix(a,*n,*n,*lda).selfadjointView<Upper>();/*internal::product_selfadjoint_matrix<Scalar,DenseIndex,ColMajor,false,false, RowMajor,true,Conj, ColMajor>
549  ::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha, blocking);*/
550  else if(UPLO(*uplo)==LO) internal::product_selfadjoint_matrix<Scalar,DenseIndex,ColMajor,false,false, ColMajor,true,false, ColMajor>
551  ::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha, blocking);
552  else return 0;
553  }
554  else
555  {
556  return 0;
557  }
558 
559  return 0;
560 }
561 
562 // c = alpha*a*conj(a') + beta*c for op = 'N'or'n'
563 // c = alpha*conj(a')*a + beta*c for op = 'C'or'c'
564 int EIGEN_BLAS_FUNC(herk)(const char *uplo, const char *op, const int *n, const int *k,
565  const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
566 {
567 // std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " << *pbeta << " " << *ldc << "\n";
568 
569  typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, const Scalar&, internal::level3_blocking<Scalar,Scalar>&);
570  static const functype func[8] = {
571  // array index: NOTR | (UP << 2)
573  0,
574  // array index: ADJ | (UP << 2)
576  0,
577  // array index: NOTR | (LO << 2)
579  0,
580  // array index: ADJ | (LO << 2)
582  0
583  };
584 
585  const Scalar* a = reinterpret_cast<const Scalar*>(pa);
586  Scalar* c = reinterpret_cast<Scalar*>(pc);
588  RealScalar beta = *pbeta;
589 
590 // std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n";
591 
592  int info = 0;
593  if(UPLO(*uplo)==INVALID) info = 1;
594  else if((OP(*op)==INVALID) || (OP(*op)==TR)) info = 2;
595  else if(*n<0) info = 3;
596  else if(*k<0) info = 4;
597  else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7;
598  else if(*ldc<std::max(1,*n)) info = 10;
599  if(info)
600  return xerbla_(SCALAR_SUFFIX_UP"HERK ",&info,6);
601 
602  int code = OP(*op) | (UPLO(*uplo) << 2);
603 
604  if(beta!=RealScalar(1))
605  {
606  if(UPLO(*uplo)==UP)
607  if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero();
608  else matrix(c, *n, *n, *ldc).triangularView<StrictlyUpper>() *= beta;
609  else
610  if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero();
611  else matrix(c, *n, *n, *ldc).triangularView<StrictlyLower>() *= beta;
612 
613  if(beta!=Scalar(0))
614  {
615  matrix(c, *n, *n, *ldc).diagonal().real() *= beta;
616  matrix(c, *n, *n, *ldc).diagonal().imag().setZero();
617  }
618  }
619 
620  if(*k>0 && alpha!=RealScalar(0))
621  {
622  internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*n,*n,*k,1,false);
623  func[code](*n, *k, a, *lda, a, *lda, c, *ldc, alpha, blocking);
624  matrix(c, *n, *n, *ldc).diagonal().imag().setZero();
625  }
626  return 0;
627 }
628 
629 // c = alpha*a*conj(b') + conj(alpha)*b*conj(a') + beta*c, for op = 'N'or'n'
630 // c = alpha*conj(a')*b + conj(alpha)*conj(b')*a + beta*c, for op = 'C'or'c'
631 int EIGEN_BLAS_FUNC(her2k)(const char *uplo, const char *op, const int *n, const int *k,
632  const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
633 {
634  const Scalar* a = reinterpret_cast<const Scalar*>(pa);
635  const Scalar* b = reinterpret_cast<const Scalar*>(pb);
636  Scalar* c = reinterpret_cast<Scalar*>(pc);
637  Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
638  RealScalar beta = *pbeta;
639 
640 // std::cerr << "in her2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << *ldb << " " << beta << " " << *ldc << "\n";
641 
642  int info = 0;
643  if(UPLO(*uplo)==INVALID) info = 1;
644  else if((OP(*op)==INVALID) || (OP(*op)==TR)) info = 2;
645  else if(*n<0) info = 3;
646  else if(*k<0) info = 4;
647  else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7;
648  else if(*ldb<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 9;
649  else if(*ldc<std::max(1,*n)) info = 12;
650  if(info)
651  return xerbla_(SCALAR_SUFFIX_UP"HER2K",&info,6);
652 
653  if(beta!=RealScalar(1))
654  {
655  if(UPLO(*uplo)==UP)
656  if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero();
657  else matrix(c, *n, *n, *ldc).triangularView<StrictlyUpper>() *= beta;
658  else
659  if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero();
660  else matrix(c, *n, *n, *ldc).triangularView<StrictlyLower>() *= beta;
661 
662  if(beta!=Scalar(0))
663  {
664  matrix(c, *n, *n, *ldc).diagonal().real() *= beta;
665  matrix(c, *n, *n, *ldc).diagonal().imag().setZero();
666  }
667  }
668  else if(*k>0 && alpha!=Scalar(0))
669  matrix(c, *n, *n, *ldc).diagonal().imag().setZero();
670 
671  if(*k==0)
672  return 1;
673 
674  if(OP(*op)==NOTR)
675  {
676  if(UPLO(*uplo)==UP)
677  {
678  matrix(c, *n, *n, *ldc).triangularView<Upper>()
679  += alpha *matrix(a, *n, *k, *lda)*matrix(b, *n, *k, *ldb).adjoint()
680  + numext::conj(alpha)*matrix(b, *n, *k, *ldb)*matrix(a, *n, *k, *lda).adjoint();
681  }
682  else if(UPLO(*uplo)==LO)
683  matrix(c, *n, *n, *ldc).triangularView<Lower>()
684  += alpha*matrix(a, *n, *k, *lda)*matrix(b, *n, *k, *ldb).adjoint()
685  + numext::conj(alpha)*matrix(b, *n, *k, *ldb)*matrix(a, *n, *k, *lda).adjoint();
686  }
687  else if(OP(*op)==ADJ)
688  {
689  if(UPLO(*uplo)==UP)
690  matrix(c, *n, *n, *ldc).triangularView<Upper>()
691  += alpha*matrix(a, *k, *n, *lda).adjoint()*matrix(b, *k, *n, *ldb)
692  + numext::conj(alpha)*matrix(b, *k, *n, *ldb).adjoint()*matrix(a, *k, *n, *lda);
693  else if(UPLO(*uplo)==LO)
694  matrix(c, *n, *n, *ldc).triangularView<Lower>()
695  += alpha*matrix(a, *k, *n, *lda).adjoint()*matrix(b, *k, *n, *ldb)
696  + numext::conj(alpha)*matrix(b, *k, *n, *ldb).adjoint()*matrix(a, *k, *n, *lda);
697  }
698 
699  return 1;
700 }
701 
702 #endif // ISCOMPLEX
#define SCALAR_SUFFIX_UP
Matrix3f m
SCALAR Scalar
Definition: bench_gemm.cpp:33
#define max(a, b)
Definition: datatypes.h:20
Scalar * b
Definition: benchVecAdd.cpp:17
Matrix diag(const std::vector< Matrix > &Hs)
Definition: Matrix.cpp:206
int RealScalar int RealScalar int RealScalar * pc
int EIGEN_BLAS_FUNC() trmm(const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n, const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb)
Definition: level3_impl.h:183
int n
Scalar Scalar * c
Definition: benchVecAdd.cpp:17
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar * data() const
#define ISCOMPLEX
int EIGEN_BLAS_FUNC() gemm(const char *opa, const char *opb, const int *m, const int *n, const int *k, const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
Definition: level3_impl.h:12
#define EIGEN_BLAS_FUNC(X)
Array33i a
else if n * info
Scalar Scalar int size
Definition: benchVecAdd.cpp:17
EIGEN_DEVICE_FUNC Index outerStride() const
EIGEN_WEAK_LINKING int xerbla_(const char *msg, int *info, int)
Definition: xerbla.cpp:15
int RealScalar * palpha
RealScalar alpha
* lda
Definition: eigenvalues.cpp:59
MatrixXf matA(2, 2)
NumTraits< Scalar >::Real RealScalar
Definition: bench_gemm.cpp:34
int EIGEN_BLAS_FUNC() trsm(const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n, const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb)
Definition: level3_impl.h:78
EIGEN_DEFAULT_DENSE_INDEX_TYPE DenseIndex
Definition: Meta.h:25
Map< Matrix< T, Dynamic, Dynamic, ColMajor >, 0, OuterStride<> > matrix(T *data, int rows, int cols, int stride)
The matrix class, also used for vectors and row-vectors.
void run(Expr &expr, Dev &dev)
Definition: TensorSyclRun.h:33
int EIGEN_BLAS_FUNC() symm(const char *side, const char *uplo, const int *m, const int *n, const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
Definition: level3_impl.h:287
int EIGEN_BLAS_FUNC() syrk(const char *uplo, const char *op, const int *n, const int *k, const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
Definition: level3_impl.h:357
ScalarWithExceptions conj(const ScalarWithExceptions &x)
Definition: exceptions.cpp:74
int EIGEN_BLAS_FUNC() syr2k(const char *uplo, const char *op, const int *n, const int *k, const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
Definition: level3_impl.h:437
v setZero(3)


gtsam
Author(s):
autogenerated on Sat May 8 2021 02:42:30