gemm_mixed.hpp
Go to the documentation of this file.
1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
2 // Copyright (C) 2008-2011 Conrad Sanderson
3 //
4 // This file is part of the Armadillo C++ library.
5 // It is provided without any warranty of fitness
6 // for any purpose. You can redistribute this file
7 // and/or modify it under the terms of the GNU
8 // Lesser General Public License (LGPL) as published
9 // by the Free Software Foundation, either version 3
10 // of the License or (at your option) any later version.
11 // (see http://www.opensource.org/licenses for more info)
12 
13 
16 
17 
18 
23 
24 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
26  {
27  public:
28 
29  template<typename out_eT, typename in_eT1, typename in_eT2>
30  arma_hot
31  inline
32  static
33  void
34  apply
35  (
36  Mat<out_eT>& C,
37  const Mat<in_eT1>& A,
38  const Mat<in_eT2>& B,
39  const out_eT alpha = out_eT(1),
40  const out_eT beta = out_eT(0)
41  )
42  {
44 
45  const uword A_n_rows = A.n_rows;
46  const uword A_n_cols = A.n_cols;
47 
48  const uword B_n_rows = B.n_rows;
49  const uword B_n_cols = B.n_cols;
50 
51  if( (do_trans_A == false) && (do_trans_B == false) )
52  {
53  podarray<in_eT1> tmp(A_n_cols);
54  in_eT1* A_rowdata = tmp.memptr();
55 
56  for(uword row_A=0; row_A < A_n_rows; ++row_A)
57  {
58  tmp.copy_row(A, row_A);
59 
60  for(uword col_B=0; col_B < B_n_cols; ++col_B)
61  {
62  const in_eT2* B_coldata = B.colptr(col_B);
63 
64  out_eT acc = out_eT(0);
65  for(uword i=0; i < B_n_rows; ++i)
66  {
67  acc += upgrade_val<in_eT1,in_eT2>::apply(A_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
68  }
69 
70  if( (use_alpha == false) && (use_beta == false) )
71  {
72  C.at(row_A,col_B) = acc;
73  }
74  else
75  if( (use_alpha == true) && (use_beta == false) )
76  {
77  C.at(row_A,col_B) = alpha * acc;
78  }
79  else
80  if( (use_alpha == false) && (use_beta == true) )
81  {
82  C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
83  }
84  else
85  if( (use_alpha == true) && (use_beta == true) )
86  {
87  C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
88  }
89 
90  }
91  }
92  }
93  else
94  if( (do_trans_A == true) && (do_trans_B == false) )
95  {
96  for(uword col_A=0; col_A < A_n_cols; ++col_A)
97  {
98  // col_A is interpreted as row_A when storing the results in matrix C
99 
100  const in_eT1* A_coldata = A.colptr(col_A);
101 
102  for(uword col_B=0; col_B < B_n_cols; ++col_B)
103  {
104  const in_eT2* B_coldata = B.colptr(col_B);
105 
106  out_eT acc = out_eT(0);
107  for(uword i=0; i < B_n_rows; ++i)
108  {
109  acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
110  }
111 
112  if( (use_alpha == false) && (use_beta == false) )
113  {
114  C.at(col_A,col_B) = acc;
115  }
116  else
117  if( (use_alpha == true) && (use_beta == false) )
118  {
119  C.at(col_A,col_B) = alpha * acc;
120  }
121  else
122  if( (use_alpha == false) && (use_beta == true) )
123  {
124  C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
125  }
126  else
127  if( (use_alpha == true) && (use_beta == true) )
128  {
129  C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
130  }
131 
132  }
133  }
134  }
135  else
136  if( (do_trans_A == false) && (do_trans_B == true) )
137  {
138  Mat<in_eT2> B_tmp;
139 
140  op_strans::apply_noalias(B_tmp, B);
141 
143  }
144  else
145  if( (do_trans_A == true) && (do_trans_B == true) )
146  {
147  // mat B_tmp = trans(B);
148  // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
149 
150 
151  // By using the trans(A)*trans(B) = trans(B*A) equivalency,
152  // transpose operations are not needed
153 
154  podarray<in_eT2> tmp(B_n_cols);
155  in_eT2* B_rowdata = tmp.memptr();
156 
157  for(uword row_B=0; row_B < B_n_rows; ++row_B)
158  {
159  tmp.copy_row(B, row_B);
160 
161  for(uword col_A=0; col_A < A_n_cols; ++col_A)
162  {
163  const in_eT1* A_coldata = A.colptr(col_A);
164 
165  out_eT acc = out_eT(0);
166  for(uword i=0; i < A_n_rows; ++i)
167  {
168  acc += upgrade_val<in_eT1,in_eT2>::apply(B_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
169  }
170 
171  if( (use_alpha == false) && (use_beta == false) )
172  {
173  C.at(col_A,row_B) = acc;
174  }
175  else
176  if( (use_alpha == true) && (use_beta == false) )
177  {
178  C.at(col_A,row_B) = alpha * acc;
179  }
180  else
181  if( (use_alpha == false) && (use_beta == true) )
182  {
183  C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
184  }
185  else
186  if( (use_alpha == true) && (use_beta == true) )
187  {
188  C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
189  }
190 
191  }
192  }
193 
194  }
195  }
196 
197  };
198 
199 
200 
204 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
206  {
207  public:
208 
209  template<typename out_eT, typename in_eT1, typename in_eT2>
210  arma_hot
211  inline
212  static
213  void
214  apply
215  (
216  Mat<out_eT>& C,
217  const Mat<in_eT1>& A,
218  const Mat<in_eT2>& B,
219  const out_eT alpha = out_eT(1),
220  const out_eT beta = out_eT(0)
221  )
222  {
224 
225  const uword A_n_rows = A.n_rows;
226  const uword A_n_cols = A.n_cols;
227 
228  const uword B_n_rows = B.n_rows;
229  const uword B_n_cols = B.n_cols;
230 
231  if( (do_trans_A == false) && (do_trans_B == false) )
232  {
233  for(uword row_A = 0; row_A < A_n_rows; ++row_A)
234  {
235  for(uword col_B = 0; col_B < B_n_cols; ++col_B)
236  {
237  const in_eT2* B_coldata = B.colptr(col_B);
238 
239  out_eT acc = out_eT(0);
240  for(uword i = 0; i < B_n_rows; ++i)
241  {
242  const out_eT val1 = upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i));
243  const out_eT val2 = upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
244  acc += val1 * val2;
245  //acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
246  }
247 
248  if( (use_alpha == false) && (use_beta == false) )
249  {
250  C.at(row_A,col_B) = acc;
251  }
252  else
253  if( (use_alpha == true) && (use_beta == false) )
254  {
255  C.at(row_A,col_B) = alpha * acc;
256  }
257  else
258  if( (use_alpha == false) && (use_beta == true) )
259  {
260  C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
261  }
262  else
263  if( (use_alpha == true) && (use_beta == true) )
264  {
265  C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
266  }
267  }
268  }
269  }
270  else
271  if( (do_trans_A == true) && (do_trans_B == false) )
272  {
273  for(uword col_A=0; col_A < A_n_cols; ++col_A)
274  {
275  // col_A is interpreted as row_A when storing the results in matrix C
276 
277  const in_eT1* A_coldata = A.colptr(col_A);
278 
279  for(uword col_B=0; col_B < B_n_cols; ++col_B)
280  {
281  const in_eT2* B_coldata = B.colptr(col_B);
282 
283  out_eT acc = out_eT(0);
284  for(uword i=0; i < B_n_rows; ++i)
285  {
286  acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
287  }
288 
289  if( (use_alpha == false) && (use_beta == false) )
290  {
291  C.at(col_A,col_B) = acc;
292  }
293  else
294  if( (use_alpha == true) && (use_beta == false) )
295  {
296  C.at(col_A,col_B) = alpha * acc;
297  }
298  else
299  if( (use_alpha == false) && (use_beta == true) )
300  {
301  C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
302  }
303  else
304  if( (use_alpha == true) && (use_beta == true) )
305  {
306  C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
307  }
308 
309  }
310  }
311  }
312  else
313  if( (do_trans_A == false) && (do_trans_B == true) )
314  {
315  for(uword row_A = 0; row_A < A_n_rows; ++row_A)
316  {
317  for(uword row_B = 0; row_B < B_n_rows; ++row_B)
318  {
319  out_eT acc = out_eT(0);
320  for(uword i = 0; i < B_n_cols; ++i)
321  {
323  }
324 
325  if( (use_alpha == false) && (use_beta == false) )
326  {
327  C.at(row_A,row_B) = acc;
328  }
329  else
330  if( (use_alpha == true) && (use_beta == false) )
331  {
332  C.at(row_A,row_B) = alpha * acc;
333  }
334  else
335  if( (use_alpha == false) && (use_beta == true) )
336  {
337  C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B);
338  }
339  else
340  if( (use_alpha == true) && (use_beta == true) )
341  {
342  C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B);
343  }
344  }
345  }
346  }
347  else
348  if( (do_trans_A == true) && (do_trans_B == true) )
349  {
350  for(uword row_B=0; row_B < B_n_rows; ++row_B)
351  {
352 
353  for(uword col_A=0; col_A < A_n_cols; ++col_A)
354  {
355  const in_eT1* A_coldata = A.colptr(col_A);
356 
357  out_eT acc = out_eT(0);
358  for(uword i=0; i < A_n_rows; ++i)
359  {
361  }
362 
363  if( (use_alpha == false) && (use_beta == false) )
364  {
365  C.at(col_A,row_B) = acc;
366  }
367  else
368  if( (use_alpha == true) && (use_beta == false) )
369  {
370  C.at(col_A,row_B) = alpha * acc;
371  }
372  else
373  if( (use_alpha == false) && (use_beta == true) )
374  {
375  C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
376  }
377  else
378  if( (use_alpha == true) && (use_beta == true) )
379  {
380  C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
381  }
382 
383  }
384  }
385 
386  }
387  }
388 
389  };
390 
391 
392 
393 
394 
397 
398 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
400  {
401  public:
402 
404  template<typename out_eT, typename in_eT1, typename in_eT2>
405  inline
406  static
407  void
408  apply
409  (
410  Mat<out_eT>& C,
411  const Mat<in_eT1>& A,
412  const Mat<in_eT2>& B,
413  const out_eT alpha = out_eT(1),
414  const out_eT beta = out_eT(0)
415  )
416  {
418 
419  Mat<in_eT1> tmp_A;
420  Mat<in_eT2> tmp_B;
421 
422  const bool predo_trans_A = ( (do_trans_A == true) && (is_complex<in_eT1>::value == true) );
423  const bool predo_trans_B = ( (do_trans_B == true) && (is_complex<in_eT2>::value == true) );
424 
425  if(do_trans_A)
426  {
427  op_htrans::apply_noalias(tmp_A, A);
428  }
429 
430  if(do_trans_B)
431  {
432  op_htrans::apply_noalias(tmp_B, B);
433  }
434 
435  const Mat<in_eT1>& AA = (predo_trans_A == false) ? A : tmp_A;
436  const Mat<in_eT2>& BB = (predo_trans_B == false) ? B : tmp_B;
437 
438  if( (AA.n_elem <= 64u) && (BB.n_elem <= 64u) )
439  {
441  }
442  else
443  {
445  }
446  }
447 
448 
449  };
450 
451 
452 
arma_hot void copy_row(const Mat< eT > &A, const uword row)
A lightweight array for POD types. If the amount of memory requested is small, the stack is used...
Matrix multplication where the matrices have differing element types. Uses caching for speedup...
Definition: gemm_mixed.hpp:25
const uword n_cols
number of columns in the matrix (read-only)
Definition: Mat_bones.hpp:30
const uword n_elem
number of elements in the matrix (read-only)
Definition: Mat_bones.hpp:31
const uword n_rows
number of rows in the matrix (read-only)
Definition: Mat_bones.hpp:29
arma_inline arma_warn_unused eT * colptr(const uword in_col)
returns a pointer to array of eTs for a specified column; no bounds check
Definition: Mat_meat.hpp:4000
u32 uword
Definition: typedef.hpp:85
static arma_hot void apply(Mat< out_eT > &C, const Mat< in_eT1 > &A, const Mat< in_eT2 > &B, const out_eT alpha=out_eT(1), const out_eT beta=out_eT(0))
Definition: gemm_mixed.hpp:35
arma_inline arma_warn_unused eT & at(const uword i)
linear element accessor (treats the matrix as a vector); no bounds check.
Definition: Mat_meat.hpp:3692
static arma_inline void apply_noalias(Mat< eT > &out, const Mat< eT > &A, const typename arma_not_cx< eT >::result *junk=0)
Matrix multplication where the matrices have differing element types.
Definition: gemm_mixed.hpp:399
static arma_inline promote_type< T1, T2 >::result apply(const T1 x)
Definition: upgrade_val.hpp:31
#define arma_extra_debug_sigprint
Definition: debug.hpp:1116
arma_inline eT * memptr()
#define arma_hot
static void apply_noalias(Mat< eT > &out, const Mat< eT > &A)
Immediate transpose of a dense matrix.


armadillo_matrix
Author(s):
autogenerated on Fri Apr 16 2021 02:31:57