00001 #include "mini_blas.h" 00002 00003 void cpu_gemm_nn(int TA, int TB, int M, int N, int K, float ALPHA, 00004 float *A, int lda, 00005 float *B, int ldb, 00006 float BETA, 00007 float *C, int ldc) 00008 { 00009 int i,j,k; 00010 for(i = 0; i < M; ++i){ 00011 for(k = 0; k < K; ++k){ 00012 register float A_PART = ALPHA*A[i*lda+k]; 00013 for(j = 0; j < N; ++j){ 00014 C[i*ldc+j] += A_PART*B[k*ldb+j]; 00015 } 00016 } 00017 } 00018 } 00019 00020 void cpu_gemm_nt(int TA, int TB, int M, int N, int K, float ALPHA, 00021 float *A, int lda, 00022 float *B, int ldb, 00023 float BETA, 00024 float *C, int ldc) 00025 { 00026 int i,j,k; 00027 for(i = 0; i < M; ++i){ 00028 for(j = 0; j < N; ++j){ 00029 register float sum = 0; 00030 for(k = 0; k < K; ++k){ 00031 sum += ALPHA*A[i*lda+k]*B[k+j*ldb]; 00032 } 00033 C[i*ldc+j] += sum; 00034 } 00035 } 00036 } 00037 00038 void cpu_gemm_tn(int TA, int TB, int M, int N, int K, float ALPHA, 00039 float *A, int lda, 00040 float *B, int ldb, 00041 float BETA, 00042 float *C, int ldc) 00043 { 00044 int i,j,k; 00045 for(i = 0; i < M; ++i){ 00046 for(k = 0; k < K; ++k){ 00047 register float A_PART = ALPHA*A[k*lda+i]; 00048 for(j = 0; j < N; ++j){ 00049 C[i*ldc+j] += A_PART*B[k*ldb+j]; 00050 } 00051 } 00052 } 00053 } 00054 void cpu_gemm_tt(int TA, int TB, int M, int N, int K, float ALPHA, 00055 float *A, int lda, 00056 float *B, int ldb, 00057 float BETA, 00058 float *C, int ldc) 00059 { 00060 int i,j,k; 00061 for(i = 0; i < M; ++i){ 00062 for(j = 0; j < N; ++j){ 00063 for(k = 0; k < K; ++k){ 00064 C[i*ldc+j] += ALPHA*A[i+k*lda]*B[k+j*ldb]; 00065 } 00066 } 00067 } 00068 } 00069 00070 00071 void cpu_gemm(int TA, int TB, int M, int N, int K, float ALPHA, 00072 float *A, int lda, 00073 float *B, int ldb, 00074 float BETA, 00075 float *C, int ldc) 00076 { 00077 int i, j; 00078 for(i = 0; i < M; ++i){ 00079 for(j = 0; j < N; ++j){ 00080 C[i*ldc + j] *= BETA; 00081 } 00082 } 00083 if(!TA && !TB) 00084 cpu_gemm_nn( TA, TB, M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc); 00085 else if(TA && !TB) 00086 cpu_gemm_tn( TA, TB, M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc); 00087 else if(!TA && TB) 00088 cpu_gemm_nt( TA, TB, M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc); 00089 else 00090 cpu_gemm_tt( TA, TB, M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc); 00091 }