super.c
Go to the documentation of this file.
00001 #include "network.h"
00002 #include "cost_layer.h"
00003 #include "utils.h"
00004 #include "parser.h"
00005 
00006 #ifdef OPENCV
00007 #include "opencv2/highgui/highgui_c.h"
00008 #endif
00009 
00010 void train_super(char *cfgfile, char *weightfile)
00011 {
00012     char *train_images = "/data/imagenet/imagenet1k.train.list";
00013     char *backup_directory = "/home/pjreddie/backup/";
00014     srand(time(0));
00015     char *base = basecfg(cfgfile);
00016     printf("%s\n", base);
00017     float avg_loss = -1;
00018     network net = parse_network_cfg(cfgfile);
00019     if(weightfile){
00020         load_weights(&net, weightfile);
00021     }
00022     printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
00023     int imgs = net.batch*net.subdivisions;
00024     int i = *net.seen/imgs;
00025     data train, buffer;
00026 
00027 
00028     list *plist = get_paths(train_images);
00029     //int N = plist->size;
00030     char **paths = (char **)list_to_array(plist);
00031 
00032     load_args args = {0};
00033     args.w = net.w;
00034     args.h = net.h;
00035     args.scale = 4;
00036     args.paths = paths;
00037     args.n = imgs;
00038     args.m = plist->size;
00039     args.d = &buffer;
00040     args.type = SUPER_DATA;
00041 
00042     pthread_t load_thread = load_data_in_thread(args);
00043     clock_t time;
00044     //while(i*imgs < N*120){
00045     while(get_current_batch(net) < net.max_batches){
00046         i += 1;
00047         time=clock();
00048         pthread_join(load_thread, 0);
00049         train = buffer;
00050         load_thread = load_data_in_thread(args);
00051 
00052         printf("Loaded: %lf seconds\n", sec(clock()-time));
00053 
00054         time=clock();
00055         float loss = train_network(net, train);
00056         if (avg_loss < 0) avg_loss = loss;
00057         avg_loss = avg_loss*.9 + loss*.1;
00058 
00059         printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs);
00060         if(i%1000==0){
00061             char buff[256];
00062             sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
00063             save_weights(net, buff);
00064         }
00065         if(i%100==0){
00066             char buff[256];
00067             sprintf(buff, "%s/%s.backup", backup_directory, base);
00068             save_weights(net, buff);
00069         }
00070         free_data(train);
00071     }
00072     char buff[256];
00073     sprintf(buff, "%s/%s_final.weights", backup_directory, base);
00074     save_weights(net, buff);
00075 }
00076 
00077 void test_super(char *cfgfile, char *weightfile, char *filename)
00078 {
00079     network net = parse_network_cfg(cfgfile);
00080     if(weightfile){
00081         load_weights(&net, weightfile);
00082     }
00083     set_batch_network(&net, 1);
00084     srand(2222222);
00085 
00086     clock_t time;
00087     char buff[256];
00088     char *input = buff;
00089     while(1){
00090         if(filename){
00091             strncpy(input, filename, 256);
00092         }else{
00093             printf("Enter Image Path: ");
00094             fflush(stdout);
00095             input = fgets(input, 256, stdin);
00096             if(!input) return;
00097             strtok(input, "\n");
00098         }
00099         image im = load_image_color(input, 0, 0);
00100         resize_network(&net, im.w, im.h);
00101         printf("%d %d\n", im.w, im.h);
00102 
00103         float *X = im.data;
00104         time=clock();
00105         network_predict(net, X);
00106         image out = get_network_image(net);
00107         printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
00108         save_image(out, "out");
00109 
00110         free_image(im);
00111         if (filename) break;
00112     }
00113 }
00114 
00115 
00116 void run_super(int argc, char **argv)
00117 {
00118     if(argc < 4){
00119         fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
00120         return;
00121     }
00122 
00123     char *cfg = argv[3];
00124     char *weights = (argc > 4) ? argv[4] : 0;
00125     char *filename = (argc > 5) ? argv[5] : 0;
00126     if(0==strcmp(argv[2], "train")) train_super(cfg, weights);
00127     else if(0==strcmp(argv[2], "test")) test_super(cfg, weights, filename);
00128     /*
00129     else if(0==strcmp(argv[2], "valid")) validate_super(cfg, weights);
00130     */
00131 }


rail_object_detector
Author(s):
autogenerated on Sat Jun 8 2019 20:26:31