00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023 #ifdef __cplusplus
00024 extern "C" {
00025 #endif
00026
00027
00028
00029 #include <stdlib.h>
00030 #include <string.h>
00031 #include <assert.h>
00032 #include <stdio.h>
00033
00034 #include "decision-tree.h"
00035
00036
00037
00038
00039
00040 #define DT_TABLE_DEPTH (4)
00041 #define WILDCARD_SPEC (-1)
00042
00043 enum {
00044 DT_VAL,
00045 DT_TABLE
00046 };
00047
00048
00049
00050
00051
00052 struct DTNodeStruct;
00053 struct DTTableStruct;
00054
00055 struct DTTableStruct {
00056 int numEntries;
00057 struct DTNodeStruct** entries;
00058 struct DTNodeStruct* defaultEntry;
00059 };
00060
00061 struct DTNodeStruct {
00062 int type;
00063 union {
00064 struct DTTableStruct subTree;
00065 REAL_VALUE val;
00066 } data;
00067 };
00068
00069 typedef struct DTNodeStruct DTNode;
00070 typedef struct DTTableStruct DTTable;
00071
00072
00073
00074
00075
00076 static DTNode* dtNewNodeVal(REAL_VALUE val);
00077 static DTNode* dtNewNodeTable(int numEntries);
00078 static void dtInitTable(DTTable* t, int numEntries);
00079 static void dtDestroyNode(DTNode* n);
00080 static void dtDestroyTable(DTTable* t);
00081 static DTNode* dtDeepCopyNode(const DTNode* in);
00082 static void dtDeepCopyTable(DTTable* out, const DTTable* in);
00083 static DTNode* dtConvertToTable(DTNode* in, int numEntries);
00084 static void dtSpaces(int indent);
00085 static void dtDebugPrintNode(DTNode* n, int indent);
00086 static void dtDebugPrintTable(DTTable* t, int indent);
00087
00088
00089
00090
00091
00092 static int* gTableSizes = NULL;
00093 static DTNode* gTree = NULL;
00094
00095
00096
00097
00098
00099 static DTNode* dtNewNodeVal(REAL_VALUE val)
00100 {
00101 DTNode* out;
00102
00103 out = (DTNode*) malloc(sizeof(DTNode));
00104 checkAllocatedPointer((void *)out);
00105
00106 out->type = DT_VAL;
00107 out->data.val = val;
00108
00109 return out;
00110 }
00111
00112 static DTNode* dtNewNodeTable(int numEntries)
00113 {
00114 DTNode* out;
00115
00116 out = (DTNode*) malloc(sizeof(DTNode));
00117 checkAllocatedPointer((void *)out);
00118
00119 out->type = DT_TABLE;
00120 dtInitTable(&out->data.subTree, numEntries);
00121
00122 return out;
00123 }
00124
00125 static void dtInitTable(DTTable* t, int numEntries)
00126 {
00127 t->numEntries = numEntries;
00128 t->entries = (DTNode**) malloc(numEntries * sizeof(DTNode*));
00129 checkAllocatedPointer((void *)t->entries );
00130 memset(t->entries, 0, numEntries * sizeof(DTNode*));
00131 t->defaultEntry = NULL;
00132 }
00133
00134 static void dtDestroyNode(DTNode* n)
00135 {
00136 if (NULL == n) return;
00137
00138 switch (n->type) {
00139 case DT_VAL:
00140
00141 break;
00142 case DT_TABLE:
00143 dtDestroyTable(&n->data.subTree);
00144 break;
00145 default:
00146 assert(0 );
00147 }
00148
00149 free(n);
00150 }
00151
00152 static void dtDestroyTable(DTTable* t)
00153 {
00154 int i;
00155
00156 for (i=0; i < t->numEntries; i++) {
00157 dtDestroyNode(t->entries[i]);
00158 }
00159 dtDestroyNode(t->defaultEntry);
00160 free(t->entries);
00161 t->entries = NULL;
00162 }
00163
00164 static DTNode* dtDeepCopyNode(const DTNode* in)
00165 {
00166 DTNode* out;
00167
00168 if (NULL == in) {
00169 out = NULL;
00170 } else {
00171 switch (in->type) {
00172 case DT_VAL:
00173 out = dtNewNodeVal(in->data.val);
00174 break;
00175 case DT_TABLE:
00176 out = dtNewNodeTable(in->data.subTree.numEntries);
00177 dtDeepCopyTable(&out->data.subTree, &in->data.subTree);
00178 break;
00179 default:
00180 assert(0 );
00181 }
00182 }
00183
00184 return out;
00185 }
00186
00187 static void dtDeepCopyTable(DTTable* out, const DTTable* in)
00188 {
00189 int i;
00190
00191 dtInitTable(out, in->numEntries);
00192 out->defaultEntry = dtDeepCopyNode(in->defaultEntry);
00193 for (i=0; i < in->numEntries; i++) {
00194 if (NULL != in->entries[i]) {
00195 out->entries[i] = dtDeepCopyNode(in->entries[i]);
00196 }
00197 }
00198 }
00199
00200 static DTNode* dtConvertToTable(DTNode* in, int numEntries)
00201 {
00202 DTNode* out;
00203
00204 assert(NULL != in);
00205
00206 switch (in->type) {
00207 case DT_VAL:
00208 out = dtNewNodeTable(numEntries);
00209 out->data.subTree.defaultEntry = dtNewNodeVal(in->data.val);
00210 dtDestroyNode(in);
00211 break;
00212 case DT_TABLE:
00213 out = in;
00214 break;
00215 default:
00216 assert(0 );
00217 }
00218
00219 return out;
00220 }
00221
00222 DTNode* dtAddInternal(DTNode* node, int* vec, int index, REAL_VALUE val)
00223 {
00224 int i;
00225 int allWildcards;
00226 DTNode** entryP;
00227
00228
00229
00230
00231 allWildcards = 1;
00232 for (i = index; i < DT_TABLE_DEPTH; i++) {
00233 if (vec[i] != WILDCARD_SPEC) {
00234 allWildcards = 0;
00235 break;
00236 }
00237 }
00238
00239 if (allWildcards) {
00240
00241
00242 dtDestroyNode(node);
00243 node = dtNewNodeVal(val);
00244 } else if (WILDCARD_SPEC == vec[index]) {
00245
00246
00247
00248 node = dtConvertToTable(node, gTableSizes[index]);
00249 node->data.subTree.defaultEntry =
00250 dtAddInternal(node->data.subTree.defaultEntry, vec, index+1, val);
00251 for (i = 0; i < gTableSizes[index]; i++) {
00252 if (NULL != node->data.subTree.entries[i]) {
00253 node->data.subTree.entries[i] =
00254 dtAddInternal(node->data.subTree.entries[i], vec, index+1, val);
00255 }
00256 }
00257 } else {
00258
00259
00260 node = dtConvertToTable(node, gTableSizes[index]);
00261 entryP = &node->data.subTree.entries[vec[index]];
00262 if (NULL == *entryP) {
00263
00264
00265 *entryP = dtDeepCopyNode(node->data.subTree.defaultEntry);
00266 }
00267 *entryP = dtAddInternal(*entryP, vec, index+1, val);
00268 }
00269
00270 return node;
00271 }
00272
00273 static REAL_VALUE dtGetInternal(DTNode* node, int* vec, int index)
00274 {
00275 DTNode* entry;
00276
00277 assert(NULL != node);
00278
00279 switch (node->type) {
00280 case DT_VAL:
00281 return node->data.val;
00282 case DT_TABLE:
00283 entry = node->data.subTree.entries[vec[index]];
00284 if (NULL == entry) {
00285 entry = node->data.subTree.defaultEntry;
00286 }
00287 return dtGetInternal(entry, vec, index+1);
00288 default:
00289 assert(0 );
00290 }
00291 printf("Code bug\n");
00292 exit(EXIT_FAILURE);
00293 }
00294
00295 static void dtSpaces(int indent)
00296 {
00297 int i;
00298
00299 for (i=0; i < indent; i++) {
00300 putchar(' ');
00301 }
00302 }
00303
00304 static void dtDebugPrintNode(DTNode* n, int indent)
00305 {
00306 if (NULL == n) {
00307 dtSpaces(indent);
00308 printf("(NULL)\n");
00309 return;
00310 }
00311
00312 switch (n->type) {
00313 case DT_VAL:
00314 dtSpaces(indent);
00315 printf("val = %lf\n", n->data.val);
00316 break;
00317 case DT_TABLE:
00318 dtDebugPrintTable(&n->data.subTree, indent);
00319 break;
00320 default:
00321 assert(0 );
00322 }
00323 }
00324
00325 static void dtDebugPrintTable(DTTable* t, int indent)
00326 {
00327 int i;
00328
00329 dtSpaces(indent);
00330 printf("table:\n");
00331 dtSpaces(indent+2);
00332 printf("default:\n");
00333 dtDebugPrintNode(t->defaultEntry, indent+4);
00334 for (i=0; i < t->numEntries; i++) {
00335 dtSpaces(indent+2);
00336 if (NULL == t->entries[i]) {
00337 printf("entry %d: (default)\n", i);
00338 } else {
00339 printf("entry %d:\n", i);
00340 dtDebugPrintNode(t->entries[i], indent+4);
00341 }
00342 };
00343 }
00344
00345
00346
00347
00348
00349 void dtInit(int numActions, int numStates, int numObservations)
00350 {
00351
00352 if (NULL != gTree) return;
00353
00354 gTableSizes = (int*) malloc(DT_TABLE_DEPTH*sizeof(int));
00355 checkAllocatedPointer((void *)gTableSizes );
00356
00357 gTableSizes[0] = numActions;
00358 gTableSizes[1] = numStates;
00359 gTableSizes[2] = numStates;
00360 gTableSizes[3] = numObservations;
00361
00362 gTree = dtNewNodeVal(0);
00363 }
00364
00365 void dtAdd(int action, int cur_state, int next_state, int obs, REAL_VALUE val)
00366 {
00367 int vec[DT_TABLE_DEPTH];
00368 vec[0] = action;
00369 vec[1] = cur_state;
00370 vec[2] = next_state;
00371 vec[3] = obs;
00372
00373 gTree = dtAddInternal(gTree, vec, 0, val);
00374 }
00375
00376 REAL_VALUE dtGet(int action, int cur_state, int next_state, int obs)
00377 {
00378 int vec[DT_TABLE_DEPTH];
00379 vec[0] = action;
00380 vec[1] = cur_state;
00381 vec[2] = next_state;
00382 vec[3] = obs;
00383
00384 return dtGetInternal(gTree, vec, 0);
00385 }
00386
00387 void dtDeallocate(void)
00388 {
00389 dtDestroyNode(gTree);
00390 gTree = NULL;
00391 free(gTableSizes);
00392 gTableSizes = NULL;
00393 }
00394
00395 void dtDebugPrint(const char* header)
00396 {
00397 printf("%s\n", header);
00398 dtDebugPrintNode(gTree, 2);
00399 }
00400
00401
00402 #ifdef __cplusplus
00403 }
00404 #endif
00405
00406
00407
00408
00409
00410
00411
00412
00413
00414
00415
00416
00417
00418
00419
00420
00421
00422
00423
00424
00425
00426
00427
00428
00429
00430
00431
00432
00433
00434
00435
00436
00437
00438
00439
00440
00441
00442
00443
00444
00445
00446
00447
00448
00449
00450
00451
00452
00453
00454
00455
00456
00457
00458
00459
00460
00461
00462
00463
00464
00465
00466
00467
00468
00469
00470
00471