captcha.c
Go to the documentation of this file.
00001 #include "network.h"
00002 #include "utils.h"
00003 #include "parser.h"
00004 
00005 void fix_data_captcha(data d, int mask)
00006 {
00007     matrix labels = d.y;
00008     int i, j;
00009     for(i = 0; i < d.y.rows; ++i){
00010         for(j = 0; j < d.y.cols; j += 2){
00011             if (mask){
00012                 if(!labels.vals[i][j]){
00013                     labels.vals[i][j] = SECRET_NUM;
00014                     labels.vals[i][j+1] = SECRET_NUM;
00015                 }else if(labels.vals[i][j+1]){
00016                     labels.vals[i][j] = 0;
00017                 }
00018             } else{
00019                 if (labels.vals[i][j]) {
00020                     labels.vals[i][j+1] = 0;
00021                 } else {
00022                     labels.vals[i][j+1] = 1;
00023                 }
00024             }
00025         }
00026     }
00027 }
00028 
00029 void train_captcha(char *cfgfile, char *weightfile)
00030 {
00031     srand(time(0));
00032     float avg_loss = -1;
00033     char *base = basecfg(cfgfile);
00034     printf("%s\n", base);
00035     network net = parse_network_cfg(cfgfile);
00036     if(weightfile){
00037         load_weights(&net, weightfile);
00038     }
00039     printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
00040     int imgs = 1024;
00041     int i = *net.seen/imgs;
00042     int solved = 1;
00043     list *plist;
00044     char **labels = get_labels("/data/captcha/reimgs.labels.list");
00045     if (solved){
00046         plist = get_paths("/data/captcha/reimgs.solved.list");
00047     }else{
00048         plist = get_paths("/data/captcha/reimgs.raw.list");
00049     }
00050     char **paths = (char **)list_to_array(plist);
00051     printf("%d\n", plist->size);
00052     clock_t time;
00053     pthread_t load_thread;
00054     data train;
00055     data buffer;
00056 
00057     load_args args = {0};
00058     args.w = net.w;
00059     args.h = net.h;
00060     args.paths = paths;
00061     args.classes = 26;
00062     args.n = imgs;
00063     args.m = plist->size;
00064     args.labels = labels;
00065     args.d = &buffer;
00066     args.type = CLASSIFICATION_DATA;
00067 
00068     load_thread = load_data_in_thread(args);
00069     while(1){
00070         ++i;
00071         time=clock();
00072         pthread_join(load_thread, 0);
00073         train = buffer;
00074         fix_data_captcha(train, solved);
00075 
00076         /*
00077            image im = float_to_image(256, 256, 3, train.X.vals[114]);
00078            show_image(im, "training");
00079            cvWaitKey(0);
00080          */
00081 
00082         load_thread = load_data_in_thread(args);
00083         printf("Loaded: %lf seconds\n", sec(clock()-time));
00084         time=clock();
00085         float loss = train_network(net, train);
00086         if(avg_loss == -1) avg_loss = loss;
00087         avg_loss = avg_loss*.9 + loss*.1;
00088         printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), *net.seen);
00089         free_data(train);
00090         if(i%100==0){
00091             char buff[256];
00092             sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
00093             save_weights(net, buff);
00094         }
00095     }
00096 }
00097 
00098 void test_captcha(char *cfgfile, char *weightfile, char *filename)
00099 {
00100     network net = parse_network_cfg(cfgfile);
00101     if(weightfile){
00102         load_weights(&net, weightfile);
00103     }
00104     set_batch_network(&net, 1);
00105     srand(2222222);
00106     int i = 0;
00107     char **names = get_labels("/data/captcha/reimgs.labels.list");
00108     char buff[256];
00109     char *input = buff;
00110     int indexes[26];
00111     while(1){
00112         if(filename){
00113             strncpy(input, filename, 256);
00114         }else{
00115             //printf("Enter Image Path: ");
00116             //fflush(stdout);
00117             input = fgets(input, 256, stdin);
00118             if(!input) return;
00119             strtok(input, "\n");
00120         }
00121         image im = load_image_color(input, net.w, net.h);
00122         float *X = im.data;
00123         float *predictions = network_predict(net, X);
00124         top_predictions(net, 26, indexes);
00125         //printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
00126         for(i = 0; i < 26; ++i){
00127             int index = indexes[i];
00128             if(i != 0) printf(", ");
00129             printf("%s %f", names[index], predictions[index]);
00130         }
00131         printf("\n");
00132         fflush(stdout);
00133         free_image(im);
00134         if (filename) break;
00135     }
00136 }
00137 
00138 void valid_captcha(char *cfgfile, char *weightfile, char *filename)
00139 {
00140     char **labels = get_labels("/data/captcha/reimgs.labels.list");
00141     network net = parse_network_cfg(cfgfile);
00142     if(weightfile){
00143         load_weights(&net, weightfile);
00144     }
00145     list *plist = get_paths("/data/captcha/reimgs.fg.list");
00146     char **paths = (char **)list_to_array(plist);
00147     int N = plist->size;
00148     int outputs = net.outputs;
00149 
00150     set_batch_network(&net, 1);
00151     srand(2222222);
00152     int i, j;
00153     for(i = 0; i < N; ++i){
00154         if (i%100 == 0) fprintf(stderr, "%d\n", i);
00155         image im = load_image_color(paths[i], net.w, net.h);
00156         float *X = im.data;
00157         float *predictions = network_predict(net, X);
00158         //printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
00159         int truth = -1;
00160         for(j = 0; j < 13; ++j){
00161             if (strstr(paths[i], labels[j])) truth = j;
00162         }
00163         if (truth == -1){
00164             fprintf(stderr, "bad: %s\n", paths[i]);
00165             return;
00166         }
00167         printf("%d, ", truth);
00168         for(j = 0; j < outputs; ++j){
00169             if (j != 0) printf(", ");
00170             printf("%f", predictions[j]);
00171         }
00172         printf("\n");
00173         fflush(stdout);
00174         free_image(im);
00175         if (filename) break;
00176     }
00177 }
00178 
00179 /*
00180    void train_captcha(char *cfgfile, char *weightfile)
00181    {
00182    float avg_loss = -1;
00183    srand(time(0));
00184    char *base = basecfg(cfgfile);
00185    printf("%s\n", base);
00186    network net = parse_network_cfg(cfgfile);
00187    if(weightfile){
00188    load_weights(&net, weightfile);
00189    }
00190    printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
00191    int imgs = 1024;
00192    int i = net.seen/imgs;
00193    list *plist = get_paths("/data/captcha/train.auto5");
00194    char **paths = (char **)list_to_array(plist);
00195    printf("%d\n", plist->size);
00196    clock_t time;
00197    while(1){
00198    ++i;
00199    time=clock();
00200    data train = load_data_captcha(paths, imgs, plist->size, 10, 200, 60);
00201    translate_data_rows(train, -128);
00202    scale_data_rows(train, 1./128);
00203    printf("Loaded: %lf seconds\n", sec(clock()-time));
00204    time=clock();
00205    float loss = train_network(net, train);
00206    net.seen += imgs;
00207    if(avg_loss == -1) avg_loss = loss;
00208    avg_loss = avg_loss*.9 + loss*.1;
00209    printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen);
00210    free_data(train);
00211    if(i%10==0){
00212    char buff[256];
00213    sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
00214    save_weights(net, buff);
00215    }
00216    }
00217    }
00218 
00219    void decode_captcha(char *cfgfile, char *weightfile)
00220    {
00221    setbuf(stdout, NULL);
00222    srand(time(0));
00223    network net = parse_network_cfg(cfgfile);
00224    set_batch_network(&net, 1);
00225    if(weightfile){
00226    load_weights(&net, weightfile);
00227    }
00228    char filename[256];
00229    while(1){
00230    printf("Enter filename: ");
00231    fgets(filename, 256, stdin);
00232    strtok(filename, "\n");
00233    image im = load_image_color(filename, 300, 57);
00234    scale_image(im, 1./255.);
00235    float *X = im.data;
00236    float *predictions = network_predict(net, X);
00237    image out  = float_to_image(300, 57, 1, predictions);
00238    show_image(out, "decoded");
00239 #ifdef OPENCV
00240 cvWaitKey(0);
00241 #endif
00242 free_image(im);
00243 }
00244 }
00245 
00246 void encode_captcha(char *cfgfile, char *weightfile)
00247 {
00248 float avg_loss = -1;
00249 srand(time(0));
00250 char *base = basecfg(cfgfile);
00251 printf("%s\n", base);
00252 network net = parse_network_cfg(cfgfile);
00253 if(weightfile){
00254     load_weights(&net, weightfile);
00255 }
00256 printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
00257 int imgs = 1024;
00258 int i = net.seen/imgs;
00259 list *plist = get_paths("/data/captcha/encode.list");
00260 char **paths = (char **)list_to_array(plist);
00261 printf("%d\n", plist->size);
00262 clock_t time;
00263 while(1){
00264     ++i;
00265     time=clock();
00266     data train = load_data_captcha_encode(paths, imgs, plist->size, 300, 57);
00267     scale_data_rows(train, 1./255);
00268     printf("Loaded: %lf seconds\n", sec(clock()-time));
00269     time=clock();
00270     float loss = train_network(net, train);
00271     net.seen += imgs;
00272     if(avg_loss == -1) avg_loss = loss;
00273     avg_loss = avg_loss*.9 + loss*.1;
00274     printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen);
00275     free_matrix(train.X);
00276     if(i%100==0){
00277         char buff[256];
00278         sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
00279         save_weights(net, buff);
00280     }
00281 }
00282 }
00283 
00284 void validate_captcha(char *cfgfile, char *weightfile)
00285 {
00286     srand(time(0));
00287     char *base = basecfg(cfgfile);
00288     printf("%s\n", base);
00289     network net = parse_network_cfg(cfgfile);
00290     if(weightfile){
00291         load_weights(&net, weightfile);
00292     }
00293     int numchars = 37;
00294     list *plist = get_paths("/data/captcha/solved.hard");
00295     char **paths = (char **)list_to_array(plist);
00296     int imgs = plist->size;
00297     data valid = load_data_captcha(paths, imgs, 0, 10, 200, 60);
00298     translate_data_rows(valid, -128);
00299     scale_data_rows(valid, 1./128);
00300     matrix pred = network_predict_data(net, valid);
00301     int i, k;
00302     int correct = 0;
00303     int total = 0;
00304     int accuracy = 0;
00305     for(i = 0; i < imgs; ++i){
00306         int allcorrect = 1;
00307         for(k = 0; k < 10; ++k){
00308             char truth = int_to_alphanum(max_index(valid.y.vals[i]+k*numchars, numchars));
00309             char prediction = int_to_alphanum(max_index(pred.vals[i]+k*numchars, numchars));
00310             if (truth != prediction) allcorrect=0;
00311             if (truth != '.' && truth == prediction) ++correct;
00312             if (truth != '.' || truth != prediction) ++total;
00313         }
00314         accuracy += allcorrect;
00315     }
00316     printf("Word Accuracy: %f, Char Accuracy %f\n", (float)accuracy/imgs, (float)correct/total);
00317     free_data(valid);
00318 }
00319 
00320 void test_captcha(char *cfgfile, char *weightfile)
00321 {
00322     setbuf(stdout, NULL);
00323     srand(time(0));
00324     //char *base = basecfg(cfgfile);
00325     //printf("%s\n", base);
00326     network net = parse_network_cfg(cfgfile);
00327     set_batch_network(&net, 1);
00328     if(weightfile){
00329         load_weights(&net, weightfile);
00330     }
00331     char filename[256];
00332     while(1){
00333         //printf("Enter filename: ");
00334         fgets(filename, 256, stdin);
00335         strtok(filename, "\n");
00336         image im = load_image_color(filename, 200, 60);
00337         translate_image(im, -128);
00338         scale_image(im, 1/128.);
00339         float *X = im.data;
00340         float *predictions = network_predict(net, X);
00341         print_letters(predictions, 10);
00342         free_image(im);
00343     }
00344 }
00345     */
00346 void run_captcha(int argc, char **argv)
00347 {
00348     if(argc < 4){
00349         fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
00350         return;
00351     }
00352 
00353     char *cfg = argv[3];
00354     char *weights = (argc > 4) ? argv[4] : 0;
00355     char *filename = (argc > 5) ? argv[5]: 0;
00356     if(0==strcmp(argv[2], "train")) train_captcha(cfg, weights);
00357     else if(0==strcmp(argv[2], "test")) test_captcha(cfg, weights, filename);
00358     else if(0==strcmp(argv[2], "valid")) valid_captcha(cfg, weights, filename);
00359     //if(0==strcmp(argv[2], "test")) test_captcha(cfg, weights);
00360     //else if(0==strcmp(argv[2], "encode")) encode_captcha(cfg, weights);
00361     //else if(0==strcmp(argv[2], "decode")) decode_captcha(cfg, weights);
00362     //else if(0==strcmp(argv[2], "valid")) validate_captcha(cfg, weights);
00363 }
00364 


rail_object_detector
Author(s):
autogenerated on Sat Jun 8 2019 20:26:29