00001 #include "network.h"
00002 #include "cost_layer.h"
00003 #include "utils.h"
00004 #include "parser.h"
00005
00006 #ifdef OPENCV
00007 #include "opencv2/highgui/highgui_c.h"
00008 image get_image_from_stream(CvCapture *cap);
00009 #endif
00010
00011 void extract_voxel(char *lfile, char *rfile, char *prefix)
00012 {
00013 #ifdef OPENCV
00014 int w = 1920;
00015 int h = 1080;
00016 int shift = 0;
00017 int count = 0;
00018 CvCapture *lcap = cvCaptureFromFile(lfile);
00019 CvCapture *rcap = cvCaptureFromFile(rfile);
00020 while(1){
00021 image l = get_image_from_stream(lcap);
00022 image r = get_image_from_stream(rcap);
00023 if(!l.w || !r.w) break;
00024 if(count%100 == 0) {
00025 shift = best_3d_shift_r(l, r, -l.h/100, l.h/100);
00026 printf("%d\n", shift);
00027 }
00028 image ls = crop_image(l, (l.w - w)/2, (l.h - h)/2, w, h);
00029 image rs = crop_image(r, 105 + (r.w - w)/2, (r.h - h)/2 + shift, w, h);
00030 char buff[256];
00031 sprintf(buff, "%s_%05d_l", prefix, count);
00032 save_image(ls, buff);
00033 sprintf(buff, "%s_%05d_r", prefix, count);
00034 save_image(rs, buff);
00035 free_image(l);
00036 free_image(r);
00037 free_image(ls);
00038 free_image(rs);
00039 ++count;
00040 }
00041
00042 #else
00043 printf("need OpenCV for extraction\n");
00044 #endif
00045 }
00046
00047 void train_voxel(char *cfgfile, char *weightfile)
00048 {
00049 char *train_images = "/data/imagenet/imagenet1k.train.list";
00050 char *backup_directory = "/home/pjreddie/backup/";
00051 srand(time(0));
00052 char *base = basecfg(cfgfile);
00053 printf("%s\n", base);
00054 float avg_loss = -1;
00055 network net = parse_network_cfg(cfgfile);
00056 if(weightfile){
00057 load_weights(&net, weightfile);
00058 }
00059 printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
00060 int imgs = net.batch*net.subdivisions;
00061 int i = *net.seen/imgs;
00062 data train, buffer;
00063
00064
00065 list *plist = get_paths(train_images);
00066
00067 char **paths = (char **)list_to_array(plist);
00068
00069 load_args args = {0};
00070 args.w = net.w;
00071 args.h = net.h;
00072 args.scale = 4;
00073 args.paths = paths;
00074 args.n = imgs;
00075 args.m = plist->size;
00076 args.d = &buffer;
00077 args.type = SUPER_DATA;
00078
00079 pthread_t load_thread = load_data_in_thread(args);
00080 clock_t time;
00081
00082 while(get_current_batch(net) < net.max_batches){
00083 i += 1;
00084 time=clock();
00085 pthread_join(load_thread, 0);
00086 train = buffer;
00087 load_thread = load_data_in_thread(args);
00088
00089 printf("Loaded: %lf seconds\n", sec(clock()-time));
00090
00091 time=clock();
00092 float loss = train_network(net, train);
00093 if (avg_loss < 0) avg_loss = loss;
00094 avg_loss = avg_loss*.9 + loss*.1;
00095
00096 printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs);
00097 if(i%1000==0){
00098 char buff[256];
00099 sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
00100 save_weights(net, buff);
00101 }
00102 if(i%100==0){
00103 char buff[256];
00104 sprintf(buff, "%s/%s.backup", backup_directory, base);
00105 save_weights(net, buff);
00106 }
00107 free_data(train);
00108 }
00109 char buff[256];
00110 sprintf(buff, "%s/%s_final.weights", backup_directory, base);
00111 save_weights(net, buff);
00112 }
00113
00114 void test_voxel(char *cfgfile, char *weightfile, char *filename)
00115 {
00116 network net = parse_network_cfg(cfgfile);
00117 if(weightfile){
00118 load_weights(&net, weightfile);
00119 }
00120 set_batch_network(&net, 1);
00121 srand(2222222);
00122
00123 clock_t time;
00124 char buff[256];
00125 char *input = buff;
00126 while(1){
00127 if(filename){
00128 strncpy(input, filename, 256);
00129 }else{
00130 printf("Enter Image Path: ");
00131 fflush(stdout);
00132 input = fgets(input, 256, stdin);
00133 if(!input) return;
00134 strtok(input, "\n");
00135 }
00136 image im = load_image_color(input, 0, 0);
00137 resize_network(&net, im.w, im.h);
00138 printf("%d %d\n", im.w, im.h);
00139
00140 float *X = im.data;
00141 time=clock();
00142 network_predict(net, X);
00143 image out = get_network_image(net);
00144 printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
00145 save_image(out, "out");
00146
00147 free_image(im);
00148 if (filename) break;
00149 }
00150 }
00151
00152
00153 void run_voxel(int argc, char **argv)
00154 {
00155 if(argc < 4){
00156 fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
00157 return;
00158 }
00159
00160 char *cfg = argv[3];
00161 char *weights = (argc > 4) ? argv[4] : 0;
00162 char *filename = (argc > 5) ? argv[5] : 0;
00163 if(0==strcmp(argv[2], "train")) train_voxel(cfg, weights);
00164 else if(0==strcmp(argv[2], "test")) test_voxel(cfg, weights, filename);
00165 else if(0==strcmp(argv[2], "extract")) extract_voxel(argv[3], argv[4], argv[5]);
00166
00167
00168
00169 }