gemm.hpp
Go to the documentation of this file.
1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
2 // Copyright (C) 2008-2011 Conrad Sanderson
3 //
4 // This file is part of the Armadillo C++ library.
5 // It is provided without any warranty of fitness
6 // for any purpose. You can redistribute this file
7 // and/or modify it under the terms of the GNU
8 // Lesser General Public License (LGPL) as published
9 // by the Free Software Foundation, either version 3
10 // of the License or (at your option) any later version.
11 // (see http://www.opensource.org/licenses for more info)
12 
13 
16 
17 
18 
20 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
22  {
23  public:
24 
25 
26  template<typename eT>
27  arma_hot
28  inline
29  static
30  void
31  apply
32  (
33  Mat<eT>& C,
34  const Mat<eT>& A,
35  const Mat<eT>& B,
36  const eT alpha = eT(1),
37  const eT beta = eT(0)
38  )
39  {
41 
42  switch(A.n_rows)
43  {
44  case 4:
46 
47  case 3:
49 
50  case 2:
52 
53  case 1:
55 
56  default:
57  ;
58  }
59  }
60 
61  };
62 
63 
64 
65 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
67  {
68  public:
69 
70  template<typename eT>
71  arma_hot
72  inline
73  static
74  void
75  apply
76  (
77  Mat<eT>& C,
78  const Mat<eT>& A,
79  const Mat<eT>& B,
80  const eT alpha = eT(1),
81  const eT beta = eT(0)
82  )
83  {
85 
86  const uword A_n_rows = A.n_rows;
87  const uword A_n_cols = A.n_cols;
88 
89  const uword B_n_rows = B.n_rows;
90  const uword B_n_cols = B.n_cols;
91 
92  if( (do_trans_A == false) && (do_trans_B == false) )
93  {
94  arma_aligned podarray<eT> tmp(A_n_cols);
95  eT* A_rowdata = tmp.memptr();
96 
97  for(uword row_A=0; row_A < A_n_rows; ++row_A)
98  {
99  tmp.copy_row(A, row_A);
100 
101  for(uword col_B=0; col_B < B_n_cols; ++col_B)
102  {
103  const eT acc = op_dot::direct_dot_arma(B_n_rows, A_rowdata, B.colptr(col_B));
104 
105  if( (use_alpha == false) && (use_beta == false) )
106  {
107  C.at(row_A,col_B) = acc;
108  }
109  else
110  if( (use_alpha == true) && (use_beta == false) )
111  {
112  C.at(row_A,col_B) = alpha * acc;
113  }
114  else
115  if( (use_alpha == false) && (use_beta == true) )
116  {
117  C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
118  }
119  else
120  if( (use_alpha == true) && (use_beta == true) )
121  {
122  C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
123  }
124 
125  }
126  }
127  }
128  else
129  if( (do_trans_A == true) && (do_trans_B == false) )
130  {
131  for(uword col_A=0; col_A < A_n_cols; ++col_A)
132  {
133  // col_A is interpreted as row_A when storing the results in matrix C
134 
135  const eT* A_coldata = A.colptr(col_A);
136 
137  for(uword col_B=0; col_B < B_n_cols; ++col_B)
138  {
139  const eT acc = op_dot::direct_dot_arma(B_n_rows, A_coldata, B.colptr(col_B));
140 
141  if( (use_alpha == false) && (use_beta == false) )
142  {
143  C.at(col_A,col_B) = acc;
144  }
145  else
146  if( (use_alpha == true) && (use_beta == false) )
147  {
148  C.at(col_A,col_B) = alpha * acc;
149  }
150  else
151  if( (use_alpha == false) && (use_beta == true) )
152  {
153  C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
154  }
155  else
156  if( (use_alpha == true) && (use_beta == true) )
157  {
158  C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
159  }
160 
161  }
162  }
163  }
164  else
165  if( (do_trans_A == false) && (do_trans_B == true) )
166  {
167  Mat<eT> BB;
169 
171  }
172  else
173  if( (do_trans_A == true) && (do_trans_B == true) )
174  {
175  // mat B_tmp = trans(B);
176  // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
177 
178 
179  // By using the trans(A)*trans(B) = trans(B*A) equivalency,
180  // transpose operations are not needed
181 
183  eT* B_rowdata = tmp.memptr();
184 
185  for(uword row_B=0; row_B < B_n_rows; ++row_B)
186  {
187  tmp.copy_row(B, row_B);
188 
189  for(uword col_A=0; col_A < A_n_cols; ++col_A)
190  {
191  const eT acc = op_dot::direct_dot_arma(A_n_rows, B_rowdata, A.colptr(col_A));
192 
193  if( (use_alpha == false) && (use_beta == false) )
194  {
195  C.at(col_A,row_B) = acc;
196  }
197  else
198  if( (use_alpha == true) && (use_beta == false) )
199  {
200  C.at(col_A,row_B) = alpha * acc;
201  }
202  else
203  if( (use_alpha == false) && (use_beta == true) )
204  {
205  C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
206  }
207  else
208  if( (use_alpha == true) && (use_beta == true) )
209  {
210  C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
211  }
212 
213  }
214  }
215 
216  }
217  }
218 
219  };
220 
221 
222 
223 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
225  {
226  public:
227 
228 
229  template<typename eT>
230  arma_hot
231  inline
232  static
233  void
234  apply
235  (
236  Mat<eT>& C,
237  const Mat<eT>& A,
238  const Mat<eT>& B,
239  const eT alpha = eT(1),
240  const eT beta = eT(0),
241  const typename arma_not_cx<eT>::result* junk = 0
242  )
243  {
245  arma_ignore(junk);
246 
247  const uword A_n_rows = A.n_rows;
248  const uword A_n_cols = A.n_cols;
249 
250  const uword B_n_rows = B.n_rows;
251  const uword B_n_cols = B.n_cols;
252 
253  if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) && (A_n_rows == B_n_rows) && (B_n_rows == B_n_cols) )
254  {
255  if(do_trans_B == false)
256  {
258  }
259  else
260  {
261  Mat<eT> BB(A_n_rows, A_n_rows);
263 
265  }
266  }
267  else
268  {
270  }
271  }
272 
273 
274 
275  template<typename eT>
276  arma_hot
277  inline
278  static
279  void
280  apply
281  (
282  Mat<eT>& C,
283  const Mat<eT>& A,
284  const Mat<eT>& B,
285  const eT alpha = eT(1),
286  const eT beta = eT(0),
287  const typename arma_cx_only<eT>::result* junk = 0
288  )
289  {
291  arma_ignore(junk);
292 
293  // "better than nothing" handling of hermitian transposes for complex number matrices
294 
295  Mat<eT> tmp_A;
296  Mat<eT> tmp_B;
297 
298  if(do_trans_A)
299  {
300  op_htrans::apply_noalias(tmp_A, A);
301  }
302 
303  if(do_trans_B)
304  {
305  op_htrans::apply_noalias(tmp_B, B);
306  }
307 
308  const Mat<eT>& AA = (do_trans_A == false) ? A : tmp_A;
309  const Mat<eT>& BB = (do_trans_B == false) ? B : tmp_B;
310 
311  const uword A_n_rows = AA.n_rows;
312  const uword A_n_cols = AA.n_cols;
313 
314  const uword B_n_rows = BB.n_rows;
315  const uword B_n_cols = BB.n_cols;
316 
317  if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) && (A_n_rows == B_n_rows) && (B_n_rows == B_n_cols) )
318  {
320  }
321  else
322  {
324  }
325  }
326 
327  };
328 
329 
330 
334 
335 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
336 class gemm
337  {
338  public:
339 
340  template<typename eT>
341  inline
342  static
343  void
344  apply_blas_type( Mat<eT>& C, const Mat<eT>& A, const Mat<eT>& B, const eT alpha = eT(1), const eT beta = eT(0) )
345  {
347 
348  if( (A.n_elem <= 48u) && (B.n_elem <= 48u) )
349  {
351  }
352  else
353  {
354  #if defined(ARMA_USE_ATLAS)
355  {
356  arma_extra_debug_print("atlas::cblas_gemm()");
357 
358  atlas::cblas_gemm<eT>
359  (
360  atlas::CblasColMajor,
361  (do_trans_A) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
362  (do_trans_B) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
363  C.n_rows,
364  C.n_cols,
365  (do_trans_A) ? A.n_rows : A.n_cols,
366  (use_alpha) ? alpha : eT(1),
367  A.mem,
368  (do_trans_A) ? A.n_rows : C.n_rows,
369  B.mem,
370  (do_trans_B) ? C.n_cols : ( (do_trans_A) ? A.n_rows : A.n_cols ),
371  (use_beta) ? beta : eT(0),
372  C.memptr(),
373  C.n_rows
374  );
375  }
376  #elif defined(ARMA_USE_BLAS)
377  {
378  arma_extra_debug_print("blas::gemm()");
379 
380  const char trans_A = (do_trans_A) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N';
381  const char trans_B = (do_trans_B) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N';
382 
383  const blas_int m = C.n_rows;
384  const blas_int n = C.n_cols;
385  const blas_int k = (do_trans_A) ? A.n_rows : A.n_cols;
386 
387  const eT local_alpha = (use_alpha) ? alpha : eT(1);
388 
389  const blas_int lda = (do_trans_A) ? k : m;
390  const blas_int ldb = (do_trans_B) ? n : k;
391 
392  const eT local_beta = (use_beta) ? beta : eT(0);
393 
394  arma_extra_debug_print( arma_boost::format("blas::gemm(): trans_A = %c") % trans_A );
395  arma_extra_debug_print( arma_boost::format("blas::gemm(): trans_B = %c") % trans_B );
396 
397  blas::gemm<eT>
398  (
399  &trans_A,
400  &trans_B,
401  &m,
402  &n,
403  &k,
404  &local_alpha,
405  A.mem,
406  &lda,
407  B.mem,
408  &ldb,
409  &local_beta,
410  C.memptr(),
411  &m
412  );
413  }
414  #else
415  {
417  }
418  #endif
419  }
420  }
421 
422 
423 
425  template<typename eT>
426  inline
427  static
428  void
429  apply( Mat<eT>& C, const Mat<eT>& A, const Mat<eT>& B, const eT alpha = eT(1), const eT beta = eT(0) )
430  {
432  }
433 
434 
435 
437  static
438  void
439  apply
440  (
441  Mat<float>& C,
442  const Mat<float>& A,
443  const Mat<float>& B,
444  const float alpha = float(1),
445  const float beta = float(0)
446  )
447  {
449  }
450 
451 
452 
454  static
455  void
456  apply
457  (
458  Mat<double>& C,
459  const Mat<double>& A,
460  const Mat<double>& B,
461  const double alpha = double(1),
462  const double beta = double(0)
463  )
464  {
466  }
467 
468 
469 
471  static
472  void
473  apply
474  (
475  Mat< std::complex<float> >& C,
476  const Mat< std::complex<float> >& A,
477  const Mat< std::complex<float> >& B,
478  const std::complex<float> alpha = std::complex<float>(1),
479  const std::complex<float> beta = std::complex<float>(0)
480  )
481  {
483  }
484 
485 
486 
488  static
489  void
490  apply
491  (
492  Mat< std::complex<double> >& C,
493  const Mat< std::complex<double> >& A,
494  const Mat< std::complex<double> >& B,
495  const std::complex<double> alpha = std::complex<double>(1),
496  const std::complex<double> beta = std::complex<double>(0)
497  )
498  {
500  }
501 
502  };
503 
504 
505 
arma_inline arma_warn_unused eT * memptr()
returns a pointer to array of eTs used by the matrix
Definition: Mat_meat.hpp:4024
int blas_int
A lightweight array for POD types. If the amount of memory requested is small, the stack is used...
for tiny square matrices, size <= 4x4
Definition: gemm.hpp:21
static arma_hot void apply(Mat< eT > &C, const Mat< eT > &A, const Mat< eT > &B, const eT alpha=eT(1), const eT beta=eT(0))
Definition: gemm.hpp:76
const uword n_cols
number of columns in the matrix (read-only)
Definition: Mat_bones.hpp:30
static arma_hot void apply(Mat< eT > &C, const Mat< eT > &A, const Mat< eT > &B, const eT alpha=eT(1), const eT beta=eT(0), const typename arma_not_cx< eT >::result *junk=0)
Definition: gemm.hpp:235
static void apply_noalias_tinysq(Mat< eT > &out, const Mat< eT > &A)
for tiny square matrices (size <= 4x4)
const uword n_elem
number of elements in the matrix (read-only)
Definition: Mat_bones.hpp:31
const uword n_rows
number of rows in the matrix (read-only)
Definition: Mat_bones.hpp:29
static arma_hot void apply(eT *y, const Mat< eT > &A, const eT *x, const eT alpha=eT(1), const eT beta=eT(0))
Definition: gemv.hpp:62
static arma_hot void apply(Mat< eT > &C, const Mat< eT > &A, const Mat< eT > &B, const eT alpha=eT(1), const eT beta=eT(0))
Definition: gemm.hpp:32
#define arma_extra_debug_print
Definition: debug.hpp:1118
arma_inline arma_warn_unused eT * colptr(const uword in_col)
returns a pointer to array of eTs for a specified column; no bounds check
Definition: Mat_meat.hpp:4000
u32 uword
Definition: typedef.hpp:85
Wrapper for ATLAS/BLAS dgemm function, using template arguments to control the arguments passed to dg...
arma_inline arma_warn_unused eT & at(const uword i)
linear element accessor (treats the matrix as a vector); no bounds check.
Definition: Mat_meat.hpp:3692
#define arma_ignore(variable)
static void apply(Mat< eT > &C, const Mat< eT > &A, const Mat< eT > &B, const eT alpha=eT(1), const eT beta=eT(0))
immediate multiplication of matrices A and B, storing the result in C
Definition: gemm.hpp:429
static void apply_blas_type(Mat< eT > &C, const Mat< eT > &A, const Mat< eT > &B, const eT alpha=eT(1), const eT beta=eT(0))
Definition: gemm.hpp:344
static arma_inline void apply_noalias(Mat< eT > &out, const Mat< eT > &A, const typename arma_not_cx< eT >::result *junk=0)
#define arma_aligned
#define arma_extra_debug_sigprint
Definition: debug.hpp:1116
Dense matrix class.
#define arma_inline
arma_aligned const eT *const mem
pointer to the memory used by the matrix (memory is read-only)
Definition: Mat_bones.hpp:40
#define arma_hot
static void apply_noalias(Mat< eT > &out, const Mat< eT > &A)
Immediate transpose of a dense matrix.
arma_hot static arma_pure eT direct_dot_arma(const uword n_elem, const eT *const A, const eT *const B)


armadillo_matrix
Author(s):
autogenerated on Fri Apr 16 2021 02:31:57