00001 #include <stdio.h>
00002
00003 #include "network.h"
00004 #include "detection_layer.h"
00005 #include "cost_layer.h"
00006 #include "utils.h"
00007 #include "parser.h"
00008 #include "box.h"
00009
00010 void train_compare(char *cfgfile, char *weightfile)
00011 {
00012 srand(time(0));
00013 float avg_loss = -1;
00014 char *base = basecfg(cfgfile);
00015 char *backup_directory = "/home/pjreddie/backup/";
00016 printf("%s\n", base);
00017 network net = parse_network_cfg(cfgfile);
00018 if(weightfile){
00019 load_weights(&net, weightfile);
00020 }
00021 printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
00022 int imgs = 1024;
00023 list *plist = get_paths("data/compare.train.list");
00024 char **paths = (char **)list_to_array(plist);
00025 int N = plist->size;
00026 printf("%d\n", N);
00027 clock_t time;
00028 pthread_t load_thread;
00029 data train;
00030 data buffer;
00031
00032 load_args args = {0};
00033 args.w = net.w;
00034 args.h = net.h;
00035 args.paths = paths;
00036 args.classes = 20;
00037 args.n = imgs;
00038 args.m = N;
00039 args.d = &buffer;
00040 args.type = COMPARE_DATA;
00041
00042 load_thread = load_data_in_thread(args);
00043 int epoch = *net.seen/N;
00044 int i = 0;
00045 while(1){
00046 ++i;
00047 time=clock();
00048 pthread_join(load_thread, 0);
00049 train = buffer;
00050
00051 load_thread = load_data_in_thread(args);
00052 printf("Loaded: %lf seconds\n", sec(clock()-time));
00053 time=clock();
00054 float loss = train_network(net, train);
00055 if(avg_loss == -1) avg_loss = loss;
00056 avg_loss = avg_loss*.9 + loss*.1;
00057 printf("%.3f: %f, %f avg, %lf seconds, %d images\n", (float)*net.seen/N, loss, avg_loss, sec(clock()-time), *net.seen);
00058 free_data(train);
00059 if(i%100 == 0){
00060 char buff[256];
00061 sprintf(buff, "%s/%s_%d_minor_%d.weights",backup_directory,base, epoch, i);
00062 save_weights(net, buff);
00063 }
00064 if(*net.seen/N > epoch){
00065 epoch = *net.seen/N;
00066 i = 0;
00067 char buff[256];
00068 sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
00069 save_weights(net, buff);
00070 if(epoch%22 == 0) net.learning_rate *= .1;
00071 }
00072 }
00073 pthread_join(load_thread, 0);
00074 free_data(buffer);
00075 free_network(net);
00076 free_ptrs((void**)paths, plist->size);
00077 free_list(plist);
00078 free(base);
00079 }
00080
00081 void validate_compare(char *filename, char *weightfile)
00082 {
00083 int i = 0;
00084 network net = parse_network_cfg(filename);
00085 if(weightfile){
00086 load_weights(&net, weightfile);
00087 }
00088 srand(time(0));
00089
00090 list *plist = get_paths("data/compare.val.list");
00091
00092 char **paths = (char **)list_to_array(plist);
00093 int N = plist->size/2;
00094 free_list(plist);
00095
00096 clock_t time;
00097 int correct = 0;
00098 int total = 0;
00099 int splits = 10;
00100 int num = (i+1)*N/splits - i*N/splits;
00101
00102 data val, buffer;
00103
00104 load_args args = {0};
00105 args.w = net.w;
00106 args.h = net.h;
00107 args.paths = paths;
00108 args.classes = 20;
00109 args.n = num;
00110 args.m = 0;
00111 args.d = &buffer;
00112 args.type = COMPARE_DATA;
00113
00114 pthread_t load_thread = load_data_in_thread(args);
00115 for(i = 1; i <= splits; ++i){
00116 time=clock();
00117
00118 pthread_join(load_thread, 0);
00119 val = buffer;
00120
00121 num = (i+1)*N/splits - i*N/splits;
00122 char **part = paths+(i*N/splits);
00123 if(i != splits){
00124 args.paths = part;
00125 load_thread = load_data_in_thread(args);
00126 }
00127 printf("Loaded: %d images in %lf seconds\n", val.X.rows, sec(clock()-time));
00128
00129 time=clock();
00130 matrix pred = network_predict_data(net, val);
00131 int j,k;
00132 for(j = 0; j < val.y.rows; ++j){
00133 for(k = 0; k < 20; ++k){
00134 if(val.y.vals[j][k*2] != val.y.vals[j][k*2+1]){
00135 ++total;
00136 if((val.y.vals[j][k*2] < val.y.vals[j][k*2+1]) == (pred.vals[j][k*2] < pred.vals[j][k*2+1])){
00137 ++correct;
00138 }
00139 }
00140 }
00141 }
00142 free_matrix(pred);
00143 printf("%d: Acc: %f, %lf seconds, %d images\n", i, (float)correct/total, sec(clock()-time), val.X.rows);
00144 free_data(val);
00145 }
00146 }
00147
00148 typedef struct {
00149 network net;
00150 char *filename;
00151 int class;
00152 int classes;
00153 float elo;
00154 float *elos;
00155 } sortable_bbox;
00156
00157 int total_compares = 0;
00158 int current_class = 0;
00159
00160 int elo_comparator(const void*a, const void *b)
00161 {
00162 sortable_bbox box1 = *(sortable_bbox*)a;
00163 sortable_bbox box2 = *(sortable_bbox*)b;
00164 if(box1.elos[current_class] == box2.elos[current_class]) return 0;
00165 if(box1.elos[current_class] > box2.elos[current_class]) return -1;
00166 return 1;
00167 }
00168
00169 int bbox_comparator(const void *a, const void *b)
00170 {
00171 ++total_compares;
00172 sortable_bbox box1 = *(sortable_bbox*)a;
00173 sortable_bbox box2 = *(sortable_bbox*)b;
00174 network net = box1.net;
00175 int class = box1.class;
00176
00177 image im1 = load_image_color(box1.filename, net.w, net.h);
00178 image im2 = load_image_color(box2.filename, net.w, net.h);
00179 float *X = calloc(net.w*net.h*net.c, sizeof(float));
00180 memcpy(X, im1.data, im1.w*im1.h*im1.c*sizeof(float));
00181 memcpy(X+im1.w*im1.h*im1.c, im2.data, im2.w*im2.h*im2.c*sizeof(float));
00182 float *predictions = network_predict(net, X);
00183
00184 free_image(im1);
00185 free_image(im2);
00186 free(X);
00187 if (predictions[class*2] > predictions[class*2+1]){
00188 return 1;
00189 }
00190 return -1;
00191 }
00192
00193 void bbox_update(sortable_bbox *a, sortable_bbox *b, int class, int result)
00194 {
00195 int k = 32;
00196 float EA = 1./(1+pow(10, (b->elos[class] - a->elos[class])/400.));
00197 float EB = 1./(1+pow(10, (a->elos[class] - b->elos[class])/400.));
00198 float SA = result ? 1 : 0;
00199 float SB = result ? 0 : 1;
00200 a->elos[class] += k*(SA - EA);
00201 b->elos[class] += k*(SB - EB);
00202 }
00203
00204 void bbox_fight(network net, sortable_bbox *a, sortable_bbox *b, int classes, int class)
00205 {
00206 image im1 = load_image_color(a->filename, net.w, net.h);
00207 image im2 = load_image_color(b->filename, net.w, net.h);
00208 float *X = calloc(net.w*net.h*net.c, sizeof(float));
00209 memcpy(X, im1.data, im1.w*im1.h*im1.c*sizeof(float));
00210 memcpy(X+im1.w*im1.h*im1.c, im2.data, im2.w*im2.h*im2.c*sizeof(float));
00211 float *predictions = network_predict(net, X);
00212 ++total_compares;
00213
00214 int i;
00215 for(i = 0; i < classes; ++i){
00216 if(class < 0 || class == i){
00217 int result = predictions[i*2] > predictions[i*2+1];
00218 bbox_update(a, b, i, result);
00219 }
00220 }
00221
00222 free_image(im1);
00223 free_image(im2);
00224 free(X);
00225 }
00226
00227 void SortMaster3000(char *filename, char *weightfile)
00228 {
00229 int i = 0;
00230 network net = parse_network_cfg(filename);
00231 if(weightfile){
00232 load_weights(&net, weightfile);
00233 }
00234 srand(time(0));
00235 set_batch_network(&net, 1);
00236
00237 list *plist = get_paths("data/compare.sort.list");
00238
00239 char **paths = (char **)list_to_array(plist);
00240 int N = plist->size;
00241 free_list(plist);
00242 sortable_bbox *boxes = calloc(N, sizeof(sortable_bbox));
00243 printf("Sorting %d boxes...\n", N);
00244 for(i = 0; i < N; ++i){
00245 boxes[i].filename = paths[i];
00246 boxes[i].net = net;
00247 boxes[i].class = 7;
00248 boxes[i].elo = 1500;
00249 }
00250 clock_t time=clock();
00251 qsort(boxes, N, sizeof(sortable_bbox), bbox_comparator);
00252 for(i = 0; i < N; ++i){
00253 printf("%s\n", boxes[i].filename);
00254 }
00255 printf("Sorted in %d compares, %f secs\n", total_compares, sec(clock()-time));
00256 }
00257
00258 void BattleRoyaleWithCheese(char *filename, char *weightfile)
00259 {
00260 int classes = 20;
00261 int i,j;
00262 network net = parse_network_cfg(filename);
00263 if(weightfile){
00264 load_weights(&net, weightfile);
00265 }
00266 srand(time(0));
00267 set_batch_network(&net, 1);
00268
00269 list *plist = get_paths("data/compare.sort.list");
00270
00271
00272
00273 char **paths = (char **)list_to_array(plist);
00274 int N = plist->size;
00275 int total = N;
00276 free_list(plist);
00277 sortable_bbox *boxes = calloc(N, sizeof(sortable_bbox));
00278 printf("Battling %d boxes...\n", N);
00279 for(i = 0; i < N; ++i){
00280 boxes[i].filename = paths[i];
00281 boxes[i].net = net;
00282 boxes[i].classes = classes;
00283 boxes[i].elos = calloc(classes, sizeof(float));;
00284 for(j = 0; j < classes; ++j){
00285 boxes[i].elos[j] = 1500;
00286 }
00287 }
00288 int round;
00289 clock_t time=clock();
00290 for(round = 1; round <= 4; ++round){
00291 clock_t round_time=clock();
00292 printf("Round: %d\n", round);
00293 shuffle(boxes, N, sizeof(sortable_bbox));
00294 for(i = 0; i < N/2; ++i){
00295 bbox_fight(net, boxes+i*2, boxes+i*2+1, classes, -1);
00296 }
00297 printf("Round: %f secs, %d remaining\n", sec(clock()-round_time), N);
00298 }
00299
00300 int class;
00301
00302 for (class = 0; class < classes; ++class){
00303
00304 N = total;
00305 current_class = class;
00306 qsort(boxes, N, sizeof(sortable_bbox), elo_comparator);
00307 N /= 2;
00308
00309 for(round = 1; round <= 100; ++round){
00310 clock_t round_time=clock();
00311 printf("Round: %d\n", round);
00312
00313 sorta_shuffle(boxes, N, sizeof(sortable_bbox), 10);
00314 for(i = 0; i < N/2; ++i){
00315 bbox_fight(net, boxes+i*2, boxes+i*2+1, classes, class);
00316 }
00317 qsort(boxes, N, sizeof(sortable_bbox), elo_comparator);
00318 if(round <= 20) N = (N*9/10)/2*2;
00319
00320 printf("Round: %f secs, %d remaining\n", sec(clock()-round_time), N);
00321 }
00322 char buff[256];
00323 sprintf(buff, "results/battle_%d.log", class);
00324 FILE *outfp = fopen(buff, "w");
00325 for(i = 0; i < N; ++i){
00326 fprintf(outfp, "%s %f\n", boxes[i].filename, boxes[i].elos[class]);
00327 }
00328 fclose(outfp);
00329 }
00330 printf("Tournament in %d compares, %f secs\n", total_compares, sec(clock()-time));
00331 }
00332
00333 void run_compare(int argc, char **argv)
00334 {
00335 if(argc < 4){
00336 fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
00337 return;
00338 }
00339
00340 char *cfg = argv[3];
00341 char *weights = (argc > 4) ? argv[4] : 0;
00342
00343 if(0==strcmp(argv[2], "train")) train_compare(cfg, weights);
00344 else if(0==strcmp(argv[2], "valid")) validate_compare(cfg, weights);
00345 else if(0==strcmp(argv[2], "sort")) SortMaster3000(cfg, weights);
00346 else if(0==strcmp(argv[2], "battle")) BattleRoyaleWithCheese(cfg, weights);
00347
00348
00349
00350
00351
00352 }