00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049 #include <stdio.h>
00050 #include "lapacke.h"
00051 #include "lapacke_utils.h"
00052 #include "test_utils.h"
00053
00054 static void init_scalars_sgetri( lapack_int *n, lapack_int *lda,
00055 lapack_int *lwork );
00056 static void init_a( lapack_int size, float *a );
00057 static void init_ipiv( lapack_int size, lapack_int *ipiv );
00058 static void init_work( lapack_int size, float *work );
00059 static int compare_sgetri( float *a, float *a_i, lapack_int info,
00060 lapack_int info_i, lapack_int lda, lapack_int n );
00061
00062 int main(void)
00063 {
00064
00065 lapack_int n, n_i;
00066 lapack_int lda, lda_i;
00067 lapack_int lda_r;
00068 lapack_int lwork, lwork_i;
00069 lapack_int info, info_i;
00070 lapack_int i;
00071 int failed;
00072
00073
00074 float *a = NULL, *a_i = NULL;
00075 lapack_int *ipiv = NULL, *ipiv_i = NULL;
00076 float *work = NULL, *work_i = NULL;
00077 float *a_save = NULL;
00078 float *a_r = NULL;
00079
00080
00081 init_scalars_sgetri( &n, &lda, &lwork );
00082 lda_r = n+2;
00083 n_i = n;
00084 lda_i = lda;
00085 lwork_i = lwork;
00086
00087
00088 a = (float *)LAPACKE_malloc( lda*n * sizeof(float) );
00089 ipiv = (lapack_int *)LAPACKE_malloc( n * sizeof(lapack_int) );
00090 work = (float *)LAPACKE_malloc( lwork * sizeof(float) );
00091
00092
00093 a_i = (float *)LAPACKE_malloc( lda*n * sizeof(float) );
00094 ipiv_i = (lapack_int *)LAPACKE_malloc( n * sizeof(lapack_int) );
00095 work_i = (float *)LAPACKE_malloc( lwork * sizeof(float) );
00096
00097
00098 a_save = (float *)LAPACKE_malloc( lda*n * sizeof(float) );
00099
00100
00101 a_r = (float *)LAPACKE_malloc( n*(n+2) * sizeof(float) );
00102
00103
00104 init_a( lda*n, a );
00105 init_ipiv( n, ipiv );
00106 init_work( lwork, work );
00107
00108
00109 for( i = 0; i < lda*n; i++ ) {
00110 a_save[i] = a[i];
00111 }
00112
00113
00114 sgetri_( &n, a, &lda, ipiv, work, &lwork, &info );
00115
00116
00117
00118 for( i = 0; i < lda*n; i++ ) {
00119 a_i[i] = a_save[i];
00120 }
00121 for( i = 0; i < n; i++ ) {
00122 ipiv_i[i] = ipiv[i];
00123 }
00124 for( i = 0; i < lwork; i++ ) {
00125 work_i[i] = work[i];
00126 }
00127 info_i = LAPACKE_sgetri_work( LAPACK_COL_MAJOR, n_i, a_i, lda_i, ipiv_i,
00128 work_i, lwork_i );
00129
00130 failed = compare_sgetri( a, a_i, info, info_i, lda, n );
00131 if( failed == 0 ) {
00132 printf( "PASSED: column-major middle-level interface to sgetri\n" );
00133 } else {
00134 printf( "FAILED: column-major middle-level interface to sgetri\n" );
00135 }
00136
00137
00138
00139 for( i = 0; i < lda*n; i++ ) {
00140 a_i[i] = a_save[i];
00141 }
00142 for( i = 0; i < n; i++ ) {
00143 ipiv_i[i] = ipiv[i];
00144 }
00145 for( i = 0; i < lwork; i++ ) {
00146 work_i[i] = work[i];
00147 }
00148 info_i = LAPACKE_sgetri( LAPACK_COL_MAJOR, n_i, a_i, lda_i, ipiv_i );
00149
00150 failed = compare_sgetri( a, a_i, info, info_i, lda, n );
00151 if( failed == 0 ) {
00152 printf( "PASSED: column-major high-level interface to sgetri\n" );
00153 } else {
00154 printf( "FAILED: column-major high-level interface to sgetri\n" );
00155 }
00156
00157
00158
00159 for( i = 0; i < lda*n; i++ ) {
00160 a_i[i] = a_save[i];
00161 }
00162 for( i = 0; i < n; i++ ) {
00163 ipiv_i[i] = ipiv[i];
00164 }
00165 for( i = 0; i < lwork; i++ ) {
00166 work_i[i] = work[i];
00167 }
00168
00169 LAPACKE_sge_trans( LAPACK_COL_MAJOR, n, n, a_i, lda, a_r, n+2 );
00170 info_i = LAPACKE_sgetri_work( LAPACK_ROW_MAJOR, n_i, a_r, lda_r, ipiv_i,
00171 work_i, lwork_i );
00172
00173 LAPACKE_sge_trans( LAPACK_ROW_MAJOR, n, n, a_r, n+2, a_i, lda );
00174
00175 failed = compare_sgetri( a, a_i, info, info_i, lda, n );
00176 if( failed == 0 ) {
00177 printf( "PASSED: row-major middle-level interface to sgetri\n" );
00178 } else {
00179 printf( "FAILED: row-major middle-level interface to sgetri\n" );
00180 }
00181
00182
00183
00184 for( i = 0; i < lda*n; i++ ) {
00185 a_i[i] = a_save[i];
00186 }
00187 for( i = 0; i < n; i++ ) {
00188 ipiv_i[i] = ipiv[i];
00189 }
00190 for( i = 0; i < lwork; i++ ) {
00191 work_i[i] = work[i];
00192 }
00193
00194
00195 LAPACKE_sge_trans( LAPACK_COL_MAJOR, n, n, a_i, lda, a_r, n+2 );
00196 info_i = LAPACKE_sgetri( LAPACK_ROW_MAJOR, n_i, a_r, lda_r, ipiv_i );
00197
00198 LAPACKE_sge_trans( LAPACK_ROW_MAJOR, n, n, a_r, n+2, a_i, lda );
00199
00200 failed = compare_sgetri( a, a_i, info, info_i, lda, n );
00201 if( failed == 0 ) {
00202 printf( "PASSED: row-major high-level interface to sgetri\n" );
00203 } else {
00204 printf( "FAILED: row-major high-level interface to sgetri\n" );
00205 }
00206
00207
00208 if( a != NULL ) {
00209 LAPACKE_free( a );
00210 }
00211 if( a_i != NULL ) {
00212 LAPACKE_free( a_i );
00213 }
00214 if( a_r != NULL ) {
00215 LAPACKE_free( a_r );
00216 }
00217 if( a_save != NULL ) {
00218 LAPACKE_free( a_save );
00219 }
00220 if( ipiv != NULL ) {
00221 LAPACKE_free( ipiv );
00222 }
00223 if( ipiv_i != NULL ) {
00224 LAPACKE_free( ipiv_i );
00225 }
00226 if( work != NULL ) {
00227 LAPACKE_free( work );
00228 }
00229 if( work_i != NULL ) {
00230 LAPACKE_free( work_i );
00231 }
00232
00233 return 0;
00234 }
00235
00236
00237 static void init_scalars_sgetri( lapack_int *n, lapack_int *lda,
00238 lapack_int *lwork )
00239 {
00240 *n = 4;
00241 *lda = 8;
00242 *lwork = 512;
00243
00244 return;
00245 }
00246
00247
00248 static void init_a( lapack_int size, float *a ) {
00249 lapack_int i;
00250 for( i = 0; i < size; i++ ) {
00251 a[i] = 0;
00252 }
00253 a[0] = 5.250000000e+000;
00254 a[8] = -2.950000048e+000;
00255 a[16] = -9.499999881e-001;
00256 a[24] = -3.799999952e+000;
00257 a[1] = 3.428571522e-001;
00258 a[9] = 3.891428709e+000;
00259 a[17] = 2.375714302e+000;
00260 a[25] = 4.128571749e-001;
00261 a[2] = 3.009524047e-001;
00262 a[10] = -4.631179273e-001;
00263 a[18] = -1.513859272e+000;
00264 a[26] = 2.948207855e-001;
00265 a[3] = -2.114285827e-001;
00266 a[11] = -3.298825026e-001;
00267 a[19] = 4.723378923e-003;
00268 a[27] = 1.313732415e-001;
00269 }
00270 static void init_ipiv( lapack_int size, lapack_int *ipiv ) {
00271 lapack_int i;
00272 for( i = 0; i < size; i++ ) {
00273 ipiv[i] = 0;
00274 }
00275 ipiv[0] = 2;
00276 ipiv[1] = 2;
00277 ipiv[2] = 3;
00278 ipiv[3] = 4;
00279 }
00280 static void init_work( lapack_int size, float *work ) {
00281 lapack_int i;
00282 for( i = 0; i < size; i++ ) {
00283 work[i] = 0;
00284 }
00285 }
00286
00287
00288
00289 static int compare_sgetri( float *a, float *a_i, lapack_int info,
00290 lapack_int info_i, lapack_int lda, lapack_int n )
00291 {
00292 lapack_int i;
00293 int failed = 0;
00294 for( i = 0; i < lda*n; i++ ) {
00295 failed += compare_floats(a[i],a_i[i]);
00296 }
00297 failed += (info == info_i) ? 0 : 1;
00298 if( info != 0 || info_i != 0 ) {
00299 printf( "info=%d, info_i=%d\n",(int)info,(int)info_i );
00300 }
00301
00302 return failed;
00303 }