cifar.c
Go to the documentation of this file.
00001 #include "network.h"
00002 #include "utils.h"
00003 #include "parser.h"
00004 #include "option_list.h"
00005 #include "blas.h"
00006 
00007 #ifdef OPENCV
00008 #include "opencv2/highgui/highgui_c.h"
00009 #endif
00010 
00011 void train_cifar(char *cfgfile, char *weightfile)
00012 {
00013     srand(time(0));
00014     float avg_loss = -1;
00015     char *base = basecfg(cfgfile);
00016     printf("%s\n", base);
00017     network net = parse_network_cfg(cfgfile);
00018     if(weightfile){
00019         load_weights(&net, weightfile);
00020     }
00021     printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
00022 
00023     char *backup_directory = "/home/pjreddie/backup/";
00024     int classes = 10;
00025     int N = 50000;
00026 
00027     char **labels = get_labels("data/cifar/labels.txt");
00028     int epoch = (*net.seen)/N;
00029     data train = load_all_cifar10();
00030     while(get_current_batch(net) < net.max_batches || net.max_batches == 0){
00031         clock_t time=clock();
00032 
00033         float loss = train_network_sgd(net, train, 1);
00034         if(avg_loss == -1) avg_loss = loss;
00035         avg_loss = avg_loss*.95 + loss*.05;
00036         printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen);
00037         if(*net.seen/N > epoch){
00038             epoch = *net.seen/N;
00039             char buff[256];
00040             sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
00041             save_weights(net, buff);
00042         }
00043         if(get_current_batch(net)%100 == 0){
00044             char buff[256];
00045             sprintf(buff, "%s/%s.backup",backup_directory,base);
00046             save_weights(net, buff);
00047         }
00048     }
00049     char buff[256];
00050     sprintf(buff, "%s/%s.weights", backup_directory, base);
00051     save_weights(net, buff);
00052 
00053     free_network(net);
00054     free_ptrs((void**)labels, classes);
00055     free(base);
00056     free_data(train);
00057 }
00058 
00059 void train_cifar_distill(char *cfgfile, char *weightfile)
00060 {
00061     srand(time(0));
00062     float avg_loss = -1;
00063     char *base = basecfg(cfgfile);
00064     printf("%s\n", base);
00065     network net = parse_network_cfg(cfgfile);
00066     if(weightfile){
00067         load_weights(&net, weightfile);
00068     }
00069     printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
00070 
00071     char *backup_directory = "/home/pjreddie/backup/";
00072     int classes = 10;
00073     int N = 50000;
00074 
00075     char **labels = get_labels("data/cifar/labels.txt");
00076     int epoch = (*net.seen)/N;
00077 
00078     data train = load_all_cifar10();
00079     matrix soft = csv_to_matrix("results/ensemble.csv");
00080 
00081     float weight = .9;
00082     scale_matrix(soft, weight);
00083     scale_matrix(train.y, 1. - weight);
00084     matrix_add_matrix(soft, train.y);
00085 
00086     while(get_current_batch(net) < net.max_batches || net.max_batches == 0){
00087         clock_t time=clock();
00088 
00089         float loss = train_network_sgd(net, train, 1);
00090         if(avg_loss == -1) avg_loss = loss;
00091         avg_loss = avg_loss*.95 + loss*.05;
00092         printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen);
00093         if(*net.seen/N > epoch){
00094             epoch = *net.seen/N;
00095             char buff[256];
00096             sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
00097             save_weights(net, buff);
00098         }
00099         if(get_current_batch(net)%100 == 0){
00100             char buff[256];
00101             sprintf(buff, "%s/%s.backup",backup_directory,base);
00102             save_weights(net, buff);
00103         }
00104     }
00105     char buff[256];
00106     sprintf(buff, "%s/%s.weights", backup_directory, base);
00107     save_weights(net, buff);
00108 
00109     free_network(net);
00110     free_ptrs((void**)labels, classes);
00111     free(base);
00112     free_data(train);
00113 }
00114 
00115 void test_cifar_multi(char *filename, char *weightfile)
00116 {
00117     network net = parse_network_cfg(filename);
00118     if(weightfile){
00119         load_weights(&net, weightfile);
00120     }
00121     set_batch_network(&net, 1);
00122     srand(time(0));
00123 
00124     float avg_acc = 0;
00125     data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin");
00126 
00127     int i;
00128     for(i = 0; i < test.X.rows; ++i){
00129         image im = float_to_image(32, 32, 3, test.X.vals[i]);
00130 
00131         float pred[10] = {0};
00132 
00133         float *p = network_predict(net, im.data);
00134         axpy_cpu(10, 1, p, 1, pred, 1);
00135         flip_image(im);
00136         p = network_predict(net, im.data);
00137         axpy_cpu(10, 1, p, 1, pred, 1);
00138 
00139         int index = max_index(pred, 10);
00140         int class = max_index(test.y.vals[i], 10);
00141         if(index == class) avg_acc += 1;
00142         free_image(im);
00143         printf("%4d: %.2f%%\n", i, 100.*avg_acc/(i+1));
00144     }
00145 }
00146 
00147 void test_cifar(char *filename, char *weightfile)
00148 {
00149     network net = parse_network_cfg(filename);
00150     if(weightfile){
00151         load_weights(&net, weightfile);
00152     }
00153     srand(time(0));
00154 
00155     clock_t time;
00156     float avg_acc = 0;
00157     float avg_top5 = 0;
00158     data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin");
00159 
00160     time=clock();
00161 
00162     float *acc = network_accuracies(net, test, 2);
00163     avg_acc += acc[0];
00164     avg_top5 += acc[1];
00165     printf("top1: %f, %lf seconds, %d images\n", avg_acc, sec(clock()-time), test.X.rows);
00166     free_data(test);
00167 }
00168 
00169 void extract_cifar()
00170 {
00171 char *labels[] = {"airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"};
00172     int i;
00173     data train = load_all_cifar10();
00174     data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin");
00175     for(i = 0; i < train.X.rows; ++i){
00176         image im = float_to_image(32, 32, 3, train.X.vals[i]);
00177         int class = max_index(train.y.vals[i], 10);
00178         char buff[256];
00179         sprintf(buff, "data/cifar/train/%d_%s",i,labels[class]);
00180         save_image_png(im, buff);
00181     }
00182     for(i = 0; i < test.X.rows; ++i){
00183         image im = float_to_image(32, 32, 3, test.X.vals[i]);
00184         int class = max_index(test.y.vals[i], 10);
00185         char buff[256];
00186         sprintf(buff, "data/cifar/test/%d_%s",i,labels[class]);
00187         save_image_png(im, buff);
00188     }
00189 }
00190 
00191 void test_cifar_csv(char *filename, char *weightfile)
00192 {
00193     network net = parse_network_cfg(filename);
00194     if(weightfile){
00195         load_weights(&net, weightfile);
00196     }
00197     srand(time(0));
00198 
00199     data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin");
00200 
00201     matrix pred = network_predict_data(net, test);
00202 
00203     int i;
00204     for(i = 0; i < test.X.rows; ++i){
00205         image im = float_to_image(32, 32, 3, test.X.vals[i]);
00206         flip_image(im);
00207     }
00208     matrix pred2 = network_predict_data(net, test);
00209     scale_matrix(pred, .5);
00210     scale_matrix(pred2, .5);
00211     matrix_add_matrix(pred2, pred);
00212 
00213     matrix_to_csv(pred);
00214     fprintf(stderr, "Accuracy: %f\n", matrix_topk_accuracy(test.y, pred, 1));
00215     free_data(test);
00216 }
00217 
00218 void test_cifar_csvtrain(char *filename, char *weightfile)
00219 {
00220     network net = parse_network_cfg(filename);
00221     if(weightfile){
00222         load_weights(&net, weightfile);
00223     }
00224     srand(time(0));
00225 
00226     data test = load_all_cifar10();
00227 
00228     matrix pred = network_predict_data(net, test);
00229 
00230     int i;
00231     for(i = 0; i < test.X.rows; ++i){
00232         image im = float_to_image(32, 32, 3, test.X.vals[i]);
00233         flip_image(im);
00234     }
00235     matrix pred2 = network_predict_data(net, test);
00236     scale_matrix(pred, .5);
00237     scale_matrix(pred2, .5);
00238     matrix_add_matrix(pred2, pred);
00239 
00240     matrix_to_csv(pred);
00241     fprintf(stderr, "Accuracy: %f\n", matrix_topk_accuracy(test.y, pred, 1));
00242     free_data(test);
00243 }
00244 
00245 void eval_cifar_csv()
00246 {
00247     data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin");
00248 
00249     matrix pred = csv_to_matrix("results/combined.csv");
00250     fprintf(stderr, "%d %d\n", pred.rows, pred.cols);
00251 
00252     fprintf(stderr, "Accuracy: %f\n", matrix_topk_accuracy(test.y, pred, 1));
00253     free_data(test);
00254     free_matrix(pred);
00255 }
00256 
00257 
00258 void run_cifar(int argc, char **argv)
00259 {
00260     if(argc < 4){
00261         fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
00262         return;
00263     }
00264 
00265     char *cfg = argv[3];
00266     char *weights = (argc > 4) ? argv[4] : 0;
00267     if(0==strcmp(argv[2], "train")) train_cifar(cfg, weights);
00268     else if(0==strcmp(argv[2], "extract")) extract_cifar();
00269     else if(0==strcmp(argv[2], "distill")) train_cifar_distill(cfg, weights);
00270     else if(0==strcmp(argv[2], "test")) test_cifar(cfg, weights);
00271     else if(0==strcmp(argv[2], "multi")) test_cifar_multi(cfg, weights);
00272     else if(0==strcmp(argv[2], "csv")) test_cifar_csv(cfg, weights);
00273     else if(0==strcmp(argv[2], "csvtrain")) test_cifar_csvtrain(cfg, weights);
00274     else if(0==strcmp(argv[2], "eval")) eval_cifar_csv();
00275 }
00276 
00277 


rail_object_detector
Author(s):
autogenerated on Sat Jun 8 2019 20:26:29