00001 #include "network.h"
00002 #include "cost_layer.h"
00003 #include "utils.h"
00004 #include "blas.h"
00005 #include "parser.h"
00006
00007 #ifdef OPENCV
00008 #include "opencv2/highgui/highgui_c.h"
00009 #endif
00010
00011 typedef struct {
00012 float *x;
00013 float *y;
00014 } float_pair;
00015
00016 int *read_tokenized_data(char *filename, size_t *read)
00017 {
00018 size_t size = 512;
00019 size_t count = 0;
00020 FILE *fp = fopen(filename, "r");
00021 int *d = calloc(size, sizeof(int));
00022 int n, one;
00023 one = fscanf(fp, "%d", &n);
00024 while(one == 1){
00025 ++count;
00026 if(count > size){
00027 size = size*2;
00028 d = realloc(d, size*sizeof(int));
00029 }
00030 d[count-1] = n;
00031 one = fscanf(fp, "%d", &n);
00032 }
00033 fclose(fp);
00034 d = realloc(d, count*sizeof(int));
00035 *read = count;
00036 return d;
00037 }
00038
00039 char **read_tokens(char *filename, size_t *read)
00040 {
00041 size_t size = 512;
00042 size_t count = 0;
00043 FILE *fp = fopen(filename, "r");
00044 char **d = calloc(size, sizeof(char *));
00045 char *line;
00046 while((line=fgetl(fp)) != 0){
00047 ++count;
00048 if(count > size){
00049 size = size*2;
00050 d = realloc(d, size*sizeof(char *));
00051 }
00052 d[count-1] = line;
00053 }
00054 fclose(fp);
00055 d = realloc(d, count*sizeof(char *));
00056 *read = count;
00057 return d;
00058 }
00059
00060 float_pair get_rnn_token_data(int *tokens, size_t *offsets, int characters, size_t len, int batch, int steps)
00061 {
00062 float *x = calloc(batch * steps * characters, sizeof(float));
00063 float *y = calloc(batch * steps * characters, sizeof(float));
00064 int i,j;
00065 for(i = 0; i < batch; ++i){
00066 for(j = 0; j < steps; ++j){
00067 int curr = tokens[(offsets[i])%len];
00068 int next = tokens[(offsets[i] + 1)%len];
00069
00070 x[(j*batch + i)*characters + curr] = 1;
00071 y[(j*batch + i)*characters + next] = 1;
00072
00073 offsets[i] = (offsets[i] + 1) % len;
00074
00075 if(curr >= characters || curr < 0 || next >= characters || next < 0){
00076 error("Bad char");
00077 }
00078 }
00079 }
00080 float_pair p;
00081 p.x = x;
00082 p.y = y;
00083 return p;
00084 }
00085
00086 float_pair get_rnn_data(unsigned char *text, size_t *offsets, int characters, size_t len, int batch, int steps)
00087 {
00088 float *x = calloc(batch * steps * characters, sizeof(float));
00089 float *y = calloc(batch * steps * characters, sizeof(float));
00090 int i,j;
00091 for(i = 0; i < batch; ++i){
00092 for(j = 0; j < steps; ++j){
00093 unsigned char curr = text[(offsets[i])%len];
00094 unsigned char next = text[(offsets[i] + 1)%len];
00095
00096 x[(j*batch + i)*characters + curr] = 1;
00097 y[(j*batch + i)*characters + next] = 1;
00098
00099 offsets[i] = (offsets[i] + 1) % len;
00100
00101 if(curr > 255 || curr <= 0 || next > 255 || next <= 0){
00102
00103
00104
00105
00106 error("Bad char");
00107 }
00108 }
00109 }
00110 float_pair p;
00111 p.x = x;
00112 p.y = y;
00113 return p;
00114 }
00115
00116 void reset_rnn_state(network net, int b)
00117 {
00118 int i;
00119 for (i = 0; i < net.n; ++i) {
00120 #ifdef GPU
00121 layer l = net.layers[i];
00122 if(l.state_gpu){
00123 fill_ongpu(l.outputs, 0, l.state_gpu + l.outputs*b, 1);
00124 }
00125 #endif
00126 }
00127 }
00128
00129 void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear, int tokenized)
00130 {
00131 srand(time(0));
00132 unsigned char *text = 0;
00133 int *tokens = 0;
00134 size_t size;
00135 if(tokenized){
00136 tokens = read_tokenized_data(filename, &size);
00137 } else {
00138 FILE *fp = fopen(filename, "rb");
00139
00140 fseek(fp, 0, SEEK_END);
00141 size = ftell(fp);
00142 fseek(fp, 0, SEEK_SET);
00143
00144 text = calloc(size+1, sizeof(char));
00145 fread(text, 1, size, fp);
00146 fclose(fp);
00147 }
00148
00149 char *backup_directory = "/home/pjreddie/backup/";
00150 char *base = basecfg(cfgfile);
00151 fprintf(stderr, "%s\n", base);
00152 float avg_loss = -1;
00153 network net = parse_network_cfg(cfgfile);
00154 if(weightfile){
00155 load_weights(&net, weightfile);
00156 }
00157
00158 int inputs = get_network_input_size(net);
00159 fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
00160 int batch = net.batch;
00161 int steps = net.time_steps;
00162 if(clear) *net.seen = 0;
00163 int i = (*net.seen)/net.batch;
00164
00165 int streams = batch/steps;
00166 size_t *offsets = calloc(streams, sizeof(size_t));
00167 int j;
00168 for(j = 0; j < streams; ++j){
00169 offsets[j] = rand_size_t()%size;
00170 }
00171
00172 clock_t time;
00173 while(get_current_batch(net) < net.max_batches){
00174 i += 1;
00175 time=clock();
00176 float_pair p;
00177 if(tokenized){
00178 p = get_rnn_token_data(tokens, offsets, inputs, size, streams, steps);
00179 }else{
00180 p = get_rnn_data(text, offsets, inputs, size, streams, steps);
00181 }
00182
00183 float loss = train_network_datum(net, p.x, p.y) / (batch);
00184 free(p.x);
00185 free(p.y);
00186 if (avg_loss < 0) avg_loss = loss;
00187 avg_loss = avg_loss*.9 + loss*.1;
00188
00189 int chars = get_current_batch(net)*batch;
00190 fprintf(stderr, "%d: %f, %f avg, %f rate, %lf seconds, %f epochs\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), (float) chars/size);
00191
00192 for(j = 0; j < streams; ++j){
00193
00194 if(rand()%10 == 0){
00195
00196 offsets[j] = rand_size_t()%size;
00197 reset_rnn_state(net, j);
00198 }
00199 }
00200
00201 if(i%1000==0){
00202 char buff[256];
00203 sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
00204 save_weights(net, buff);
00205 }
00206 if(i%10==0){
00207 char buff[256];
00208 sprintf(buff, "%s/%s.backup", backup_directory, base);
00209 save_weights(net, buff);
00210 }
00211 }
00212 char buff[256];
00213 sprintf(buff, "%s/%s_final.weights", backup_directory, base);
00214 save_weights(net, buff);
00215 }
00216
00217 void print_symbol(int n, char **tokens){
00218 if(tokens){
00219 printf("%s ", tokens[n]);
00220 } else {
00221 printf("%c", n);
00222 }
00223 }
00224
00225 void test_char_rnn(char *cfgfile, char *weightfile, int num, char *seed, float temp, int rseed, char *token_file)
00226 {
00227 char **tokens = 0;
00228 if(token_file){
00229 size_t n;
00230 tokens = read_tokens(token_file, &n);
00231 }
00232
00233 srand(rseed);
00234 char *base = basecfg(cfgfile);
00235 fprintf(stderr, "%s\n", base);
00236
00237 network net = parse_network_cfg(cfgfile);
00238 if(weightfile){
00239 load_weights(&net, weightfile);
00240 }
00241 int inputs = get_network_input_size(net);
00242
00243 int i, j;
00244 for(i = 0; i < net.n; ++i) net.layers[i].temperature = temp;
00245 int c = 0;
00246 int len = strlen(seed);
00247 float *input = calloc(inputs, sizeof(float));
00248
00249
00250
00251
00252
00253
00254
00255
00256
00257 for(i = 0; i < len-1; ++i){
00258 c = seed[i];
00259 input[c] = 1;
00260 network_predict(net, input);
00261 input[c] = 0;
00262 print_symbol(c, tokens);
00263 }
00264 if(len) c = seed[len-1];
00265 print_symbol(c, tokens);
00266 for(i = 0; i < num; ++i){
00267 input[c] = 1;
00268 float *out = network_predict(net, input);
00269 input[c] = 0;
00270 for(j = 32; j < 127; ++j){
00271
00272 }
00273 for(j = 0; j < inputs; ++j){
00274 if (out[j] < .0001) out[j] = 0;
00275 }
00276 c = sample_array(out, inputs);
00277 print_symbol(c, tokens);
00278 }
00279 printf("\n");
00280 }
00281
00282 void test_tactic_rnn(char *cfgfile, char *weightfile, int num, float temp, int rseed, char *token_file)
00283 {
00284 char **tokens = 0;
00285 if(token_file){
00286 size_t n;
00287 tokens = read_tokens(token_file, &n);
00288 }
00289
00290 srand(rseed);
00291 char *base = basecfg(cfgfile);
00292 fprintf(stderr, "%s\n", base);
00293
00294 network net = parse_network_cfg(cfgfile);
00295 if(weightfile){
00296 load_weights(&net, weightfile);
00297 }
00298 int inputs = get_network_input_size(net);
00299
00300 int i, j;
00301 for(i = 0; i < net.n; ++i) net.layers[i].temperature = temp;
00302 int c = 0;
00303 float *input = calloc(inputs, sizeof(float));
00304 float *out = 0;
00305
00306 while((c = getc(stdin)) != EOF){
00307 input[c] = 1;
00308 out = network_predict(net, input);
00309 input[c] = 0;
00310 }
00311 for(i = 0; i < num; ++i){
00312 for(j = 0; j < inputs; ++j){
00313 if (out[j] < .0001) out[j] = 0;
00314 }
00315 int next = sample_array(out, inputs);
00316 if(c == '.' && next == '\n') break;
00317 c = next;
00318 print_symbol(c, tokens);
00319
00320 input[c] = 1;
00321 out = network_predict(net, input);
00322 input[c] = 0;
00323 }
00324 printf("\n");
00325 }
00326
00327 void valid_tactic_rnn(char *cfgfile, char *weightfile, char *seed)
00328 {
00329 char *base = basecfg(cfgfile);
00330 fprintf(stderr, "%s\n", base);
00331
00332 network net = parse_network_cfg(cfgfile);
00333 if(weightfile){
00334 load_weights(&net, weightfile);
00335 }
00336 int inputs = get_network_input_size(net);
00337
00338 int count = 0;
00339 int words = 1;
00340 int c;
00341 int len = strlen(seed);
00342 float *input = calloc(inputs, sizeof(float));
00343 int i;
00344 for(i = 0; i < len; ++i){
00345 c = seed[i];
00346 input[(int)c] = 1;
00347 network_predict(net, input);
00348 input[(int)c] = 0;
00349 }
00350 float sum = 0;
00351 c = getc(stdin);
00352 float log2 = log(2);
00353 int in = 0;
00354 while(c != EOF){
00355 int next = getc(stdin);
00356 if(next == EOF) break;
00357 if(next < 0 || next >= 255) error("Out of range character");
00358
00359 input[c] = 1;
00360 float *out = network_predict(net, input);
00361 input[c] = 0;
00362
00363 if(c == '.' && next == '\n') in = 0;
00364 if(!in) {
00365 if(c == '>' && next == '>'){
00366 in = 1;
00367 ++words;
00368 }
00369 c = next;
00370 continue;
00371 }
00372 ++count;
00373 sum += log(out[next])/log2;
00374 c = next;
00375 printf("%d %d Perplexity: %4.4f Word Perplexity: %4.4f\n", count, words, pow(2, -sum/count), pow(2, -sum/words));
00376 }
00377 }
00378
00379 void valid_char_rnn(char *cfgfile, char *weightfile, char *seed)
00380 {
00381 char *base = basecfg(cfgfile);
00382 fprintf(stderr, "%s\n", base);
00383
00384 network net = parse_network_cfg(cfgfile);
00385 if(weightfile){
00386 load_weights(&net, weightfile);
00387 }
00388 int inputs = get_network_input_size(net);
00389
00390 int count = 0;
00391 int words = 1;
00392 int c;
00393 int len = strlen(seed);
00394 float *input = calloc(inputs, sizeof(float));
00395 int i;
00396 for(i = 0; i < len; ++i){
00397 c = seed[i];
00398 input[(int)c] = 1;
00399 network_predict(net, input);
00400 input[(int)c] = 0;
00401 }
00402 float sum = 0;
00403 c = getc(stdin);
00404 float log2 = log(2);
00405 while(c != EOF){
00406 int next = getc(stdin);
00407 if(next == EOF) break;
00408 if(next < 0 || next >= 255) error("Out of range character");
00409 ++count;
00410 if(next == ' ' || next == '\n' || next == '\t') ++words;
00411 input[c] = 1;
00412 float *out = network_predict(net, input);
00413 input[c] = 0;
00414 sum += log(out[next])/log2;
00415 c = next;
00416 printf("%d Perplexity: %4.4f Word Perplexity: %4.4f\n", count, pow(2, -sum/count), pow(2, -sum/words));
00417 }
00418 }
00419
00420 void vec_char_rnn(char *cfgfile, char *weightfile, char *seed)
00421 {
00422 char *base = basecfg(cfgfile);
00423 fprintf(stderr, "%s\n", base);
00424
00425 network net = parse_network_cfg(cfgfile);
00426 if(weightfile){
00427 load_weights(&net, weightfile);
00428 }
00429 int inputs = get_network_input_size(net);
00430
00431 int c;
00432 int seed_len = strlen(seed);
00433 float *input = calloc(inputs, sizeof(float));
00434 int i;
00435 char *line;
00436 while((line=fgetl(stdin)) != 0){
00437 reset_rnn_state(net, 0);
00438 for(i = 0; i < seed_len; ++i){
00439 c = seed[i];
00440 input[(int)c] = 1;
00441 network_predict(net, input);
00442 input[(int)c] = 0;
00443 }
00444 strip(line);
00445 int str_len = strlen(line);
00446 for(i = 0; i < str_len; ++i){
00447 c = line[i];
00448 input[(int)c] = 1;
00449 network_predict(net, input);
00450 input[(int)c] = 0;
00451 }
00452 c = ' ';
00453 input[(int)c] = 1;
00454 network_predict(net, input);
00455 input[(int)c] = 0;
00456
00457 layer l = net.layers[0];
00458 #ifdef GPU
00459 cuda_pull_array(l.output_gpu, l.output, l.outputs);
00460 #endif
00461 printf("%s", line);
00462 for(i = 0; i < l.outputs; ++i){
00463 printf(",%g", l.output[i]);
00464 }
00465 printf("\n");
00466 }
00467 }
00468
00469 void run_char_rnn(int argc, char **argv)
00470 {
00471 if(argc < 4){
00472 fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
00473 return;
00474 }
00475 char *filename = find_char_arg(argc, argv, "-file", "data/shakespeare.txt");
00476 char *seed = find_char_arg(argc, argv, "-seed", "\n\n");
00477 int len = find_int_arg(argc, argv, "-len", 1000);
00478 float temp = find_float_arg(argc, argv, "-temp", .7);
00479 int rseed = find_int_arg(argc, argv, "-srand", time(0));
00480 int clear = find_arg(argc, argv, "-clear");
00481 int tokenized = find_arg(argc, argv, "-tokenized");
00482 char *tokens = find_char_arg(argc, argv, "-tokens", 0);
00483
00484 char *cfg = argv[3];
00485 char *weights = (argc > 4) ? argv[4] : 0;
00486 if(0==strcmp(argv[2], "train")) train_char_rnn(cfg, weights, filename, clear, tokenized);
00487 else if(0==strcmp(argv[2], "valid")) valid_char_rnn(cfg, weights, seed);
00488 else if(0==strcmp(argv[2], "validtactic")) valid_tactic_rnn(cfg, weights, seed);
00489 else if(0==strcmp(argv[2], "vec")) vec_char_rnn(cfg, weights, seed);
00490 else if(0==strcmp(argv[2], "generate")) test_char_rnn(cfg, weights, len, seed, temp, rseed, tokens);
00491 else if(0==strcmp(argv[2], "generatetactic")) test_tactic_rnn(cfg, weights, len, temp, rseed, tokens);
00492 }