00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #ifndef C_BLAS_PRODUIT_MATRICE_VECTEUR_HH
00021 #define C_BLAS_PRODUIT_MATRICE_VECTEUR_HH
00022
00023 #include "f77_interface.hh"
00024 #include <complex>
00025 extern "C"
00026 {
00027 #include "cblas.h"
00028
00029
00030 #include "blas.h"
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060 void ssytrd_(char *uplo, const int *n, float *a, const int *lda, float *d, float *e, float *tau, float *work, int *lwork, int *info );
00061 void sgehrd_( const int *n, int *ilo, int *ihi, float *a, const int *lda, float *tau, float *work, int *lwork, int *info );
00062
00063
00064
00065
00066
00067 void sgetc2_(const int* n, float *a, const int *lda, int *ipiv, int *jpiv, int*info );
00068 #ifdef HAS_LAPACK
00069 #endif
00070 }
00071
00072 #define MAKE_STRING2(S) #S
00073 #define MAKE_STRING(S) MAKE_STRING2(S)
00074
00075 template<class real>
00076 class C_BLAS_interface : public f77_interface_base<real>
00077 {
00078 public :
00079
00080 typedef typename f77_interface_base<real>::gene_matrix gene_matrix;
00081 typedef typename f77_interface_base<real>::gene_vector gene_vector;
00082
00083 static inline std::string name( void )
00084 {
00085 return MAKE_STRING(CBLASNAME);
00086 }
00087
00088 static inline void matrix_vector_product(gene_matrix & A, gene_vector & B, gene_vector & X, int N)
00089 {
00090 cblas_dgemv(CblasColMajor,CblasNoTrans,N,N,1.0,A,N,B,1,0.0,X,1);
00091 }
00092
00093 static inline void atv_product(gene_matrix & A, gene_vector & B, gene_vector & X, int N)
00094 {
00095 cblas_dgemv(CblasColMajor,CblasTrans,N,N,1.0,A,N,B,1,0.0,X,1);
00096 }
00097
00098 static inline void symv(gene_matrix & A, gene_vector & B, gene_vector & X, int N)
00099 {
00100 cblas_dsymv(CblasColMajor,CblasLower,CblasTrans,N,N,1.0,A,N,B,1,0.0,X,1);
00101 }
00102
00103 static inline void matrix_matrix_product(gene_matrix & A, gene_matrix & B, gene_matrix & X, int N){
00104 cblas_dgemm(CblasColMajor,CblasNoTrans,CblasNoTrans,N,N,N,1.0,A,N,B,N,0.0,X,N);
00105 }
00106
00107 static inline void transposed_matrix_matrix_product(gene_matrix & A, gene_matrix & B, gene_matrix & X, int N){
00108 cblas_dgemm(CblasColMajor,CblasTrans,CblasTrans,N,N,N,1.0,A,N,B,N,0.0,X,N);
00109 }
00110
00111 static inline void ata_product(gene_matrix & A, gene_matrix & X, int N){
00112 cblas_dgemm(CblasColMajor,CblasTrans,CblasNoTrans,N,N,N,1.0,A,N,A,N,0.0,X,N);
00113 }
00114
00115 static inline void aat_product(gene_matrix & A, gene_matrix & X, int N){
00116 cblas_dgemm(CblasColMajor,CblasNoTrans,CblasTrans,N,N,N,1.0,A,N,A,N,0.0,X,N);
00117 }
00118
00119 static inline void axpy(real coef, const gene_vector & X, gene_vector & Y, int N){
00120 cblas_daxpy(N,coef,X,1,Y,1);
00121 }
00122
00123 static inline void axpby(real a, const gene_vector & X, real b, gene_vector & Y, int N){
00124 cblas_dscal(N,b,Y,1);
00125 cblas_daxpy(N,a,X,1,Y,1);
00126 }
00127
00128 };
00129
00130 static float fone = 1;
00131 static float fzero = 0;
00132 static char notrans = 'N';
00133 static char trans = 'T';
00134 static char nonunit = 'N';
00135 static char lower = 'L';
00136 static char right = 'R';
00137 static char left = 'L';
00138 static int intone = 1;
00139
00140 template<>
00141 class C_BLAS_interface<float> : public f77_interface_base<float>
00142 {
00143
00144 public :
00145
00146 static inline std::string name( void )
00147 {
00148 return MAKE_STRING(CBLASNAME);
00149 }
00150
00151 static inline void matrix_vector_product(gene_matrix & A, gene_vector & B, gene_vector & X, int N){
00152 #ifdef PUREBLAS
00153 sgemv_(¬rans,&N,&N,&fone,A,&N,B,&intone,&fzero,X,&intone);
00154 #else
00155 cblas_sgemv(CblasColMajor,CblasNoTrans,N,N,1.0,A,N,B,1,0.0,X,1);
00156 #endif
00157 }
00158
00159 static inline void symv(gene_matrix & A, gene_vector & B, gene_vector & X, int N){
00160 #ifdef PUREBLAS
00161 ssymv_(&lower, &N,&fone,A,&N,B,&intone,&fzero,X,&intone);
00162 #else
00163 cblas_ssymv(CblasColMajor,CblasLower,N,1.0,A,N,B,1,0.0,X,1);
00164 #endif
00165 }
00166
00167 static inline void syr2(gene_matrix & A, gene_vector & B, gene_vector & X, int N){
00168 #ifdef PUREBLAS
00169 ssyr2_(&lower,&N,&fone,B,&intone,X,&intone,A,&N);
00170 #else
00171 cblas_ssyr2(CblasColMajor,CblasLower,N,1.0,B,1,X,1,A,N);
00172 #endif
00173 }
00174
00175 static inline void ger(gene_matrix & A, gene_vector & X, gene_vector & Y, int N){
00176 #ifdef PUREBLAS
00177 sger_(&N,&N,&fone,X,&intone,Y,&intone,A,&N);
00178 #else
00179 cblas_sger(CblasColMajor,N,N,1.0,X,1,Y,1,A,N);
00180 #endif
00181 }
00182
00183 static inline void rot(gene_vector & A, gene_vector & B, float c, float s, int N){
00184 #ifdef PUREBLAS
00185 srot_(&N,A,&intone,B,&intone,&c,&s);
00186 #else
00187 cblas_srot(N,A,1,B,1,c,s);
00188 #endif
00189 }
00190
00191 static inline void atv_product(gene_matrix & A, gene_vector & B, gene_vector & X, int N){
00192 #ifdef PUREBLAS
00193 sgemv_(&trans,&N,&N,&fone,A,&N,B,&intone,&fzero,X,&intone);
00194 #else
00195 cblas_sgemv(CblasColMajor,CblasTrans,N,N,1.0,A,N,B,1,0.0,X,1);
00196 #endif
00197 }
00198
00199 static inline void matrix_matrix_product(gene_matrix & A, gene_matrix & B, gene_matrix & X, int N){
00200 #ifdef PUREBLAS
00201 sgemm_(¬rans,¬rans,&N,&N,&N,&fone,A,&N,B,&N,&fzero,X,&N);
00202 #else
00203 cblas_sgemm(CblasColMajor,CblasNoTrans,CblasNoTrans,N,N,N,1.0,A,N,B,N,0.0,X,N);
00204 #endif
00205 }
00206
00207 static inline void transposed_matrix_matrix_product(gene_matrix & A, gene_matrix & B, gene_matrix & X, int N){
00208 #ifdef PUREBLAS
00209 sgemm_(¬rans,¬rans,&N,&N,&N,&fone,A,&N,B,&N,&fzero,X,&N);
00210 #else
00211 cblas_sgemm(CblasColMajor,CblasNoTrans,CblasNoTrans,N,N,N,1.0,A,N,B,N,0.0,X,N);
00212 #endif
00213 }
00214
00215 static inline void ata_product(gene_matrix & A, gene_matrix & X, int N){
00216 #ifdef PUREBLAS
00217 sgemm_(&trans,¬rans,&N,&N,&N,&fone,A,&N,A,&N,&fzero,X,&N);
00218 #else
00219 cblas_sgemm(CblasColMajor,CblasTrans,CblasNoTrans,N,N,N,1.0,A,N,A,N,0.0,X,N);
00220 #endif
00221 }
00222
00223 static inline void aat_product(gene_matrix & A, gene_matrix & X, int N){
00224 #ifdef PUREBLAS
00225 sgemm_(¬rans,&trans,&N,&N,&N,&fone,A,&N,A,&N,&fzero,X,&N);
00226 #else
00227 cblas_sgemm(CblasColMajor,CblasNoTrans,CblasTrans,N,N,N,1.0,A,N,A,N,0.0,X,N);
00228 #endif
00229 }
00230
00231 static inline void axpy(float coef, const gene_vector & X, gene_vector & Y, int N){
00232 #ifdef PUREBLAS
00233 saxpy_(&N,&coef,X,&intone,Y,&intone);
00234 #else
00235 cblas_saxpy(N,coef,X,1,Y,1);
00236 #endif
00237 }
00238
00239 static inline void axpby(float a, const gene_vector & X, float b, gene_vector & Y, int N){
00240 #ifdef PUREBLAS
00241 sscal_(&N,&b,Y,&intone);
00242 saxpy_(&N,&a,X,&intone,Y,&intone);
00243 #else
00244 cblas_sscal(N,b,Y,1);
00245 cblas_saxpy(N,a,X,1,Y,1);
00246 #endif
00247 }
00248
00249 static inline void cholesky(const gene_matrix & X, gene_matrix & C, int N){
00250 int N2 = N*N;
00251 scopy_(&N2, X, &intone, C, &intone);
00252 char uplo = 'L';
00253 int info = 0;
00254 spotrf_(&uplo, &N, C, &N, &info);
00255 if(info!=0) std::cerr << "spotrf_ error " << info << "\n";
00256 }
00257
00258 static inline void partial_lu_decomp(const gene_matrix & X, gene_matrix & C, int N){
00259 int N2 = N*N;
00260 scopy_(&N2, X, &intone, C, &intone);
00261 char uplo = 'L';
00262 int info = 0;
00263 int * ipiv = (int*)alloca(sizeof(int)*N);
00264 sgetrf_(&N, &N, C, &N, ipiv, &info);
00265 if(info!=0) std::cerr << "sgetrf_ error " << info << "\n";
00266 }
00267
00268 #ifdef HAS_LAPACK
00269
00270 static inline void lu_decomp(const gene_matrix & X, gene_matrix & C, int N){
00271 int N2 = N*N;
00272 scopy_(&N2, X, &intone, C, &intone);
00273 char uplo = 'L';
00274 int info = 0;
00275 int * ipiv = (int*)alloca(sizeof(int)*N);
00276 int * jpiv = (int*)alloca(sizeof(int)*N);
00277 sgetc2_(&N, C, &N, ipiv, jpiv, &info);
00278 }
00279
00280
00281
00282 static inline void hessenberg(const gene_matrix & X, gene_matrix & C, int N){
00283 #ifdef PUREBLAS
00284 {
00285 int N2 = N*N;
00286 int inc = 1;
00287 scopy_(&N2, X, &inc, C, &inc);
00288 }
00289 #else
00290 cblas_scopy(N*N, X, 1, C, 1);
00291 #endif
00292 int info = 0;
00293 int ilo = 1;
00294 int ihi = N;
00295 int bsize = 64;
00296 int worksize = N*bsize;
00297 float* d = new float[N+worksize];
00298 sgehrd_(&N, &ilo, &ihi, C, &N, d, d+N, &worksize, &info);
00299 delete[] d;
00300 }
00301
00302 static inline void tridiagonalization(const gene_matrix & X, gene_matrix & C, int N){
00303 #ifdef PUREBLAS
00304 {
00305 int N2 = N*N;
00306 int inc = 1;
00307 scopy_(&N2, X, &inc, C, &inc);
00308 }
00309 #else
00310 cblas_scopy(N*N, X, 1, C, 1);
00311 #endif
00312 char uplo = 'U';
00313 int info = 0;
00314 int ilo = 1;
00315 int ihi = N;
00316 int bsize = 64;
00317 int worksize = N*bsize;
00318 float* d = new float[3*N+worksize];
00319 ssytrd_(&uplo, &N, C, &N, d, d+N, d+2*N, d+3*N, &worksize, &info);
00320 delete[] d;
00321 }
00322 #endif
00323
00324 static inline void trisolve_lower(const gene_matrix & L, const gene_vector& B, gene_vector & X, int N){
00325 #ifdef PUREBLAS
00326 scopy_(&N, B, &intone, X, &intone);
00327 strsv_(&lower, ¬rans, &nonunit, &N, L, &N, X, &intone);
00328 #else
00329 cblas_scopy(N, B, 1, X, 1);
00330 cblas_strsv(CblasColMajor, CblasLower, CblasNoTrans, CblasNonUnit, N, L, N, X, 1);
00331 #endif
00332 }
00333
00334 static inline void trisolve_lower_matrix(const gene_matrix & L, const gene_matrix& B, gene_matrix & X, int N){
00335 #ifdef PUREBLAS
00336 scopy_(&N, B, &intone, X, &intone);
00337 strsm_(&right, &lower, ¬rans, &nonunit, &N, &N, &fone, L, &N, X, &N);
00338 #else
00339 cblas_scopy(N, B, 1, X, 1);
00340 cblas_strsm(CblasColMajor, CblasRight, CblasLower, CblasNoTrans, CblasNonUnit, N, N, 1, L, N, X, N);
00341 #endif
00342 }
00343
00344 static inline void trmm(gene_matrix & A, gene_matrix & B, gene_matrix & X, int N){
00345 #ifdef PUREBLAS
00346 strmm_(&left, &lower, ¬rans,&nonunit, &N,&N,&fone,A,&N,B,&N);
00347 #else
00348 cblas_strmm(CblasColMajor, CblasLeft, CblasLower, CblasNoTrans,CblasNonUnit, N,N,1,A,N,B,N);
00349 #endif
00350 }
00351
00352 };
00353
00354
00355 #endif
00356
00357
00358