00001 #include "network.h"
00002 #include "cost_layer.h"
00003 #include "utils.h"
00004 #include "parser.h"
00005 #include "blas.h"
00006
00007 #ifdef OPENCV
00008 #include "opencv2/highgui/highgui_c.h"
00009 image get_image_from_stream(CvCapture *cap);
00010 image ipl_to_image(IplImage* src);
00011
00012 void reconstruct_picture(network net, float *features, image recon, image update, float rate, float momentum, float lambda, int smooth_size, int iters);
00013
00014
00015 typedef struct {
00016 float *x;
00017 float *y;
00018 } float_pair;
00019
00020 float_pair get_rnn_vid_data(network net, char **files, int n, int batch, int steps)
00021 {
00022 int b;
00023 assert(net.batch == steps + 1);
00024 image out_im = get_network_image(net);
00025 int output_size = out_im.w*out_im.h*out_im.c;
00026 printf("%d %d %d\n", out_im.w, out_im.h, out_im.c);
00027 float *feats = calloc(net.batch*batch*output_size, sizeof(float));
00028 for(b = 0; b < batch; ++b){
00029 int input_size = net.w*net.h*net.c;
00030 float *input = calloc(input_size*net.batch, sizeof(float));
00031 char *filename = files[rand()%n];
00032 CvCapture *cap = cvCaptureFromFile(filename);
00033 int frames = cvGetCaptureProperty(cap, CV_CAP_PROP_FRAME_COUNT);
00034 int index = rand() % (frames - steps - 2);
00035 if (frames < (steps + 4)){
00036 --b;
00037 free(input);
00038 continue;
00039 }
00040
00041 printf("frames: %d, index: %d\n", frames, index);
00042 cvSetCaptureProperty(cap, CV_CAP_PROP_POS_FRAMES, index);
00043
00044 int i;
00045 for(i = 0; i < net.batch; ++i){
00046 IplImage* src = cvQueryFrame(cap);
00047 image im = ipl_to_image(src);
00048 rgbgr_image(im);
00049 image re = resize_image(im, net.w, net.h);
00050
00051
00052 memcpy(input + i*input_size, re.data, input_size*sizeof(float));
00053 free_image(im);
00054 free_image(re);
00055 }
00056 float *output = network_predict(net, input);
00057
00058 free(input);
00059
00060 for(i = 0; i < net.batch; ++i){
00061 memcpy(feats + (b + i*batch)*output_size, output + i*output_size, output_size*sizeof(float));
00062 }
00063
00064 cvReleaseCapture(&cap);
00065 }
00066
00067
00068 float_pair p = {0};
00069 p.x = feats;
00070 p.y = feats + output_size*batch;
00071
00072 return p;
00073 }
00074
00075
00076 void train_vid_rnn(char *cfgfile, char *weightfile)
00077 {
00078 char *train_videos = "data/vid/train.txt";
00079 char *backup_directory = "/home/pjreddie/backup/";
00080 srand(time(0));
00081 char *base = basecfg(cfgfile);
00082 printf("%s\n", base);
00083 float avg_loss = -1;
00084 network net = parse_network_cfg(cfgfile);
00085 if(weightfile){
00086 load_weights(&net, weightfile);
00087 }
00088 printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
00089 int imgs = net.batch*net.subdivisions;
00090 int i = *net.seen/imgs;
00091
00092 list *plist = get_paths(train_videos);
00093 int N = plist->size;
00094 char **paths = (char **)list_to_array(plist);
00095 clock_t time;
00096 int steps = net.time_steps;
00097 int batch = net.batch / net.time_steps;
00098
00099 network extractor = parse_network_cfg("cfg/extractor.cfg");
00100 load_weights(&extractor, "/home/pjreddie/trained/yolo-coco.conv");
00101
00102 while(get_current_batch(net) < net.max_batches){
00103 i += 1;
00104 time=clock();
00105 float_pair p = get_rnn_vid_data(extractor, paths, N, batch, steps);
00106
00107 float loss = train_network_datum(net, p.x, p.y) / (net.batch);
00108
00109
00110 free(p.x);
00111 if (avg_loss < 0) avg_loss = loss;
00112 avg_loss = avg_loss*.9 + loss*.1;
00113
00114 fprintf(stderr, "%d: %f, %f avg, %f rate, %lf seconds\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time));
00115 if(i%100==0){
00116 char buff[256];
00117 sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
00118 save_weights(net, buff);
00119 }
00120 if(i%10==0){
00121 char buff[256];
00122 sprintf(buff, "%s/%s.backup", backup_directory, base);
00123 save_weights(net, buff);
00124 }
00125 }
00126 char buff[256];
00127 sprintf(buff, "%s/%s_final.weights", backup_directory, base);
00128 save_weights(net, buff);
00129 }
00130
00131
00132 image save_reconstruction(network net, image *init, float *feat, char *name, int i)
00133 {
00134 image recon;
00135 if (init) {
00136 recon = copy_image(*init);
00137 } else {
00138 recon = make_random_image(net.w, net.h, 3);
00139 }
00140
00141 image update = make_image(net.w, net.h, 3);
00142 reconstruct_picture(net, feat, recon, update, .01, .9, .1, 2, 50);
00143 char buff[256];
00144 sprintf(buff, "%s%d", name, i);
00145 save_image(recon, buff);
00146 free_image(update);
00147 return recon;
00148 }
00149
00150 void generate_vid_rnn(char *cfgfile, char *weightfile)
00151 {
00152 network extractor = parse_network_cfg("cfg/extractor.recon.cfg");
00153 load_weights(&extractor, "/home/pjreddie/trained/yolo-coco.conv");
00154
00155 network net = parse_network_cfg(cfgfile);
00156 if(weightfile){
00157 load_weights(&net, weightfile);
00158 }
00159 set_batch_network(&extractor, 1);
00160 set_batch_network(&net, 1);
00161
00162 int i;
00163 CvCapture *cap = cvCaptureFromFile("/extra/vid/ILSVRC2015/Data/VID/snippets/val/ILSVRC2015_val_00007030.mp4");
00164 float *feat;
00165 float *next;
00166 image last;
00167 for(i = 0; i < 25; ++i){
00168 image im = get_image_from_stream(cap);
00169 image re = resize_image(im, extractor.w, extractor.h);
00170 feat = network_predict(extractor, re.data);
00171 if(i > 0){
00172 printf("%f %f\n", mean_array(feat, 14*14*512), variance_array(feat, 14*14*512));
00173 printf("%f %f\n", mean_array(next, 14*14*512), variance_array(next, 14*14*512));
00174 printf("%f\n", mse_array(feat, 14*14*512));
00175 axpy_cpu(14*14*512, -1, feat, 1, next, 1);
00176 printf("%f\n", mse_array(next, 14*14*512));
00177 }
00178 next = network_predict(net, feat);
00179
00180 free_image(im);
00181
00182 free_image(save_reconstruction(extractor, 0, feat, "feat", i));
00183 free_image(save_reconstruction(extractor, 0, next, "next", i));
00184 if (i==24) last = copy_image(re);
00185 free_image(re);
00186 }
00187 for(i = 0; i < 30; ++i){
00188 next = network_predict(net, next);
00189 image new = save_reconstruction(extractor, &last, next, "new", i);
00190 free_image(last);
00191 last = new;
00192 }
00193 }
00194
00195 void run_vid_rnn(int argc, char **argv)
00196 {
00197 if(argc < 4){
00198 fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
00199 return;
00200 }
00201
00202 char *cfg = argv[3];
00203 char *weights = (argc > 4) ? argv[4] : 0;
00204
00205 if(0==strcmp(argv[2], "train")) train_vid_rnn(cfg, weights);
00206 else if(0==strcmp(argv[2], "generate")) generate_vid_rnn(cfg, weights);
00207 }
00208 #else
00209 void run_vid_rnn(int argc, char **argv){}
00210 #endif
00211