00001 #include "network.h" 00002 #include "utils.h" 00003 #include "parser.h" 00004 #include "option_list.h" 00005 #include "blas.h" 00006 00007 #ifdef OPENCV 00008 #include "opencv2/highgui/highgui_c.h" 00009 #endif 00010 00011 void train_cifar(char *cfgfile, char *weightfile) 00012 { 00013 srand(time(0)); 00014 float avg_loss = -1; 00015 char *base = basecfg(cfgfile); 00016 printf("%s\n", base); 00017 network net = parse_network_cfg(cfgfile); 00018 if(weightfile){ 00019 load_weights(&net, weightfile); 00020 } 00021 printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); 00022 00023 char *backup_directory = "/home/pjreddie/backup/"; 00024 int classes = 10; 00025 int N = 50000; 00026 00027 char **labels = get_labels("data/cifar/labels.txt"); 00028 int epoch = (*net.seen)/N; 00029 data train = load_all_cifar10(); 00030 while(get_current_batch(net) < net.max_batches || net.max_batches == 0){ 00031 clock_t time=clock(); 00032 00033 float loss = train_network_sgd(net, train, 1); 00034 if(avg_loss == -1) avg_loss = loss; 00035 avg_loss = avg_loss*.95 + loss*.05; 00036 printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen); 00037 if(*net.seen/N > epoch){ 00038 epoch = *net.seen/N; 00039 char buff[256]; 00040 sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch); 00041 save_weights(net, buff); 00042 } 00043 if(get_current_batch(net)%100 == 0){ 00044 char buff[256]; 00045 sprintf(buff, "%s/%s.backup",backup_directory,base); 00046 save_weights(net, buff); 00047 } 00048 } 00049 char buff[256]; 00050 sprintf(buff, "%s/%s.weights", backup_directory, base); 00051 save_weights(net, buff); 00052 00053 free_network(net); 00054 free_ptrs((void**)labels, classes); 00055 free(base); 00056 free_data(train); 00057 } 00058 00059 void train_cifar_distill(char *cfgfile, char *weightfile) 00060 { 00061 srand(time(0)); 00062 float avg_loss = -1; 00063 char *base = basecfg(cfgfile); 00064 printf("%s\n", base); 00065 network net = parse_network_cfg(cfgfile); 00066 if(weightfile){ 00067 load_weights(&net, weightfile); 00068 } 00069 printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); 00070 00071 char *backup_directory = "/home/pjreddie/backup/"; 00072 int classes = 10; 00073 int N = 50000; 00074 00075 char **labels = get_labels("data/cifar/labels.txt"); 00076 int epoch = (*net.seen)/N; 00077 00078 data train = load_all_cifar10(); 00079 matrix soft = csv_to_matrix("results/ensemble.csv"); 00080 00081 float weight = .9; 00082 scale_matrix(soft, weight); 00083 scale_matrix(train.y, 1. - weight); 00084 matrix_add_matrix(soft, train.y); 00085 00086 while(get_current_batch(net) < net.max_batches || net.max_batches == 0){ 00087 clock_t time=clock(); 00088 00089 float loss = train_network_sgd(net, train, 1); 00090 if(avg_loss == -1) avg_loss = loss; 00091 avg_loss = avg_loss*.95 + loss*.05; 00092 printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen); 00093 if(*net.seen/N > epoch){ 00094 epoch = *net.seen/N; 00095 char buff[256]; 00096 sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch); 00097 save_weights(net, buff); 00098 } 00099 if(get_current_batch(net)%100 == 0){ 00100 char buff[256]; 00101 sprintf(buff, "%s/%s.backup",backup_directory,base); 00102 save_weights(net, buff); 00103 } 00104 } 00105 char buff[256]; 00106 sprintf(buff, "%s/%s.weights", backup_directory, base); 00107 save_weights(net, buff); 00108 00109 free_network(net); 00110 free_ptrs((void**)labels, classes); 00111 free(base); 00112 free_data(train); 00113 } 00114 00115 void test_cifar_multi(char *filename, char *weightfile) 00116 { 00117 network net = parse_network_cfg(filename); 00118 if(weightfile){ 00119 load_weights(&net, weightfile); 00120 } 00121 set_batch_network(&net, 1); 00122 srand(time(0)); 00123 00124 float avg_acc = 0; 00125 data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin"); 00126 00127 int i; 00128 for(i = 0; i < test.X.rows; ++i){ 00129 image im = float_to_image(32, 32, 3, test.X.vals[i]); 00130 00131 float pred[10] = {0}; 00132 00133 float *p = network_predict(net, im.data); 00134 axpy_cpu(10, 1, p, 1, pred, 1); 00135 flip_image(im); 00136 p = network_predict(net, im.data); 00137 axpy_cpu(10, 1, p, 1, pred, 1); 00138 00139 int index = max_index(pred, 10); 00140 int class = max_index(test.y.vals[i], 10); 00141 if(index == class) avg_acc += 1; 00142 free_image(im); 00143 printf("%4d: %.2f%%\n", i, 100.*avg_acc/(i+1)); 00144 } 00145 } 00146 00147 void test_cifar(char *filename, char *weightfile) 00148 { 00149 network net = parse_network_cfg(filename); 00150 if(weightfile){ 00151 load_weights(&net, weightfile); 00152 } 00153 srand(time(0)); 00154 00155 clock_t time; 00156 float avg_acc = 0; 00157 float avg_top5 = 0; 00158 data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin"); 00159 00160 time=clock(); 00161 00162 float *acc = network_accuracies(net, test, 2); 00163 avg_acc += acc[0]; 00164 avg_top5 += acc[1]; 00165 printf("top1: %f, %lf seconds, %d images\n", avg_acc, sec(clock()-time), test.X.rows); 00166 free_data(test); 00167 } 00168 00169 void extract_cifar() 00170 { 00171 char *labels[] = {"airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"}; 00172 int i; 00173 data train = load_all_cifar10(); 00174 data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin"); 00175 for(i = 0; i < train.X.rows; ++i){ 00176 image im = float_to_image(32, 32, 3, train.X.vals[i]); 00177 int class = max_index(train.y.vals[i], 10); 00178 char buff[256]; 00179 sprintf(buff, "data/cifar/train/%d_%s",i,labels[class]); 00180 save_image_png(im, buff); 00181 } 00182 for(i = 0; i < test.X.rows; ++i){ 00183 image im = float_to_image(32, 32, 3, test.X.vals[i]); 00184 int class = max_index(test.y.vals[i], 10); 00185 char buff[256]; 00186 sprintf(buff, "data/cifar/test/%d_%s",i,labels[class]); 00187 save_image_png(im, buff); 00188 } 00189 } 00190 00191 void test_cifar_csv(char *filename, char *weightfile) 00192 { 00193 network net = parse_network_cfg(filename); 00194 if(weightfile){ 00195 load_weights(&net, weightfile); 00196 } 00197 srand(time(0)); 00198 00199 data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin"); 00200 00201 matrix pred = network_predict_data(net, test); 00202 00203 int i; 00204 for(i = 0; i < test.X.rows; ++i){ 00205 image im = float_to_image(32, 32, 3, test.X.vals[i]); 00206 flip_image(im); 00207 } 00208 matrix pred2 = network_predict_data(net, test); 00209 scale_matrix(pred, .5); 00210 scale_matrix(pred2, .5); 00211 matrix_add_matrix(pred2, pred); 00212 00213 matrix_to_csv(pred); 00214 fprintf(stderr, "Accuracy: %f\n", matrix_topk_accuracy(test.y, pred, 1)); 00215 free_data(test); 00216 } 00217 00218 void test_cifar_csvtrain(char *filename, char *weightfile) 00219 { 00220 network net = parse_network_cfg(filename); 00221 if(weightfile){ 00222 load_weights(&net, weightfile); 00223 } 00224 srand(time(0)); 00225 00226 data test = load_all_cifar10(); 00227 00228 matrix pred = network_predict_data(net, test); 00229 00230 int i; 00231 for(i = 0; i < test.X.rows; ++i){ 00232 image im = float_to_image(32, 32, 3, test.X.vals[i]); 00233 flip_image(im); 00234 } 00235 matrix pred2 = network_predict_data(net, test); 00236 scale_matrix(pred, .5); 00237 scale_matrix(pred2, .5); 00238 matrix_add_matrix(pred2, pred); 00239 00240 matrix_to_csv(pred); 00241 fprintf(stderr, "Accuracy: %f\n", matrix_topk_accuracy(test.y, pred, 1)); 00242 free_data(test); 00243 } 00244 00245 void eval_cifar_csv() 00246 { 00247 data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin"); 00248 00249 matrix pred = csv_to_matrix("results/combined.csv"); 00250 fprintf(stderr, "%d %d\n", pred.rows, pred.cols); 00251 00252 fprintf(stderr, "Accuracy: %f\n", matrix_topk_accuracy(test.y, pred, 1)); 00253 free_data(test); 00254 free_matrix(pred); 00255 } 00256 00257 00258 void run_cifar(int argc, char **argv) 00259 { 00260 if(argc < 4){ 00261 fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]); 00262 return; 00263 } 00264 00265 char *cfg = argv[3]; 00266 char *weights = (argc > 4) ? argv[4] : 0; 00267 if(0==strcmp(argv[2], "train")) train_cifar(cfg, weights); 00268 else if(0==strcmp(argv[2], "extract")) extract_cifar(); 00269 else if(0==strcmp(argv[2], "distill")) train_cifar_distill(cfg, weights); 00270 else if(0==strcmp(argv[2], "test")) test_cifar(cfg, weights); 00271 else if(0==strcmp(argv[2], "multi")) test_cifar_multi(cfg, weights); 00272 else if(0==strcmp(argv[2], "csv")) test_cifar_csv(cfg, weights); 00273 else if(0==strcmp(argv[2], "csvtrain")) test_cifar_csvtrain(cfg, weights); 00274 else if(0==strcmp(argv[2], "eval")) eval_cifar_csv(); 00275 } 00276 00277