rnn.c
Go to the documentation of this file.
00001 #include "network.h"
00002 #include "cost_layer.h"
00003 #include "utils.h"
00004 #include "blas.h"
00005 #include "parser.h"
00006 
00007 #ifdef OPENCV
00008 #include "opencv2/highgui/highgui_c.h"
00009 #endif
00010 
00011 typedef struct {
00012     float *x;
00013     float *y;
00014 } float_pair;
00015 
00016 int *read_tokenized_data(char *filename, size_t *read)
00017 {
00018     size_t size = 512;
00019     size_t count = 0;
00020     FILE *fp = fopen(filename, "r");
00021     int *d = calloc(size, sizeof(int));
00022     int n, one;
00023     one = fscanf(fp, "%d", &n);
00024     while(one == 1){
00025         ++count;
00026         if(count > size){
00027             size = size*2;
00028             d = realloc(d, size*sizeof(int));
00029         }
00030         d[count-1] = n;
00031         one = fscanf(fp, "%d", &n);
00032     }
00033     fclose(fp);
00034     d = realloc(d, count*sizeof(int));
00035     *read = count;
00036     return d;
00037 }
00038 
00039 char **read_tokens(char *filename, size_t *read)
00040 {
00041     size_t size = 512;
00042     size_t count = 0;
00043     FILE *fp = fopen(filename, "r");
00044     char **d = calloc(size, sizeof(char *));
00045     char *line;
00046     while((line=fgetl(fp)) != 0){
00047         ++count;
00048         if(count > size){
00049             size = size*2;
00050             d = realloc(d, size*sizeof(char *));
00051         }
00052         d[count-1] = line;
00053     }
00054     fclose(fp);
00055     d = realloc(d, count*sizeof(char *));
00056     *read = count;
00057     return d;
00058 }
00059 
00060 float_pair get_rnn_token_data(int *tokens, size_t *offsets, int characters, size_t len, int batch, int steps)
00061 {
00062     float *x = calloc(batch * steps * characters, sizeof(float));
00063     float *y = calloc(batch * steps * characters, sizeof(float));
00064     int i,j;
00065     for(i = 0; i < batch; ++i){
00066         for(j = 0; j < steps; ++j){
00067             int curr = tokens[(offsets[i])%len];
00068             int next = tokens[(offsets[i] + 1)%len];
00069 
00070             x[(j*batch + i)*characters + curr] = 1;
00071             y[(j*batch + i)*characters + next] = 1;
00072 
00073             offsets[i] = (offsets[i] + 1) % len;
00074 
00075             if(curr >= characters || curr < 0 || next >= characters || next < 0){
00076                 error("Bad char");
00077             }
00078         }
00079     }
00080     float_pair p;
00081     p.x = x;
00082     p.y = y;
00083     return p;
00084 }
00085 
00086 float_pair get_rnn_data(unsigned char *text, size_t *offsets, int characters, size_t len, int batch, int steps)
00087 {
00088     float *x = calloc(batch * steps * characters, sizeof(float));
00089     float *y = calloc(batch * steps * characters, sizeof(float));
00090     int i,j;
00091     for(i = 0; i < batch; ++i){
00092         for(j = 0; j < steps; ++j){
00093             unsigned char curr = text[(offsets[i])%len];
00094             unsigned char next = text[(offsets[i] + 1)%len];
00095 
00096             x[(j*batch + i)*characters + curr] = 1;
00097             y[(j*batch + i)*characters + next] = 1;
00098 
00099             offsets[i] = (offsets[i] + 1) % len;
00100 
00101             if(curr > 255 || curr <= 0 || next > 255 || next <= 0){
00102                 /*text[(index+j+2)%len] = 0;
00103                 printf("%ld %d %d %d %d\n", index, j, len, (int)text[index+j], (int)text[index+j+1]);
00104                 printf("%s", text+index);
00105                 */
00106                 error("Bad char");
00107             }
00108         }
00109     }
00110     float_pair p;
00111     p.x = x;
00112     p.y = y;
00113     return p;
00114 }
00115 
00116 void reset_rnn_state(network net, int b)
00117 {
00118     int i;
00119     for (i = 0; i < net.n; ++i) {
00120         #ifdef GPU
00121         layer l = net.layers[i];
00122         if(l.state_gpu){
00123             fill_ongpu(l.outputs, 0, l.state_gpu + l.outputs*b, 1);
00124         }
00125         #endif
00126     }
00127 }
00128 
00129 void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear, int tokenized)
00130 {
00131     srand(time(0));
00132     unsigned char *text = 0;
00133     int *tokens = 0;
00134     size_t size;
00135     if(tokenized){
00136         tokens = read_tokenized_data(filename, &size);
00137     } else {
00138         FILE *fp = fopen(filename, "rb");
00139 
00140         fseek(fp, 0, SEEK_END); 
00141         size = ftell(fp);
00142         fseek(fp, 0, SEEK_SET); 
00143 
00144         text = calloc(size+1, sizeof(char));
00145         fread(text, 1, size, fp);
00146         fclose(fp);
00147     }
00148 
00149     char *backup_directory = "/home/pjreddie/backup/";
00150     char *base = basecfg(cfgfile);
00151     fprintf(stderr, "%s\n", base);
00152     float avg_loss = -1;
00153     network net = parse_network_cfg(cfgfile);
00154     if(weightfile){
00155         load_weights(&net, weightfile);
00156     }
00157 
00158     int inputs = get_network_input_size(net);
00159     fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
00160     int batch = net.batch;
00161     int steps = net.time_steps;
00162     if(clear) *net.seen = 0;
00163     int i = (*net.seen)/net.batch;
00164 
00165     int streams = batch/steps;
00166     size_t *offsets = calloc(streams, sizeof(size_t));
00167     int j;
00168     for(j = 0; j < streams; ++j){
00169         offsets[j] = rand_size_t()%size;
00170     }
00171 
00172     clock_t time;
00173     while(get_current_batch(net) < net.max_batches){
00174         i += 1;
00175         time=clock();
00176         float_pair p;
00177         if(tokenized){
00178             p = get_rnn_token_data(tokens, offsets, inputs, size, streams, steps);
00179         }else{
00180             p = get_rnn_data(text, offsets, inputs, size, streams, steps);
00181         }
00182 
00183         float loss = train_network_datum(net, p.x, p.y) / (batch);
00184         free(p.x);
00185         free(p.y);
00186         if (avg_loss < 0) avg_loss = loss;
00187         avg_loss = avg_loss*.9 + loss*.1;
00188 
00189         int chars = get_current_batch(net)*batch;
00190         fprintf(stderr, "%d: %f, %f avg, %f rate, %lf seconds, %f epochs\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), (float) chars/size);
00191 
00192         for(j = 0; j < streams; ++j){
00193             //printf("%d\n", j);
00194             if(rand()%10 == 0){
00195                 //fprintf(stderr, "Reset\n");
00196                 offsets[j] = rand_size_t()%size;
00197                 reset_rnn_state(net, j);
00198             }
00199         }
00200 
00201         if(i%1000==0){
00202             char buff[256];
00203             sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
00204             save_weights(net, buff);
00205         }
00206         if(i%10==0){
00207             char buff[256];
00208             sprintf(buff, "%s/%s.backup", backup_directory, base);
00209             save_weights(net, buff);
00210         }
00211     }
00212     char buff[256];
00213     sprintf(buff, "%s/%s_final.weights", backup_directory, base);
00214     save_weights(net, buff);
00215 }
00216 
00217 void print_symbol(int n, char **tokens){
00218     if(tokens){
00219         printf("%s ", tokens[n]);
00220     } else {
00221         printf("%c", n);
00222     }
00223 }
00224 
00225 void test_char_rnn(char *cfgfile, char *weightfile, int num, char *seed, float temp, int rseed, char *token_file)
00226 {
00227     char **tokens = 0;
00228     if(token_file){
00229         size_t n;
00230         tokens = read_tokens(token_file, &n);
00231     }
00232 
00233     srand(rseed);
00234     char *base = basecfg(cfgfile);
00235     fprintf(stderr, "%s\n", base);
00236 
00237     network net = parse_network_cfg(cfgfile);
00238     if(weightfile){
00239         load_weights(&net, weightfile);
00240     }
00241     int inputs = get_network_input_size(net);
00242 
00243     int i, j;
00244     for(i = 0; i < net.n; ++i) net.layers[i].temperature = temp;
00245     int c = 0;
00246     int len = strlen(seed);
00247     float *input = calloc(inputs, sizeof(float));
00248 
00249     /*
00250        fill_cpu(inputs, 0, input, 1);
00251        for(i = 0; i < 10; ++i){
00252        network_predict(net, input);
00253        }
00254        fill_cpu(inputs, 0, input, 1);
00255      */
00256 
00257     for(i = 0; i < len-1; ++i){
00258         c = seed[i];
00259         input[c] = 1;
00260         network_predict(net, input);
00261         input[c] = 0;
00262         print_symbol(c, tokens);
00263     }
00264     if(len) c = seed[len-1];
00265     print_symbol(c, tokens);
00266     for(i = 0; i < num; ++i){
00267         input[c] = 1;
00268         float *out = network_predict(net, input);
00269         input[c] = 0;
00270         for(j = 32; j < 127; ++j){
00271             //printf("%d %c %f\n",j, j, out[j]);
00272         }
00273         for(j = 0; j < inputs; ++j){
00274             if (out[j] < .0001) out[j] = 0;
00275         }
00276         c = sample_array(out, inputs);
00277         print_symbol(c, tokens);
00278     }
00279     printf("\n");
00280 }
00281 
00282 void test_tactic_rnn(char *cfgfile, char *weightfile, int num, float temp, int rseed, char *token_file)
00283 {
00284     char **tokens = 0;
00285     if(token_file){
00286         size_t n;
00287         tokens = read_tokens(token_file, &n);
00288     }
00289 
00290     srand(rseed);
00291     char *base = basecfg(cfgfile);
00292     fprintf(stderr, "%s\n", base);
00293 
00294     network net = parse_network_cfg(cfgfile);
00295     if(weightfile){
00296         load_weights(&net, weightfile);
00297     }
00298     int inputs = get_network_input_size(net);
00299 
00300     int i, j;
00301     for(i = 0; i < net.n; ++i) net.layers[i].temperature = temp;
00302     int c = 0;
00303     float *input = calloc(inputs, sizeof(float));
00304     float *out = 0;
00305 
00306     while((c = getc(stdin)) != EOF){
00307         input[c] = 1;
00308         out = network_predict(net, input);
00309         input[c] = 0;
00310     }
00311     for(i = 0; i < num; ++i){
00312         for(j = 0; j < inputs; ++j){
00313             if (out[j] < .0001) out[j] = 0;
00314         }
00315         int next = sample_array(out, inputs);
00316         if(c == '.' && next == '\n') break;
00317         c = next;
00318         print_symbol(c, tokens);
00319 
00320         input[c] = 1;
00321         out = network_predict(net, input);
00322         input[c] = 0;
00323     }
00324     printf("\n");
00325 }
00326 
00327 void valid_tactic_rnn(char *cfgfile, char *weightfile, char *seed)
00328 {
00329     char *base = basecfg(cfgfile);
00330     fprintf(stderr, "%s\n", base);
00331 
00332     network net = parse_network_cfg(cfgfile);
00333     if(weightfile){
00334         load_weights(&net, weightfile);
00335     }
00336     int inputs = get_network_input_size(net);
00337 
00338     int count = 0;
00339     int words = 1;
00340     int c;
00341     int len = strlen(seed);
00342     float *input = calloc(inputs, sizeof(float));
00343     int i;
00344     for(i = 0; i < len; ++i){
00345         c = seed[i];
00346         input[(int)c] = 1;
00347         network_predict(net, input);
00348         input[(int)c] = 0;
00349     }
00350     float sum = 0;
00351     c = getc(stdin);
00352     float log2 = log(2);
00353     int in = 0;
00354     while(c != EOF){
00355         int next = getc(stdin);
00356         if(next == EOF) break;
00357         if(next < 0 || next >= 255) error("Out of range character");
00358 
00359         input[c] = 1;
00360         float *out = network_predict(net, input);
00361         input[c] = 0;
00362 
00363         if(c == '.' && next == '\n') in = 0;
00364         if(!in) {
00365             if(c == '>' && next == '>'){
00366                 in = 1;
00367                 ++words;
00368             }
00369             c = next;
00370             continue;
00371         }
00372         ++count;
00373         sum += log(out[next])/log2;
00374         c = next;
00375         printf("%d %d Perplexity: %4.4f    Word Perplexity: %4.4f\n", count, words, pow(2, -sum/count), pow(2, -sum/words));
00376     }
00377 }
00378 
00379 void valid_char_rnn(char *cfgfile, char *weightfile, char *seed)
00380 {
00381     char *base = basecfg(cfgfile);
00382     fprintf(stderr, "%s\n", base);
00383 
00384     network net = parse_network_cfg(cfgfile);
00385     if(weightfile){
00386         load_weights(&net, weightfile);
00387     }
00388     int inputs = get_network_input_size(net);
00389 
00390     int count = 0;
00391     int words = 1;
00392     int c;
00393     int len = strlen(seed);
00394     float *input = calloc(inputs, sizeof(float));
00395     int i;
00396     for(i = 0; i < len; ++i){
00397         c = seed[i];
00398         input[(int)c] = 1;
00399         network_predict(net, input);
00400         input[(int)c] = 0;
00401     }
00402     float sum = 0;
00403     c = getc(stdin);
00404     float log2 = log(2);
00405     while(c != EOF){
00406         int next = getc(stdin);
00407         if(next == EOF) break;
00408         if(next < 0 || next >= 255) error("Out of range character");
00409         ++count;
00410         if(next == ' ' || next == '\n' || next == '\t') ++words;
00411         input[c] = 1;
00412         float *out = network_predict(net, input);
00413         input[c] = 0;
00414         sum += log(out[next])/log2;
00415         c = next;
00416         printf("%d Perplexity: %4.4f    Word Perplexity: %4.4f\n", count, pow(2, -sum/count), pow(2, -sum/words));
00417     }
00418 }
00419 
00420 void vec_char_rnn(char *cfgfile, char *weightfile, char *seed)
00421 {
00422     char *base = basecfg(cfgfile);
00423     fprintf(stderr, "%s\n", base);
00424 
00425     network net = parse_network_cfg(cfgfile);
00426     if(weightfile){
00427         load_weights(&net, weightfile);
00428     }
00429     int inputs = get_network_input_size(net);
00430 
00431     int c;
00432     int seed_len = strlen(seed);
00433     float *input = calloc(inputs, sizeof(float));
00434     int i;
00435     char *line;
00436     while((line=fgetl(stdin)) != 0){
00437         reset_rnn_state(net, 0);
00438         for(i = 0; i < seed_len; ++i){
00439             c = seed[i];
00440             input[(int)c] = 1;
00441             network_predict(net, input);
00442             input[(int)c] = 0;
00443         }
00444         strip(line);
00445         int str_len = strlen(line);
00446         for(i = 0; i < str_len; ++i){
00447             c = line[i];
00448             input[(int)c] = 1;
00449             network_predict(net, input);
00450             input[(int)c] = 0;
00451         }
00452         c = ' ';
00453         input[(int)c] = 1;
00454         network_predict(net, input);
00455         input[(int)c] = 0;
00456 
00457         layer l = net.layers[0];
00458         #ifdef GPU
00459         cuda_pull_array(l.output_gpu, l.output, l.outputs);
00460         #endif
00461         printf("%s", line);
00462         for(i = 0; i < l.outputs; ++i){
00463             printf(",%g", l.output[i]);
00464         }
00465         printf("\n");
00466     }
00467 }
00468 
00469 void run_char_rnn(int argc, char **argv)
00470 {
00471     if(argc < 4){
00472         fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
00473         return;
00474     }
00475     char *filename = find_char_arg(argc, argv, "-file", "data/shakespeare.txt");
00476     char *seed = find_char_arg(argc, argv, "-seed", "\n\n");
00477     int len = find_int_arg(argc, argv, "-len", 1000);
00478     float temp = find_float_arg(argc, argv, "-temp", .7);
00479     int rseed = find_int_arg(argc, argv, "-srand", time(0));
00480     int clear = find_arg(argc, argv, "-clear");
00481     int tokenized = find_arg(argc, argv, "-tokenized");
00482     char *tokens = find_char_arg(argc, argv, "-tokens", 0);
00483 
00484     char *cfg = argv[3];
00485     char *weights = (argc > 4) ? argv[4] : 0;
00486     if(0==strcmp(argv[2], "train")) train_char_rnn(cfg, weights, filename, clear, tokenized);
00487     else if(0==strcmp(argv[2], "valid")) valid_char_rnn(cfg, weights, seed);
00488     else if(0==strcmp(argv[2], "validtactic")) valid_tactic_rnn(cfg, weights, seed);
00489     else if(0==strcmp(argv[2], "vec")) vec_char_rnn(cfg, weights, seed);
00490     else if(0==strcmp(argv[2], "generate")) test_char_rnn(cfg, weights, len, seed, temp, rseed, tokens);
00491     else if(0==strcmp(argv[2], "generatetactic")) test_tactic_rnn(cfg, weights, len, temp, rseed, tokens);
00492 }


rail_object_detector
Author(s):
autogenerated on Sat Jun 8 2019 20:26:30