00001 #include "network.h" 00002 #include "utils.h" 00003 #include "parser.h" 00004 00005 #ifdef OPENCV 00006 #include "opencv2/highgui/highgui_c.h" 00007 #endif 00008 00009 void train_tag(char *cfgfile, char *weightfile, int clear) 00010 { 00011 srand(time(0)); 00012 float avg_loss = -1; 00013 char *base = basecfg(cfgfile); 00014 char *backup_directory = "/home/pjreddie/backup/"; 00015 printf("%s\n", base); 00016 network net = parse_network_cfg(cfgfile); 00017 if(weightfile){ 00018 load_weights(&net, weightfile); 00019 } 00020 if(clear) *net.seen = 0; 00021 printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); 00022 int imgs = 1024; 00023 list *plist = get_paths("/home/pjreddie/tag/train.list"); 00024 char **paths = (char **)list_to_array(plist); 00025 printf("%d\n", plist->size); 00026 int N = plist->size; 00027 clock_t time; 00028 pthread_t load_thread; 00029 data train; 00030 data buffer; 00031 00032 load_args args = {0}; 00033 args.w = net.w; 00034 args.h = net.h; 00035 00036 args.min = net.w; 00037 args.max = net.max_crop; 00038 args.size = net.w; 00039 00040 args.paths = paths; 00041 args.classes = net.outputs; 00042 args.n = imgs; 00043 args.m = N; 00044 args.d = &buffer; 00045 args.type = TAG_DATA; 00046 00047 args.angle = net.angle; 00048 args.exposure = net.exposure; 00049 args.saturation = net.saturation; 00050 args.hue = net.hue; 00051 00052 fprintf(stderr, "%d classes\n", net.outputs); 00053 00054 load_thread = load_data_in_thread(args); 00055 int epoch = (*net.seen)/N; 00056 while(get_current_batch(net) < net.max_batches || net.max_batches == 0){ 00057 time=clock(); 00058 pthread_join(load_thread, 0); 00059 train = buffer; 00060 00061 load_thread = load_data_in_thread(args); 00062 printf("Loaded: %lf seconds\n", sec(clock()-time)); 00063 time=clock(); 00064 float loss = train_network(net, train); 00065 if(avg_loss == -1) avg_loss = loss; 00066 avg_loss = avg_loss*.9 + loss*.1; 00067 printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen); 00068 free_data(train); 00069 if(*net.seen/N > epoch){ 00070 epoch = *net.seen/N; 00071 char buff[256]; 00072 sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch); 00073 save_weights(net, buff); 00074 } 00075 if(get_current_batch(net)%100 == 0){ 00076 char buff[256]; 00077 sprintf(buff, "%s/%s.backup",backup_directory,base); 00078 save_weights(net, buff); 00079 } 00080 } 00081 char buff[256]; 00082 sprintf(buff, "%s/%s.weights", backup_directory, base); 00083 save_weights(net, buff); 00084 00085 pthread_join(load_thread, 0); 00086 free_data(buffer); 00087 free_network(net); 00088 free_ptrs((void**)paths, plist->size); 00089 free_list(plist); 00090 free(base); 00091 } 00092 00093 void test_tag(char *cfgfile, char *weightfile, char *filename) 00094 { 00095 network net = parse_network_cfg(cfgfile); 00096 if(weightfile){ 00097 load_weights(&net, weightfile); 00098 } 00099 set_batch_network(&net, 1); 00100 srand(2222222); 00101 int i = 0; 00102 char **names = get_labels("data/tags.txt"); 00103 clock_t time; 00104 int indexes[10]; 00105 char buff[256]; 00106 char *input = buff; 00107 int size = net.w; 00108 while(1){ 00109 if(filename){ 00110 strncpy(input, filename, 256); 00111 }else{ 00112 printf("Enter Image Path: "); 00113 fflush(stdout); 00114 input = fgets(input, 256, stdin); 00115 if(!input) return; 00116 strtok(input, "\n"); 00117 } 00118 image im = load_image_color(input, 0, 0); 00119 image r = resize_min(im, size); 00120 resize_network(&net, r.w, r.h); 00121 printf("%d %d\n", r.w, r.h); 00122 00123 float *X = r.data; 00124 time=clock(); 00125 float *predictions = network_predict(net, X); 00126 top_predictions(net, 10, indexes); 00127 printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); 00128 for(i = 0; i < 10; ++i){ 00129 int index = indexes[i]; 00130 printf("%.1f%%: %s\n", predictions[index]*100, names[index]); 00131 } 00132 if(r.data != im.data) free_image(r); 00133 free_image(im); 00134 if (filename) break; 00135 } 00136 } 00137 00138 00139 void run_tag(int argc, char **argv) 00140 { 00141 if(argc < 4){ 00142 fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]); 00143 return; 00144 } 00145 00146 int clear = find_arg(argc, argv, "-clear"); 00147 char *cfg = argv[3]; 00148 char *weights = (argc > 4) ? argv[4] : 0; 00149 char *filename = (argc > 5) ? argv[5] : 0; 00150 if(0==strcmp(argv[2], "train")) train_tag(cfg, weights, clear); 00151 else if(0==strcmp(argv[2], "test")) test_tag(cfg, weights, filename); 00152 } 00153