compare.c
Go to the documentation of this file.
00001 #include <stdio.h>
00002 
00003 #include "network.h"
00004 #include "detection_layer.h"
00005 #include "cost_layer.h"
00006 #include "utils.h"
00007 #include "parser.h"
00008 #include "box.h"
00009 
00010 void train_compare(char *cfgfile, char *weightfile)
00011 {
00012     srand(time(0));
00013     float avg_loss = -1;
00014     char *base = basecfg(cfgfile);
00015     char *backup_directory = "/home/pjreddie/backup/";
00016     printf("%s\n", base);
00017     network net = parse_network_cfg(cfgfile);
00018     if(weightfile){
00019         load_weights(&net, weightfile);
00020     }
00021     printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
00022     int imgs = 1024;
00023     list *plist = get_paths("data/compare.train.list");
00024     char **paths = (char **)list_to_array(plist);
00025     int N = plist->size;
00026     printf("%d\n", N);
00027     clock_t time;
00028     pthread_t load_thread;
00029     data train;
00030     data buffer;
00031 
00032     load_args args = {0};
00033     args.w = net.w;
00034     args.h = net.h;
00035     args.paths = paths;
00036     args.classes = 20;
00037     args.n = imgs;
00038     args.m = N;
00039     args.d = &buffer;
00040     args.type = COMPARE_DATA;
00041 
00042     load_thread = load_data_in_thread(args);
00043     int epoch = *net.seen/N;
00044     int i = 0;
00045     while(1){
00046         ++i;
00047         time=clock();
00048         pthread_join(load_thread, 0);
00049         train = buffer;
00050 
00051         load_thread = load_data_in_thread(args);
00052         printf("Loaded: %lf seconds\n", sec(clock()-time));
00053         time=clock();
00054         float loss = train_network(net, train);
00055         if(avg_loss == -1) avg_loss = loss;
00056         avg_loss = avg_loss*.9 + loss*.1;
00057         printf("%.3f: %f, %f avg, %lf seconds, %d images\n", (float)*net.seen/N, loss, avg_loss, sec(clock()-time), *net.seen);
00058         free_data(train);
00059         if(i%100 == 0){
00060             char buff[256];
00061             sprintf(buff, "%s/%s_%d_minor_%d.weights",backup_directory,base, epoch, i);
00062             save_weights(net, buff);
00063         }
00064         if(*net.seen/N > epoch){
00065             epoch = *net.seen/N;
00066             i = 0;
00067             char buff[256];
00068             sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
00069             save_weights(net, buff);
00070             if(epoch%22 == 0) net.learning_rate *= .1;
00071         }
00072     }
00073     pthread_join(load_thread, 0);
00074     free_data(buffer);
00075     free_network(net);
00076     free_ptrs((void**)paths, plist->size);
00077     free_list(plist);
00078     free(base);
00079 }
00080 
00081 void validate_compare(char *filename, char *weightfile)
00082 {
00083     int i = 0;
00084     network net = parse_network_cfg(filename);
00085     if(weightfile){
00086         load_weights(&net, weightfile);
00087     }
00088     srand(time(0));
00089 
00090     list *plist = get_paths("data/compare.val.list");
00091     //list *plist = get_paths("data/compare.val.old");
00092     char **paths = (char **)list_to_array(plist);
00093     int N = plist->size/2;
00094     free_list(plist);
00095 
00096     clock_t time;
00097     int correct = 0;
00098     int total = 0;
00099     int splits = 10;
00100     int num = (i+1)*N/splits - i*N/splits;
00101 
00102     data val, buffer;
00103 
00104     load_args args = {0};
00105     args.w = net.w;
00106     args.h = net.h;
00107     args.paths = paths;
00108     args.classes = 20;
00109     args.n = num;
00110     args.m = 0;
00111     args.d = &buffer;
00112     args.type = COMPARE_DATA;
00113 
00114     pthread_t load_thread = load_data_in_thread(args);
00115     for(i = 1; i <= splits; ++i){
00116         time=clock();
00117 
00118         pthread_join(load_thread, 0);
00119         val = buffer;
00120 
00121         num = (i+1)*N/splits - i*N/splits;
00122         char **part = paths+(i*N/splits);
00123         if(i != splits){
00124             args.paths = part;
00125             load_thread = load_data_in_thread(args);
00126         }
00127         printf("Loaded: %d images in %lf seconds\n", val.X.rows, sec(clock()-time));
00128 
00129         time=clock();
00130         matrix pred = network_predict_data(net, val);
00131         int j,k;
00132         for(j = 0; j < val.y.rows; ++j){
00133             for(k = 0; k < 20; ++k){
00134                 if(val.y.vals[j][k*2] != val.y.vals[j][k*2+1]){
00135                     ++total;
00136                     if((val.y.vals[j][k*2] < val.y.vals[j][k*2+1]) == (pred.vals[j][k*2] < pred.vals[j][k*2+1])){
00137                         ++correct;
00138                     }
00139                 }
00140             }
00141         }
00142         free_matrix(pred);
00143         printf("%d: Acc: %f, %lf seconds, %d images\n", i, (float)correct/total, sec(clock()-time), val.X.rows);
00144         free_data(val);
00145     }
00146 }
00147 
00148 typedef struct {
00149     network net;
00150     char *filename;
00151     int class;
00152     int classes;
00153     float elo;
00154     float *elos;
00155 } sortable_bbox;
00156 
00157 int total_compares = 0;
00158 int current_class = 0;
00159 
00160 int elo_comparator(const void*a, const void *b)
00161 {
00162     sortable_bbox box1 = *(sortable_bbox*)a;
00163     sortable_bbox box2 = *(sortable_bbox*)b;
00164     if(box1.elos[current_class] == box2.elos[current_class]) return 0;
00165     if(box1.elos[current_class] >  box2.elos[current_class]) return -1;
00166     return 1;
00167 }
00168 
00169 int bbox_comparator(const void *a, const void *b)
00170 {
00171     ++total_compares;
00172     sortable_bbox box1 = *(sortable_bbox*)a;
00173     sortable_bbox box2 = *(sortable_bbox*)b;
00174     network net = box1.net;
00175     int class   = box1.class;
00176 
00177     image im1 = load_image_color(box1.filename, net.w, net.h);
00178     image im2 = load_image_color(box2.filename, net.w, net.h);
00179     float *X  = calloc(net.w*net.h*net.c, sizeof(float));
00180     memcpy(X,                   im1.data, im1.w*im1.h*im1.c*sizeof(float));
00181     memcpy(X+im1.w*im1.h*im1.c, im2.data, im2.w*im2.h*im2.c*sizeof(float));
00182     float *predictions = network_predict(net, X);
00183     
00184     free_image(im1);
00185     free_image(im2);
00186     free(X);
00187     if (predictions[class*2] > predictions[class*2+1]){
00188         return 1;
00189     }
00190     return -1;
00191 }
00192 
00193 void bbox_update(sortable_bbox *a, sortable_bbox *b, int class, int result)
00194 {
00195     int k = 32;
00196     float EA = 1./(1+pow(10, (b->elos[class] - a->elos[class])/400.));
00197     float EB = 1./(1+pow(10, (a->elos[class] - b->elos[class])/400.));
00198     float SA = result ? 1 : 0;
00199     float SB = result ? 0 : 1;
00200     a->elos[class] += k*(SA - EA);
00201     b->elos[class] += k*(SB - EB);
00202 }
00203 
00204 void bbox_fight(network net, sortable_bbox *a, sortable_bbox *b, int classes, int class)
00205 {
00206     image im1 = load_image_color(a->filename, net.w, net.h);
00207     image im2 = load_image_color(b->filename, net.w, net.h);
00208     float *X  = calloc(net.w*net.h*net.c, sizeof(float));
00209     memcpy(X,                   im1.data, im1.w*im1.h*im1.c*sizeof(float));
00210     memcpy(X+im1.w*im1.h*im1.c, im2.data, im2.w*im2.h*im2.c*sizeof(float));
00211     float *predictions = network_predict(net, X);
00212     ++total_compares;
00213 
00214     int i;
00215     for(i = 0; i < classes; ++i){
00216         if(class < 0 || class == i){
00217             int result = predictions[i*2] > predictions[i*2+1];
00218             bbox_update(a, b, i, result);
00219         }
00220     }
00221     
00222     free_image(im1);
00223     free_image(im2);
00224     free(X);
00225 }
00226 
00227 void SortMaster3000(char *filename, char *weightfile)
00228 {
00229     int i = 0;
00230     network net = parse_network_cfg(filename);
00231     if(weightfile){
00232         load_weights(&net, weightfile);
00233     }
00234     srand(time(0));
00235     set_batch_network(&net, 1);
00236 
00237     list *plist = get_paths("data/compare.sort.list");
00238     //list *plist = get_paths("data/compare.val.old");
00239     char **paths = (char **)list_to_array(plist);
00240     int N = plist->size;
00241     free_list(plist);
00242     sortable_bbox *boxes = calloc(N, sizeof(sortable_bbox));
00243     printf("Sorting %d boxes...\n", N);
00244     for(i = 0; i < N; ++i){
00245         boxes[i].filename = paths[i];
00246         boxes[i].net = net;
00247         boxes[i].class = 7;
00248         boxes[i].elo = 1500;
00249     }
00250     clock_t time=clock();
00251     qsort(boxes, N, sizeof(sortable_bbox), bbox_comparator);
00252     for(i = 0; i < N; ++i){
00253         printf("%s\n", boxes[i].filename);
00254     }
00255     printf("Sorted in %d compares, %f secs\n", total_compares, sec(clock()-time));
00256 }
00257 
00258 void BattleRoyaleWithCheese(char *filename, char *weightfile)
00259 {
00260     int classes = 20;
00261     int i,j;
00262     network net = parse_network_cfg(filename);
00263     if(weightfile){
00264         load_weights(&net, weightfile);
00265     }
00266     srand(time(0));
00267     set_batch_network(&net, 1);
00268 
00269     list *plist = get_paths("data/compare.sort.list");
00270     //list *plist = get_paths("data/compare.small.list");
00271     //list *plist = get_paths("data/compare.cat.list");
00272     //list *plist = get_paths("data/compare.val.old");
00273     char **paths = (char **)list_to_array(plist);
00274     int N = plist->size;
00275     int total = N;
00276     free_list(plist);
00277     sortable_bbox *boxes = calloc(N, sizeof(sortable_bbox));
00278     printf("Battling %d boxes...\n", N);
00279     for(i = 0; i < N; ++i){
00280         boxes[i].filename = paths[i];
00281         boxes[i].net = net;
00282         boxes[i].classes = classes;
00283         boxes[i].elos = calloc(classes, sizeof(float));;
00284         for(j = 0; j < classes; ++j){
00285             boxes[i].elos[j] = 1500;
00286         }
00287     }
00288     int round;
00289     clock_t time=clock();
00290     for(round = 1; round <= 4; ++round){
00291         clock_t round_time=clock();
00292         printf("Round: %d\n", round);
00293         shuffle(boxes, N, sizeof(sortable_bbox));
00294         for(i = 0; i < N/2; ++i){
00295             bbox_fight(net, boxes+i*2, boxes+i*2+1, classes, -1);
00296         }
00297         printf("Round: %f secs, %d remaining\n", sec(clock()-round_time), N);
00298     }
00299 
00300     int class;
00301 
00302     for (class = 0; class < classes; ++class){
00303 
00304         N = total;
00305         current_class = class;
00306         qsort(boxes, N, sizeof(sortable_bbox), elo_comparator);
00307         N /= 2;
00308 
00309         for(round = 1; round <= 100; ++round){
00310             clock_t round_time=clock();
00311             printf("Round: %d\n", round);
00312 
00313             sorta_shuffle(boxes, N, sizeof(sortable_bbox), 10);
00314             for(i = 0; i < N/2; ++i){
00315                 bbox_fight(net, boxes+i*2, boxes+i*2+1, classes, class);
00316             }
00317             qsort(boxes, N, sizeof(sortable_bbox), elo_comparator);
00318             if(round <= 20) N = (N*9/10)/2*2;
00319 
00320             printf("Round: %f secs, %d remaining\n", sec(clock()-round_time), N);
00321         }
00322         char buff[256];
00323         sprintf(buff, "results/battle_%d.log", class);
00324         FILE *outfp = fopen(buff, "w");
00325         for(i = 0; i < N; ++i){
00326             fprintf(outfp, "%s %f\n", boxes[i].filename, boxes[i].elos[class]);
00327         }
00328         fclose(outfp);
00329     }
00330     printf("Tournament in %d compares, %f secs\n", total_compares, sec(clock()-time));
00331 }
00332 
00333 void run_compare(int argc, char **argv)
00334 {
00335     if(argc < 4){
00336         fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
00337         return;
00338     }
00339 
00340     char *cfg = argv[3];
00341     char *weights = (argc > 4) ? argv[4] : 0;
00342     //char *filename = (argc > 5) ? argv[5]: 0;
00343     if(0==strcmp(argv[2], "train")) train_compare(cfg, weights);
00344     else if(0==strcmp(argv[2], "valid")) validate_compare(cfg, weights);
00345     else if(0==strcmp(argv[2], "sort")) SortMaster3000(cfg, weights);
00346     else if(0==strcmp(argv[2], "battle")) BattleRoyaleWithCheese(cfg, weights);
00347     /*
00348        else if(0==strcmp(argv[2], "train")) train_coco(cfg, weights);
00349        else if(0==strcmp(argv[2], "extract")) extract_boxes(cfg, weights);
00350        else if(0==strcmp(argv[2], "valid")) validate_recall(cfg, weights);
00351      */
00352 }


rail_object_detector
Author(s):
autogenerated on Sat Jun 8 2019 20:26:29