gemv.hpp
Go to the documentation of this file.
1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
2 // Copyright (C) 2008-2011 Conrad Sanderson
3 //
4 // This file is part of the Armadillo C++ library.
5 // It is provided without any warranty of fitness
6 // for any purpose. You can redistribute this file
7 // and/or modify it under the terms of the GNU
8 // Lesser General Public License (LGPL) as published
9 // by the Free Software Foundation, either version 3
10 // of the License or (at your option) any later version.
11 // (see http://www.opensource.org/licenses for more info)
12 
13 
16 
17 
18 
20 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
22  {
23  public:
24 
25 
26  template<const uword row, const uword col>
27  struct pos
28  {
29  static const uword n2 = (do_trans_A == false) ? (row + col*2) : (col + row*2);
30  static const uword n3 = (do_trans_A == false) ? (row + col*3) : (col + row*3);
31  static const uword n4 = (do_trans_A == false) ? (row + col*4) : (col + row*4);
32  };
33 
34 
35 
36  template<typename eT, const uword i>
37  arma_hot
39  static
40  void
41  assign(eT* y, const eT acc, const eT alpha, const eT beta)
42  {
43  if(use_beta == false)
44  {
45  y[i] = (use_alpha == false) ? acc : alpha*acc;
46  }
47  else
48  {
49  const eT tmp = y[i];
50 
51  y[i] = beta*tmp + ( (use_alpha == false) ? acc : alpha*acc );
52  }
53  }
54 
55 
56 
57  template<typename eT>
58  arma_hot
59  inline
60  static
61  void
62  apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
63  {
65 
66  const eT* Am = A.memptr();
67 
68  switch(A.n_rows)
69  {
70  case 1:
71  {
72  const eT acc = Am[0] * x[0];
73 
74  assign<eT, 0>(y, acc, alpha, beta);
75  }
76  break;
77 
78 
79  case 2:
80  {
81  const eT x0 = x[0];
82  const eT x1 = x[1];
83 
84  const eT acc0 = Am[pos<0,0>::n2]*x0 + Am[pos<0,1>::n2]*x1;
85  const eT acc1 = Am[pos<1,0>::n2]*x0 + Am[pos<1,1>::n2]*x1;
86 
87  assign<eT, 0>(y, acc0, alpha, beta);
88  assign<eT, 1>(y, acc1, alpha, beta);
89  }
90  break;
91 
92 
93  case 3:
94  {
95  const eT x0 = x[0];
96  const eT x1 = x[1];
97  const eT x2 = x[2];
98 
99  const eT acc0 = Am[pos<0,0>::n3]*x0 + Am[pos<0,1>::n3]*x1 + Am[pos<0,2>::n3]*x2;
100  const eT acc1 = Am[pos<1,0>::n3]*x0 + Am[pos<1,1>::n3]*x1 + Am[pos<1,2>::n3]*x2;
101  const eT acc2 = Am[pos<2,0>::n3]*x0 + Am[pos<2,1>::n3]*x1 + Am[pos<2,2>::n3]*x2;
102 
103  assign<eT, 0>(y, acc0, alpha, beta);
104  assign<eT, 1>(y, acc1, alpha, beta);
105  assign<eT, 2>(y, acc2, alpha, beta);
106  }
107  break;
108 
109 
110  case 4:
111  {
112  const eT x0 = x[0];
113  const eT x1 = x[1];
114  const eT x2 = x[2];
115  const eT x3 = x[3];
116 
117  const eT acc0 = Am[pos<0,0>::n4]*x0 + Am[pos<0,1>::n4]*x1 + Am[pos<0,2>::n4]*x2 + Am[pos<0,3>::n4]*x3;
118  const eT acc1 = Am[pos<1,0>::n4]*x0 + Am[pos<1,1>::n4]*x1 + Am[pos<1,2>::n4]*x2 + Am[pos<1,3>::n4]*x3;
119  const eT acc2 = Am[pos<2,0>::n4]*x0 + Am[pos<2,1>::n4]*x1 + Am[pos<2,2>::n4]*x2 + Am[pos<2,3>::n4]*x3;
120  const eT acc3 = Am[pos<3,0>::n4]*x0 + Am[pos<3,1>::n4]*x1 + Am[pos<3,2>::n4]*x2 + Am[pos<3,3>::n4]*x3;
121 
122  assign<eT, 0>(y, acc0, alpha, beta);
123  assign<eT, 1>(y, acc1, alpha, beta);
124  assign<eT, 2>(y, acc2, alpha, beta);
125  assign<eT, 3>(y, acc3, alpha, beta);
126  }
127  break;
128 
129 
130  default:
131  ;
132  }
133  }
134 
135  };
136 
137 
138 
142 
143 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
145  {
146  public:
147 
148  template<typename eT>
149  arma_hot
150  inline
151  static
152  void
153  apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
154  {
156 
157  const uword A_n_rows = A.n_rows;
158  const uword A_n_cols = A.n_cols;
159 
160  if(do_trans_A == false)
161  {
162  if(A_n_rows == 1)
163  {
164  const eT acc = op_dot::direct_dot_arma(A_n_cols, A.mem, x);
165 
166  if( (use_alpha == false) && (use_beta == false) )
167  {
168  y[0] = acc;
169  }
170  else
171  if( (use_alpha == true) && (use_beta == false) )
172  {
173  y[0] = alpha * acc;
174  }
175  else
176  if( (use_alpha == false) && (use_beta == true) )
177  {
178  y[0] = acc + beta*y[0];
179  }
180  else
181  if( (use_alpha == true) && (use_beta == true) )
182  {
183  y[0] = alpha*acc + beta*y[0];
184  }
185  }
186  else
187  for(uword row=0; row < A_n_rows; ++row)
188  {
189  eT acc = eT(0);
190 
191  for(uword i=0; i < A_n_cols; ++i)
192  {
193  acc += A.at(row,i) * x[i];
194  }
195 
196  if( (use_alpha == false) && (use_beta == false) )
197  {
198  y[row] = acc;
199  }
200  else
201  if( (use_alpha == true) && (use_beta == false) )
202  {
203  y[row] = alpha * acc;
204  }
205  else
206  if( (use_alpha == false) && (use_beta == true) )
207  {
208  y[row] = acc + beta*y[row];
209  }
210  else
211  if( (use_alpha == true) && (use_beta == true) )
212  {
213  y[row] = alpha*acc + beta*y[row];
214  }
215  }
216  }
217  else
218  if(do_trans_A == true)
219  {
220  for(uword col=0; col < A_n_cols; ++col)
221  {
222  // col is interpreted as row when storing the results in 'y'
223 
224 
225  // const eT* A_coldata = A.colptr(col);
226  //
227  // eT acc = eT(0);
228  // for(uword row=0; row < A_n_rows; ++row)
229  // {
230  // acc += A_coldata[row] * x[row];
231  // }
232 
233  const eT acc = op_dot::direct_dot_arma(A_n_rows, A.colptr(col), x);
234 
235  if( (use_alpha == false) && (use_beta == false) )
236  {
237  y[col] = acc;
238  }
239  else
240  if( (use_alpha == true) && (use_beta == false) )
241  {
242  y[col] = alpha * acc;
243  }
244  else
245  if( (use_alpha == false) && (use_beta == true) )
246  {
247  y[col] = acc + beta*y[col];
248  }
249  else
250  if( (use_alpha == true) && (use_beta == true) )
251  {
252  y[col] = alpha*acc + beta*y[col];
253  }
254 
255  }
256  }
257  }
258 
259  };
260 
261 
262 
263 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
265  {
266  public:
267 
268  template<typename eT>
269  arma_hot
270  inline
271  static
272  void
273  apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0), const typename arma_not_cx<eT>::result* junk = 0 )
274  {
276  arma_ignore(junk);
277 
278  const uword A_n_rows = A.n_rows;
279  const uword A_n_cols = A.n_cols;
280 
281  if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) )
282  {
284  }
285  else
286  {
288  }
289  }
290 
291 
292 
293  template<typename eT>
294  arma_hot
295  inline
296  static
297  void
298  apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0), const typename arma_cx_only<eT>::result* junk = 0 )
299  {
301 
302  Mat<eT> tmp_A;
303 
304  if(do_trans_A)
305  {
306  op_htrans::apply_noalias(tmp_A, A);
307  }
308 
309  const Mat<eT>& AA = (do_trans_A == false) ? A : tmp_A;
310 
311  const uword AA_n_rows = AA.n_rows;
312  const uword AA_n_cols = AA.n_cols;
313 
314  if( (AA_n_rows <= 4) && (AA_n_rows == AA_n_cols) )
315  {
317  }
318  else
319  {
321  }
322  }
323  };
324 
325 
326 
330 
331 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
332 class gemv
333  {
334  public:
335 
336  template<typename eT>
337  inline
338  static
339  void
340  apply_blas_type( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
341  {
343 
344  if(A.n_elem <= 64u)
345  {
347  }
348  else
349  {
350  #if defined(ARMA_USE_ATLAS)
351  {
352  arma_extra_debug_print("atlas::cblas_gemv()");
353 
354  atlas::cblas_gemv<eT>
355  (
356  atlas::CblasColMajor,
357  (do_trans_A) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
358  A.n_rows,
359  A.n_cols,
360  (use_alpha) ? alpha : eT(1),
361  A.mem,
362  A.n_rows,
363  x,
364  1,
365  (use_beta) ? beta : eT(0),
366  y,
367  1
368  );
369  }
370  #elif defined(ARMA_USE_BLAS)
371  {
372  arma_extra_debug_print("blas::gemv()");
373 
374  const char trans_A = (do_trans_A) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N';
375  const blas_int m = A.n_rows;
376  const blas_int n = A.n_cols;
377  const eT local_alpha = (use_alpha) ? alpha : eT(1);
378  //const blas_int lda = A.n_rows;
379  const blas_int inc = 1;
380  const eT local_beta = (use_beta) ? beta : eT(0);
381 
382  arma_extra_debug_print( arma_boost::format("blas::gemv(): trans_A = %c") % trans_A );
383 
384  blas::gemv<eT>
385  (
386  &trans_A,
387  &m,
388  &n,
389  &local_alpha,
390  A.mem,
391  &m, // lda
392  x,
393  &inc,
394  &local_beta,
395  y,
396  &inc
397  );
398  }
399  #else
400  {
402  }
403  #endif
404  }
405 
406  }
407 
408 
409 
410  template<typename eT>
412  static
413  void
414  apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
415  {
417  }
418 
419 
420 
422  static
423  void
424  apply
425  (
426  float* y,
427  const Mat<float>& A,
428  const float* x,
429  const float alpha = float(1),
430  const float beta = float(0)
431  )
432  {
434  }
435 
436 
437 
439  static
440  void
441  apply
442  (
443  double* y,
444  const Mat<double>& A,
445  const double* x,
446  const double alpha = double(1),
447  const double beta = double(0)
448  )
449  {
451  }
452 
453 
454 
456  static
457  void
458  apply
459  (
460  std::complex<float>* y,
461  const Mat< std::complex<float > >& A,
462  const std::complex<float>* x,
463  const std::complex<float> alpha = std::complex<float>(1),
464  const std::complex<float> beta = std::complex<float>(0)
465  )
466  {
468  }
469 
470 
471 
473  static
474  void
475  apply
476  (
477  std::complex<double>* y,
478  const Mat< std::complex<double> >& A,
479  const std::complex<double>* x,
480  const std::complex<double> alpha = std::complex<double>(1),
481  const std::complex<double> beta = std::complex<double>(0)
482  )
483  {
485  }
486 
487 
488 
489  };
490 
491 
static arma_hot void apply(eT *y, const Mat< eT > &A, const eT *x, const eT alpha=eT(1), const eT beta=eT(0))
Definition: gemv.hpp:153
arma_inline arma_warn_unused eT * memptr()
returns a pointer to array of eTs used by the matrix
Definition: Mat_meat.hpp:4024
int blas_int
static arma_hot void apply(eT *y, const Mat< eT > &A, const eT *x, const eT alpha=eT(1), const eT beta=eT(0), const typename arma_not_cx< eT >::result *junk=0)
Definition: gemv.hpp:273
Wrapper for ATLAS/BLAS gemv function, using template arguments to control the arguments passed to gem...
const uword n_cols
number of columns in the matrix (read-only)
Definition: Mat_bones.hpp:30
const uword n_elem
number of elements in the matrix (read-only)
Definition: Mat_bones.hpp:31
const uword n_rows
number of rows in the matrix (read-only)
Definition: Mat_bones.hpp:29
static arma_hot void apply(eT *y, const Mat< eT > &A, const eT *x, const eT alpha=eT(1), const eT beta=eT(0))
Definition: gemv.hpp:62
#define arma_extra_debug_print
Definition: debug.hpp:1118
arma_inline arma_warn_unused eT * colptr(const uword in_col)
returns a pointer to array of eTs for a specified column; no bounds check
Definition: Mat_meat.hpp:4000
u32 uword
Definition: typedef.hpp:85
for tiny square matrices, size <= 4x4
Definition: gemv.hpp:21
arma_inline arma_warn_unused eT & at(const uword i)
linear element accessor (treats the matrix as a vector); no bounds check.
Definition: Mat_meat.hpp:3692
#define arma_ignore(variable)
static arma_inline void apply_noalias(Mat< eT > &out, const Mat< eT > &A, const typename arma_not_cx< eT >::result *junk=0)
static arma_inline void apply(eT *y, const Mat< eT > &A, const eT *x, const eT alpha=eT(1), const eT beta=eT(0))
Definition: gemv.hpp:414
static const uword n4
Definition: gemv.hpp:31
#define arma_extra_debug_sigprint
Definition: debug.hpp:1116
static arma_hot void apply(eT *y, const Mat< eT > &A, const eT *x, const eT alpha=eT(1), const eT beta=eT(0), const typename arma_cx_only< eT >::result *junk=0)
Definition: gemv.hpp:298
arma_hot static arma_inline void assign(eT *y, const eT acc, const eT alpha, const eT beta)
Definition: gemv.hpp:41
Dense matrix class.
#define arma_inline
static void apply_blas_type(eT *y, const Mat< eT > &A, const eT *x, const eT alpha=eT(1), const eT beta=eT(0))
Definition: gemv.hpp:340
static const uword n3
Definition: gemv.hpp:30
Partial emulation of ATLAS/BLAS gemv(). &#39;y&#39; is assumed to have been set to the correct size (i...
Definition: gemv.hpp:144
arma_aligned const eT *const mem
pointer to the memory used by the matrix (memory is read-only)
Definition: Mat_bones.hpp:40
static const uword n2
Definition: gemv.hpp:29
#define arma_hot
arma_hot static arma_pure eT direct_dot_arma(const uword n_elem, const eT *const A, const eT *const B)


armadillo_matrix
Author(s):
autogenerated on Fri Apr 16 2021 02:31:57