00001 #include "network.h"
00002 #include "utils.h"
00003 #include "parser.h"
00004
00005 void fix_data_captcha(data d, int mask)
00006 {
00007 matrix labels = d.y;
00008 int i, j;
00009 for(i = 0; i < d.y.rows; ++i){
00010 for(j = 0; j < d.y.cols; j += 2){
00011 if (mask){
00012 if(!labels.vals[i][j]){
00013 labels.vals[i][j] = SECRET_NUM;
00014 labels.vals[i][j+1] = SECRET_NUM;
00015 }else if(labels.vals[i][j+1]){
00016 labels.vals[i][j] = 0;
00017 }
00018 } else{
00019 if (labels.vals[i][j]) {
00020 labels.vals[i][j+1] = 0;
00021 } else {
00022 labels.vals[i][j+1] = 1;
00023 }
00024 }
00025 }
00026 }
00027 }
00028
00029 void train_captcha(char *cfgfile, char *weightfile)
00030 {
00031 srand(time(0));
00032 float avg_loss = -1;
00033 char *base = basecfg(cfgfile);
00034 printf("%s\n", base);
00035 network net = parse_network_cfg(cfgfile);
00036 if(weightfile){
00037 load_weights(&net, weightfile);
00038 }
00039 printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
00040 int imgs = 1024;
00041 int i = *net.seen/imgs;
00042 int solved = 1;
00043 list *plist;
00044 char **labels = get_labels("/data/captcha/reimgs.labels.list");
00045 if (solved){
00046 plist = get_paths("/data/captcha/reimgs.solved.list");
00047 }else{
00048 plist = get_paths("/data/captcha/reimgs.raw.list");
00049 }
00050 char **paths = (char **)list_to_array(plist);
00051 printf("%d\n", plist->size);
00052 clock_t time;
00053 pthread_t load_thread;
00054 data train;
00055 data buffer;
00056
00057 load_args args = {0};
00058 args.w = net.w;
00059 args.h = net.h;
00060 args.paths = paths;
00061 args.classes = 26;
00062 args.n = imgs;
00063 args.m = plist->size;
00064 args.labels = labels;
00065 args.d = &buffer;
00066 args.type = CLASSIFICATION_DATA;
00067
00068 load_thread = load_data_in_thread(args);
00069 while(1){
00070 ++i;
00071 time=clock();
00072 pthread_join(load_thread, 0);
00073 train = buffer;
00074 fix_data_captcha(train, solved);
00075
00076
00077
00078
00079
00080
00081
00082 load_thread = load_data_in_thread(args);
00083 printf("Loaded: %lf seconds\n", sec(clock()-time));
00084 time=clock();
00085 float loss = train_network(net, train);
00086 if(avg_loss == -1) avg_loss = loss;
00087 avg_loss = avg_loss*.9 + loss*.1;
00088 printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), *net.seen);
00089 free_data(train);
00090 if(i%100==0){
00091 char buff[256];
00092 sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
00093 save_weights(net, buff);
00094 }
00095 }
00096 }
00097
00098 void test_captcha(char *cfgfile, char *weightfile, char *filename)
00099 {
00100 network net = parse_network_cfg(cfgfile);
00101 if(weightfile){
00102 load_weights(&net, weightfile);
00103 }
00104 set_batch_network(&net, 1);
00105 srand(2222222);
00106 int i = 0;
00107 char **names = get_labels("/data/captcha/reimgs.labels.list");
00108 char buff[256];
00109 char *input = buff;
00110 int indexes[26];
00111 while(1){
00112 if(filename){
00113 strncpy(input, filename, 256);
00114 }else{
00115
00116
00117 input = fgets(input, 256, stdin);
00118 if(!input) return;
00119 strtok(input, "\n");
00120 }
00121 image im = load_image_color(input, net.w, net.h);
00122 float *X = im.data;
00123 float *predictions = network_predict(net, X);
00124 top_predictions(net, 26, indexes);
00125
00126 for(i = 0; i < 26; ++i){
00127 int index = indexes[i];
00128 if(i != 0) printf(", ");
00129 printf("%s %f", names[index], predictions[index]);
00130 }
00131 printf("\n");
00132 fflush(stdout);
00133 free_image(im);
00134 if (filename) break;
00135 }
00136 }
00137
00138 void valid_captcha(char *cfgfile, char *weightfile, char *filename)
00139 {
00140 char **labels = get_labels("/data/captcha/reimgs.labels.list");
00141 network net = parse_network_cfg(cfgfile);
00142 if(weightfile){
00143 load_weights(&net, weightfile);
00144 }
00145 list *plist = get_paths("/data/captcha/reimgs.fg.list");
00146 char **paths = (char **)list_to_array(plist);
00147 int N = plist->size;
00148 int outputs = net.outputs;
00149
00150 set_batch_network(&net, 1);
00151 srand(2222222);
00152 int i, j;
00153 for(i = 0; i < N; ++i){
00154 if (i%100 == 0) fprintf(stderr, "%d\n", i);
00155 image im = load_image_color(paths[i], net.w, net.h);
00156 float *X = im.data;
00157 float *predictions = network_predict(net, X);
00158
00159 int truth = -1;
00160 for(j = 0; j < 13; ++j){
00161 if (strstr(paths[i], labels[j])) truth = j;
00162 }
00163 if (truth == -1){
00164 fprintf(stderr, "bad: %s\n", paths[i]);
00165 return;
00166 }
00167 printf("%d, ", truth);
00168 for(j = 0; j < outputs; ++j){
00169 if (j != 0) printf(", ");
00170 printf("%f", predictions[j]);
00171 }
00172 printf("\n");
00173 fflush(stdout);
00174 free_image(im);
00175 if (filename) break;
00176 }
00177 }
00178
00179
00180
00181
00182
00183
00184
00185
00186
00187
00188
00189
00190
00191
00192
00193
00194
00195
00196
00197
00198
00199
00200
00201
00202
00203
00204
00205
00206
00207
00208
00209
00210
00211
00212
00213
00214
00215
00216
00217
00218
00219
00220
00221
00222
00223
00224
00225
00226
00227
00228
00229
00230
00231
00232
00233
00234
00235
00236
00237
00238
00239
00240
00241
00242
00243
00244
00245
00246
00247
00248
00249
00250
00251
00252
00253
00254
00255
00256
00257
00258
00259
00260
00261
00262
00263
00264
00265
00266
00267
00268
00269
00270
00271
00272
00273
00274
00275
00276
00277
00278
00279
00280
00281
00282
00283
00284
00285
00286
00287
00288
00289
00290
00291
00292
00293
00294
00295
00296
00297
00298
00299
00300
00301
00302
00303
00304
00305
00306
00307
00308
00309
00310
00311
00312
00313
00314
00315
00316
00317
00318
00319
00320
00321
00322
00323
00324
00325
00326
00327
00328
00329
00330
00331
00332
00333
00334
00335
00336
00337
00338
00339
00340
00341
00342
00343
00344
00345
00346 void run_captcha(int argc, char **argv)
00347 {
00348 if(argc < 4){
00349 fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
00350 return;
00351 }
00352
00353 char *cfg = argv[3];
00354 char *weights = (argc > 4) ? argv[4] : 0;
00355 char *filename = (argc > 5) ? argv[5]: 0;
00356 if(0==strcmp(argv[2], "train")) train_captcha(cfg, weights);
00357 else if(0==strcmp(argv[2], "test")) test_captcha(cfg, weights, filename);
00358 else if(0==strcmp(argv[2], "valid")) valid_captcha(cfg, weights, filename);
00359
00360
00361
00362
00363 }
00364