00001 #include <gsl/gsl_matrix.h>
00002 #include <gsl/gsl_blas.h>
00003 #include <gsl/gsl_linalg.h>
00004
00005 #include <assert.h>
00006 #include <math.h>
00007 #include <string.h>
00008
00009 #include "egsl.h"
00010 #include "egsl_imp.h"
00011 #define MAX_VALS 1024
00012 #define MAX_CONTEXTS 1024
00013
00014
00015
00016
00017 struct egsl_variable {
00018 gsl_matrix * gsl_m;
00019 };
00020
00021 struct egsl_context {
00022 char name[256];
00023 int nallocated;
00024 int nvars;
00025 struct egsl_variable vars[MAX_VALS];
00026 };
00027
00028
00029 int cid=0;
00030
00031 int max_cid = 0;
00032 static struct egsl_context egsl_contexts[MAX_CONTEXTS];
00033
00034
00035 int egsl_first_time = 1;
00036
00037 int egsl_total_allocations = 0;
00038 int egsl_cache_hits = 0;
00039
00040 void egsl_error(void) {
00041
00042
00043 egsl_print_stats();
00044
00045 assert(0);
00046 }
00047
00048 val assemble_val(int cid, int index, gsl_matrix*m) {
00049 val v;
00050 v.cid=cid;
00051 v.index=index;
00052 v.gslm = m;
00053 return v;
00054 }
00055
00056 int its_context(val v) {
00057 return v.cid;
00058 }
00059
00060 int its_var_index(val v) {
00061 return v.index;
00062 }
00063
00064
00065 #if 0
00066 inline void check_valid_val(val v) { int i = v.cid; v.cid=i;}
00067 #else
00068 void check_valid_val(val v) {
00069 int context = its_context(v);
00070 if(context>cid) {
00071 fprintf(stderr, "Val is from invalid context (%d>%d)\n",context,cid);
00072 egsl_error();
00073 }
00074 int var_index = its_var_index(v);
00075 if(var_index >= egsl_contexts[context].nvars) {
00076 fprintf(stderr, "Val is invalid (%d>%d)\n",var_index,
00077 egsl_contexts[context].nvars);
00078 egsl_error();
00079 }
00080 }
00081 #endif
00082
00083 gsl_matrix * egsl_gslm(val v) {
00084 check_valid_val(v);
00085 return v.gslm;
00086 }
00087
00088 void egsl_push() { egsl_push_named("unnamed context"); }
00089 void egsl_pop() { egsl_pop_named("unnamed context"); }
00090
00091 void egsl_push_named(const char*name) {
00092 if(egsl_first_time) {
00093 int c;
00094 for(c=0;c<MAX_CONTEXTS;c++) {
00095 egsl_contexts[c].nallocated = 0;
00096 egsl_contexts[c].nvars = 0;
00097 sprintf(egsl_contexts[c].name, "not yet used");
00098 }
00099 egsl_first_time = 0;
00100 }
00101 cid++;
00102
00103 if(cid >= MAX_CONTEXTS) {
00104 fprintf(stderr, "egsl: maximum number of contexts reached \n");
00105 egsl_print_stats();
00106 assert(0);
00107 }
00108
00109 if(max_cid < cid) max_cid = cid;
00110
00111 if(name != 0)
00112 sprintf(egsl_contexts[cid].name, "%s", name);
00113 else
00114 sprintf(egsl_contexts[cid].name, "Unnamed context");
00115 }
00116
00117 void egsl_pop_named(const char*name) {
00118 assert(cid>=0);
00119 if(name != 0) {
00120 if(strcmp(name, egsl_contexts[cid].name)) {
00121 fprintf(stderr, "egsl: context mismatch. You want to pop '%s', you are still at '%s'\n",
00122 name, egsl_contexts[cid].name);
00123 egsl_print_stats();
00124 assert(0);
00125 }
00126 }
00127
00128 egsl_contexts[cid].nvars = 0;
00129 sprintf(egsl_contexts[cid].name, "Popped");
00130 cid--;
00131 }
00132
00133
00134
00135
00136
00137
00138
00139
00140
00141
00142
00143
00144
00145
00146
00147 void egsl_print_stats() {
00148 fprintf(stderr, "egsl: total allocations: %d cache hits: %d\n",
00149 egsl_total_allocations, egsl_cache_hits);
00150
00151 int c; for(c=0;c<=max_cid&&c<MAX_CONTEXTS;c++) {
00152
00153
00154 fprintf(stderr, "egsl: context #%d allocations: %d active: %d name: '%s' \n",
00155 c, egsl_contexts[c].nallocated, egsl_contexts[c].nvars, egsl_contexts[c].name);
00156 }
00157 }
00158
00159 val egsl_alloc(size_t rows, size_t columns) {
00160 struct egsl_context*c = egsl_contexts+cid;
00161
00162
00163
00164
00165 if(c->nvars>=MAX_VALS) {
00166 fprintf(stderr,"Limit reached, in context %d, nvars is %d\n",cid,c->nvars);
00167 egsl_error();
00168 }
00169 int index = c->nvars;
00170 if(index<c->nallocated) {
00171 gsl_matrix*m = c->vars[index].gsl_m;
00172 if(m->size1 == rows && m->size2 == columns) {
00173 egsl_cache_hits++;
00174 c->nvars++;
00175 return assemble_val(cid,index,c->vars[index].gsl_m);
00176 } else {
00177 gsl_matrix_free(m);
00178 egsl_total_allocations++;
00179 c->vars[index].gsl_m = gsl_matrix_alloc((size_t)rows,(size_t)columns);
00180 c->nvars++;
00181 return assemble_val(cid,index,c->vars[index].gsl_m);
00182 }
00183 } else {
00184 egsl_total_allocations++;
00185 c->vars[index].gsl_m = gsl_matrix_alloc((size_t)rows,(size_t)columns);
00186 c->nvars++;
00187 c->nallocated++;
00188 return assemble_val(cid,index,c->vars[index].gsl_m);
00189 }
00190 }
00191
00192 val egsl_alloc_in_context(int context, size_t rows, size_t columns) {
00193 struct egsl_context*c = egsl_contexts+context;
00194
00195 if(c->nvars>=MAX_VALS) {
00196 fprintf(stderr,"Limit reached, in context %d, nvars is %d\n",context,c->nvars);
00197 egsl_error();
00198 }
00199 int index = c->nvars;
00200 if(index<c->nallocated) {
00201 gsl_matrix*m = c->vars[index].gsl_m;
00202 if(m->size1 == rows && m->size2 == columns) {
00203 egsl_cache_hits++;
00204 c->nvars++;
00205 return assemble_val(context,index,c->vars[index].gsl_m);
00206 } else {
00207 gsl_matrix_free(m);
00208 egsl_total_allocations++;
00209 c->vars[index].gsl_m = gsl_matrix_alloc((size_t)rows,(size_t)columns);
00210 c->nvars++;
00211 return assemble_val(context,index,c->vars[index].gsl_m);
00212 }
00213 } else {
00214 egsl_total_allocations++;
00215 c->vars[index].gsl_m = gsl_matrix_alloc((size_t)rows,(size_t)columns);
00216 c->nvars++;
00217 c->nallocated++;
00218 return assemble_val(context,index,c->vars[index].gsl_m);
00219 }
00220 }
00221
00223 val egsl_promote(val v) {
00224 if(cid==0) {
00225 egsl_error();
00226 }
00227
00228 gsl_matrix * m = egsl_gslm(v);
00229 val v2 = egsl_alloc_in_context(cid-1, m->size1, m->size2);
00230 gsl_matrix * m2 = egsl_gslm(v2);
00231 gsl_matrix_memcpy(m2, m);
00232 return v2;
00233 }
00234
00235
00236
00237
00238 void egsl_expect_size(val v, size_t rows, size_t cols) {
00239 gsl_matrix * m = egsl_gslm(v);
00240
00241 int bad = (rows && (m->size1!=rows)) || (cols && (m->size2!=cols));
00242 if(bad) {
00243 fprintf(stderr, "Matrix size is %d,%d while I expect %d,%d",
00244 (int)m->size1,(int)m->size2,(int)rows,(int)cols);
00245
00246 egsl_error();
00247 }
00248 }
00249
00250
00251 void egsl_print(const char*str, val v) {
00252 gsl_matrix * m = egsl_gslm(v);
00253 size_t i,j;
00254 int context = its_context(v);
00255 int var_index = its_var_index(v);
00256 fprintf(stderr, "%s = (%d x %d) context=%d index=%d\n",
00257 str,(int)m->size1,(int)m->size2, context, var_index);
00258
00259 for(i=0;i<m->size1;i++) {
00260 if(i==0)
00261 fprintf(stderr, " [ ");
00262 else
00263 fprintf(stderr, " ");
00264
00265 for(j=0;j<m->size2;j++)
00266 fprintf(stderr, "%f ", gsl_matrix_get(m,i,j));
00267
00268
00269 if(i==m->size1-1)
00270 fprintf(stderr, "] \n");
00271 else
00272 fprintf(stderr, "; \n");
00273 }
00274 }
00275
00276 double* egsl_atmp(val v, size_t i, size_t j) {
00277 gsl_matrix * m = egsl_gslm(v);
00278 return gsl_matrix_ptr(m,(size_t)i,(size_t)j);
00279 }
00280
00281
00282 double egsl_norm(val v1){
00283 egsl_expect_size(v1, 0, 1);
00284 double n=0;
00285 size_t i;
00286 gsl_matrix * m = egsl_gslm(v1);
00287 for(i=0;i<m->size1;i++) {
00288 double v = gsl_matrix_get(m,i,0);
00289 n += v * v;
00290 }
00291 return sqrt(n);
00292 }
00293
00294 double egsl_atv(val v1, size_t i){
00295 return *egsl_atmp(v1, i, 0);
00296 }
00297
00298 double egsl_atm(val v1, size_t i, size_t j){
00299 return *egsl_atmp(v1, i, j);
00300 }
00301
00302 void egsl_free(void){
00303 int c;
00304 for(c=0;c<=max_cid;c++) {
00305 for(int i=egsl_contexts[c].nvars; i<egsl_contexts[c].nallocated; i++){
00306 gsl_matrix_free(egsl_contexts[c].vars[i].gsl_m);
00307 }
00308 egsl_contexts[c].nallocated = egsl_contexts[c].nvars;
00309 }
00310 }
00311