blas_interface_impl.hh
Go to the documentation of this file.
1 
2 #define BLAS_FUNC(NAME) CAT(CAT(SCALAR_PREFIX,NAME),_)
3 
4 template<> class blas_interface<SCALAR> : public c_interface_base<SCALAR>
5 {
6 
7 public :
8 
9  static SCALAR fone;
10  static SCALAR fzero;
11 
12  static inline std::string name()
13  {
14  return MAKE_STRING(CBLASNAME);
15  }
16 
17  static inline void matrix_vector_product(gene_matrix & A, gene_vector & B, gene_vector & X, int N){
18  BLAS_FUNC(gemv)(&notrans,&N,&N,&fone,A,&N,B,&intone,&fzero,X,&intone);
19  }
20 
21  static inline void symv(gene_matrix & A, gene_vector & B, gene_vector & X, int N){
22  BLAS_FUNC(symv)(&lower, &N,&fone,A,&N,B,&intone,&fzero,X,&intone);
23  }
24 
25  static inline void syr2(gene_matrix & A, gene_vector & B, gene_vector & X, int N){
26  BLAS_FUNC(syr2)(&lower,&N,&fone,B,&intone,X,&intone,A,&N);
27  }
28 
29  static inline void ger(gene_matrix & A, gene_vector & X, gene_vector & Y, int N){
30  BLAS_FUNC(ger)(&N,&N,&fone,X,&intone,Y,&intone,A,&N);
31  }
32 
33  static inline void rot(gene_vector & A, gene_vector & B, SCALAR c, SCALAR s, int N){
34  BLAS_FUNC(rot)(&N,A,&intone,B,&intone,&c,&s);
35  }
36 
37  static inline void atv_product(gene_matrix & A, gene_vector & B, gene_vector & X, int N){
38  BLAS_FUNC(gemv)(&trans,&N,&N,&fone,A,&N,B,&intone,&fzero,X,&intone);
39  }
40 
41  static inline void matrix_matrix_product(gene_matrix & A, gene_matrix & B, gene_matrix & X, int N){
42  BLAS_FUNC(gemm)(&notrans,&notrans,&N,&N,&N,&fone,A,&N,B,&N,&fzero,X,&N);
43  }
44 
46  BLAS_FUNC(gemm)(&notrans,&notrans,&N,&N,&N,&fone,A,&N,B,&N,&fzero,X,&N);
47  }
48 
49  static inline void ata_product(gene_matrix & A, gene_matrix & X, int N){
50  BLAS_FUNC(syrk)(&lower,&trans,&N,&N,&fone,A,&N,&fzero,X,&N);
51  }
52 
53  static inline void aat_product(gene_matrix & A, gene_matrix & X, int N){
54  BLAS_FUNC(syrk)(&lower,&notrans,&N,&N,&fone,A,&N,&fzero,X,&N);
55  }
56 
57  static inline void axpy(SCALAR coef, const gene_vector & X, gene_vector & Y, int N){
58  BLAS_FUNC(axpy)(&N,&coef,X,&intone,Y,&intone);
59  }
60 
61  static inline void axpby(SCALAR a, const gene_vector & X, SCALAR b, gene_vector & Y, int N){
62  BLAS_FUNC(scal)(&N,&b,Y,&intone);
63  BLAS_FUNC(axpy)(&N,&a,X,&intone,Y,&intone);
64  }
65 
66  static inline void cholesky(const gene_matrix & X, gene_matrix & C, int N){
67  int N2 = N*N;
68  BLAS_FUNC(copy)(&N2, X, &intone, C, &intone);
69  char uplo = 'L';
70  int info = 0;
71  BLAS_FUNC(potrf)(&uplo, &N, C, &N, &info);
72  if(info!=0) std::cerr << "potrf_ error " << info << "\n";
73  }
74 
75  static inline void partial_lu_decomp(const gene_matrix & X, gene_matrix & C, int N){
76  int N2 = N*N;
77  BLAS_FUNC(copy)(&N2, X, &intone, C, &intone);
78  int info = 0;
79  int * ipiv = (int*)alloca(sizeof(int)*N);
80  BLAS_FUNC(getrf)(&N, &N, C, &N, ipiv, &info);
81  if(info!=0) std::cerr << "getrf_ error " << info << "\n";
82  }
83 
84  static inline void trisolve_lower(const gene_matrix & L, const gene_vector& B, gene_vector & X, int N){
85  BLAS_FUNC(copy)(&N, B, &intone, X, &intone);
86  BLAS_FUNC(trsv)(&lower, &notrans, &nonunit, &N, L, &N, X, &intone);
87  }
88 
89  static inline void trisolve_lower_matrix(const gene_matrix & L, const gene_matrix& B, gene_matrix & X, int N){
90  BLAS_FUNC(copy)(&N, B, &intone, X, &intone);
91  BLAS_FUNC(trsm)(&right, &lower, &notrans, &nonunit, &N, &N, &fone, L, &N, X, &N);
92  }
93 
94  static inline void trmm(gene_matrix & A, gene_matrix & B, gene_matrix & /*X*/, int N){
95  BLAS_FUNC(trmm)(&left, &lower, &notrans,&nonunit, &N,&N,&fone,A,&N,B,&N);
96  }
97 
98  #ifdef HAS_LAPACK
99 
100  static inline void lu_decomp(const gene_matrix & X, gene_matrix & C, int N){
101  int N2 = N*N;
102  BLAS_FUNC(copy)(&N2, X, &intone, C, &intone);
103  int info = 0;
104  int * ipiv = (int*)alloca(sizeof(int)*N);
105  int * jpiv = (int*)alloca(sizeof(int)*N);
106  BLAS_FUNC(getc2)(&N, C, &N, ipiv, jpiv, &info);
107  }
108 
109 
110 
111  static inline void hessenberg(const gene_matrix & X, gene_matrix & C, int N){
112  {
113  int N2 = N*N;
114  int inc = 1;
115  BLAS_FUNC(copy)(&N2, X, &inc, C, &inc);
116  }
117  int info = 0;
118  int ilo = 1;
119  int ihi = N;
120  int bsize = 64;
121  int worksize = N*bsize;
122  SCALAR* d = new SCALAR[N+worksize];
123  BLAS_FUNC(gehrd)(&N, &ilo, &ihi, C, &N, d, d+N, &worksize, &info);
124  delete[] d;
125  }
126 
127  static inline void tridiagonalization(const gene_matrix & X, gene_matrix & C, int N){
128  {
129  int N2 = N*N;
130  int inc = 1;
131  BLAS_FUNC(copy)(&N2, X, &inc, C, &inc);
132  }
133  char uplo = 'U';
134  int info = 0;
135  int bsize = 64;
136  int worksize = N*bsize;
137  SCALAR* d = new SCALAR[3*N+worksize];
138  BLAS_FUNC(sytrd)(&uplo, &N, C, &N, d, d+N, d+2*N, d+3*N, &worksize, &info);
139  delete[] d;
140  }
141 
142  #endif // HAS_LAPACK
143 
144 };
145 
int EIGEN_BLAS_FUNC() rot(int *n, RealScalar *px, int *incx, RealScalar *py, int *incy, RealScalar *pc, RealScalar *ps)
#define BLAS_FUNC(NAME)
int EIGEN_BLAS_FUNC() syr2(const char *uplo, const int *n, const RealScalar *palpha, const RealScalar *px, const int *incx, const RealScalar *py, const int *incy, RealScalar *pc, const int *ldc)
const char Y
static void matrix_vector_product(gene_matrix &A, gene_vector &B, gene_vector &X, int N)
static char lower
Scalar * b
Definition: benchVecAdd.cpp:17
static void trmm(gene_matrix &A, gene_matrix &B, gene_matrix &, int N)
int EIGEN_BLAS_FUNC() trmm(const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n, const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb)
Definition: level3_impl.h:183
Scalar Scalar * c
Definition: benchVecAdd.cpp:17
static int intone
static void trisolve_lower_matrix(const gene_matrix &L, const gene_matrix &B, gene_matrix &X, int N)
static void aat_product(gene_matrix &A, gene_matrix &X, int N)
MatrixXd L
Definition: LLT_example.cpp:6
static void transposed_matrix_matrix_product(gene_matrix &A, gene_matrix &B, gene_matrix &X, int N)
static char trans
#define N
Definition: gksort.c:12
static void ger(gene_matrix &A, gene_vector &X, gene_vector &Y, int N)
static void symv(gene_matrix &A, gene_vector &B, gene_vector &X, int N)
EIGEN_DONT_INLINE void gemv(const Mat &A, const Vec &B, Vec &C)
Definition: gemv.cpp:4
static void syr2(gene_matrix &A, gene_vector &B, gene_vector &X, int N)
else if n * info
static void ata_product(gene_matrix &A, gene_matrix &X, int N)
static char left
static void axpy(SCALAR coef, const gene_vector &X, gene_vector &Y, int N)
static std::string name()
int EIGEN_BLAS_FUNC() trsv(const char *uplo, const char *opa, const char *diag, const int *n, const RealScalar *pa, const int *lda, RealScalar *pb, const int *incb)
Definition: level2_impl.h:86
static char notrans
static void axpby(SCALAR a, const gene_vector &X, SCALAR b, gene_vector &Y, int N)
void hessenberg(int size=Size)
Definition: hessenberg.cpp:14
int EIGEN_BLAS_FUNC() ger(int *m, int *n, Scalar *palpha, Scalar *px, int *incx, Scalar *py, int *incy, Scalar *pa, int *lda)
RealScalar s
static char right
EIGEN_DONT_INLINE void gemm(const A &a, const B &b, C &c)
Definition: bench_gemm.cpp:162
int EIGEN_BLAS_FUNC() trsm(const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n, const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb)
Definition: level3_impl.h:78
#define MAKE_STRING(S)
Matrix< Scalar, Dynamic, Dynamic > C
Definition: bench_gemm.cpp:50
static void partial_lu_decomp(const gene_matrix &X, gene_matrix &C, int N)
static char nonunit
static void trisolve_lower(const gene_matrix &L, const gene_vector &B, gene_vector &X, int N)
static void cholesky(const gene_matrix &X, gene_matrix &C, int N)
int EIGEN_BLAS_FUNC() symv(const char *uplo, const int *n, const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *px, const int *incx, const RealScalar *pbeta, RealScalar *py, const int *incy)
#define SCALAR
Definition: bench_gemm.cpp:23
static void atv_product(gene_matrix &A, gene_vector &B, gene_vector &X, int N)
static void matrix_matrix_product(gene_matrix &A, gene_matrix &B, gene_matrix &X, int N)
static void rot(gene_vector &A, gene_vector &B, SCALAR c, SCALAR s, int N)
int EIGEN_BLAS_FUNC() axpy(const int *n, const RealScalar *palpha, const RealScalar *px, const int *incx, RealScalar *py, const int *incy)
Definition: level1_impl.h:12
#define X
Definition: icosphere.cpp:20
int EIGEN_BLAS_FUNC() syrk(const char *uplo, const char *op, const int *n, const int *k, const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
Definition: level3_impl.h:357
int EIGEN_BLAS_FUNC() scal(int *n, RealScalar *palpha, RealScalar *px, int *incx)
Definition: level1_impl.h:117


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:33:59