00001 #include "network.h"
00002 #include "cost_layer.h"
00003 #include "utils.h"
00004 #include "parser.h"
00005
00006 #ifdef OPENCV
00007 #include "opencv2/highgui/highgui_c.h"
00008 #endif
00009
00010 void train_super(char *cfgfile, char *weightfile)
00011 {
00012 char *train_images = "/data/imagenet/imagenet1k.train.list";
00013 char *backup_directory = "/home/pjreddie/backup/";
00014 srand(time(0));
00015 char *base = basecfg(cfgfile);
00016 printf("%s\n", base);
00017 float avg_loss = -1;
00018 network net = parse_network_cfg(cfgfile);
00019 if(weightfile){
00020 load_weights(&net, weightfile);
00021 }
00022 printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
00023 int imgs = net.batch*net.subdivisions;
00024 int i = *net.seen/imgs;
00025 data train, buffer;
00026
00027
00028 list *plist = get_paths(train_images);
00029
00030 char **paths = (char **)list_to_array(plist);
00031
00032 load_args args = {0};
00033 args.w = net.w;
00034 args.h = net.h;
00035 args.scale = 4;
00036 args.paths = paths;
00037 args.n = imgs;
00038 args.m = plist->size;
00039 args.d = &buffer;
00040 args.type = SUPER_DATA;
00041
00042 pthread_t load_thread = load_data_in_thread(args);
00043 clock_t time;
00044
00045 while(get_current_batch(net) < net.max_batches){
00046 i += 1;
00047 time=clock();
00048 pthread_join(load_thread, 0);
00049 train = buffer;
00050 load_thread = load_data_in_thread(args);
00051
00052 printf("Loaded: %lf seconds\n", sec(clock()-time));
00053
00054 time=clock();
00055 float loss = train_network(net, train);
00056 if (avg_loss < 0) avg_loss = loss;
00057 avg_loss = avg_loss*.9 + loss*.1;
00058
00059 printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs);
00060 if(i%1000==0){
00061 char buff[256];
00062 sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
00063 save_weights(net, buff);
00064 }
00065 if(i%100==0){
00066 char buff[256];
00067 sprintf(buff, "%s/%s.backup", backup_directory, base);
00068 save_weights(net, buff);
00069 }
00070 free_data(train);
00071 }
00072 char buff[256];
00073 sprintf(buff, "%s/%s_final.weights", backup_directory, base);
00074 save_weights(net, buff);
00075 }
00076
00077 void test_super(char *cfgfile, char *weightfile, char *filename)
00078 {
00079 network net = parse_network_cfg(cfgfile);
00080 if(weightfile){
00081 load_weights(&net, weightfile);
00082 }
00083 set_batch_network(&net, 1);
00084 srand(2222222);
00085
00086 clock_t time;
00087 char buff[256];
00088 char *input = buff;
00089 while(1){
00090 if(filename){
00091 strncpy(input, filename, 256);
00092 }else{
00093 printf("Enter Image Path: ");
00094 fflush(stdout);
00095 input = fgets(input, 256, stdin);
00096 if(!input) return;
00097 strtok(input, "\n");
00098 }
00099 image im = load_image_color(input, 0, 0);
00100 resize_network(&net, im.w, im.h);
00101 printf("%d %d\n", im.w, im.h);
00102
00103 float *X = im.data;
00104 time=clock();
00105 network_predict(net, X);
00106 image out = get_network_image(net);
00107 printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
00108 save_image(out, "out");
00109
00110 free_image(im);
00111 if (filename) break;
00112 }
00113 }
00114
00115
00116 void run_super(int argc, char **argv)
00117 {
00118 if(argc < 4){
00119 fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
00120 return;
00121 }
00122
00123 char *cfg = argv[3];
00124 char *weights = (argc > 4) ? argv[4] : 0;
00125 char *filename = (argc > 5) ? argv[5] : 0;
00126 if(0==strcmp(argv[2], "train")) train_super(cfg, weights);
00127 else if(0==strcmp(argv[2], "test")) test_super(cfg, weights, filename);
00128
00129
00130
00131 }