00001 #include "network.h" 00002 #include "utils.h" 00003 #include "parser.h" 00004 00005 char *dice_labels[] = {"face1","face2","face3","face4","face5","face6"}; 00006 00007 void train_dice(char *cfgfile, char *weightfile) 00008 { 00009 srand(time(0)); 00010 float avg_loss = -1; 00011 char *base = basecfg(cfgfile); 00012 char *backup_directory = "/home/pjreddie/backup/"; 00013 printf("%s\n", base); 00014 network net = parse_network_cfg(cfgfile); 00015 if(weightfile){ 00016 load_weights(&net, weightfile); 00017 } 00018 printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); 00019 int imgs = 1024; 00020 int i = *net.seen/imgs; 00021 char **labels = dice_labels; 00022 list *plist = get_paths("data/dice/dice.train.list"); 00023 char **paths = (char **)list_to_array(plist); 00024 printf("%d\n", plist->size); 00025 clock_t time; 00026 while(1){ 00027 ++i; 00028 time=clock(); 00029 data train = load_data_old(paths, imgs, plist->size, labels, 6, net.w, net.h); 00030 printf("Loaded: %lf seconds\n", sec(clock()-time)); 00031 00032 time=clock(); 00033 float loss = train_network(net, train); 00034 if(avg_loss == -1) avg_loss = loss; 00035 avg_loss = avg_loss*.9 + loss*.1; 00036 printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), *net.seen); 00037 free_data(train); 00038 if((i % 100) == 0) net.learning_rate *= .1; 00039 if(i%100==0){ 00040 char buff[256]; 00041 sprintf(buff, "%s/%s_%d.weights",backup_directory,base, i); 00042 save_weights(net, buff); 00043 } 00044 } 00045 } 00046 00047 void validate_dice(char *filename, char *weightfile) 00048 { 00049 network net = parse_network_cfg(filename); 00050 if(weightfile){ 00051 load_weights(&net, weightfile); 00052 } 00053 srand(time(0)); 00054 00055 char **labels = dice_labels; 00056 list *plist = get_paths("data/dice/dice.val.list"); 00057 00058 char **paths = (char **)list_to_array(plist); 00059 int m = plist->size; 00060 free_list(plist); 00061 00062 data val = load_data_old(paths, m, 0, labels, 6, net.w, net.h); 00063 float *acc = network_accuracies(net, val, 2); 00064 printf("Validation Accuracy: %f, %d images\n", acc[0], m); 00065 free_data(val); 00066 } 00067 00068 void test_dice(char *cfgfile, char *weightfile, char *filename) 00069 { 00070 network net = parse_network_cfg(cfgfile); 00071 if(weightfile){ 00072 load_weights(&net, weightfile); 00073 } 00074 set_batch_network(&net, 1); 00075 srand(2222222); 00076 int i = 0; 00077 char **names = dice_labels; 00078 char buff[256]; 00079 char *input = buff; 00080 int indexes[6]; 00081 while(1){ 00082 if(filename){ 00083 strncpy(input, filename, 256); 00084 }else{ 00085 printf("Enter Image Path: "); 00086 fflush(stdout); 00087 input = fgets(input, 256, stdin); 00088 if(!input) return; 00089 strtok(input, "\n"); 00090 } 00091 image im = load_image_color(input, net.w, net.h); 00092 float *X = im.data; 00093 float *predictions = network_predict(net, X); 00094 top_predictions(net, 6, indexes); 00095 for(i = 0; i < 6; ++i){ 00096 int index = indexes[i]; 00097 printf("%s: %f\n", names[index], predictions[index]); 00098 } 00099 free_image(im); 00100 if (filename) break; 00101 } 00102 } 00103 00104 void run_dice(int argc, char **argv) 00105 { 00106 if(argc < 4){ 00107 fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]); 00108 return; 00109 } 00110 00111 char *cfg = argv[3]; 00112 char *weights = (argc > 4) ? argv[4] : 0; 00113 char *filename = (argc > 5) ? argv[5]: 0; 00114 if(0==strcmp(argv[2], "test")) test_dice(cfg, weights, filename); 00115 else if(0==strcmp(argv[2], "train")) train_dice(cfg, weights); 00116 else if(0==strcmp(argv[2], "valid")) validate_dice(cfg, weights); 00117 } 00118