00001 #include "network.h"
00002 #include "utils.h"
00003 #include "parser.h"
00004 #include "option_list.h"
00005 #include "blas.h"
00006
00007 #ifdef OPENCV
00008 #include "opencv2/highgui/highgui_c.h"
00009 #endif
00010
00011 int inverted = 1;
00012 int noi = 1;
00013 static const int nind = 5;
00014
00015 typedef struct {
00016 char **data;
00017 int n;
00018 } moves;
00019
00020 char *fgetgo(FILE *fp)
00021 {
00022 if(feof(fp)) return 0;
00023 size_t size = 94;
00024 char *line = malloc(size*sizeof(char));
00025 if(size != fread(line, sizeof(char), size, fp)){
00026 free(line);
00027 return 0;
00028 }
00029
00030 return line;
00031 }
00032
00033 moves load_go_moves(char *filename)
00034 {
00035 moves m;
00036 m.n = 128;
00037 m.data = calloc(128, sizeof(char*));
00038 FILE *fp = fopen(filename, "rb");
00039 int count = 0;
00040 char *line = 0;
00041 while((line = fgetgo(fp))){
00042 if(count >= m.n){
00043 m.n *= 2;
00044 m.data = realloc(m.data, m.n*sizeof(char*));
00045 }
00046 m.data[count] = line;
00047 ++count;
00048 }
00049 printf("%d\n", count);
00050 m.n = count;
00051 m.data = realloc(m.data, count*sizeof(char*));
00052 return m;
00053 }
00054
00055 void string_to_board(char *s, float *board)
00056 {
00057 int i, j;
00058
00059 int count = 0;
00060 for(i = 0; i < 91; ++i){
00061 char c = s[i];
00062 for(j = 0; j < 4; ++j){
00063 int me = (c >> (2*j)) & 1;
00064 int you = (c >> (2*j + 1)) & 1;
00065 if (me) board[count] = 1;
00066 else if (you) board[count] = -1;
00067 else board[count] = 0;
00068 ++count;
00069 if(count >= 19*19) break;
00070 }
00071 }
00072 }
00073
00074 void board_to_string(char *s, float *board)
00075 {
00076 int i, j;
00077 memset(s, 0, (19*19/4+1)*sizeof(char));
00078 int count = 0;
00079 for(i = 0; i < 91; ++i){
00080 for(j = 0; j < 4; ++j){
00081 int me = (board[count] == 1);
00082 int you = (board[count] == -1);
00083 if (me) s[i] = s[i] | (1<<(2*j));
00084 if (you) s[i] = s[i] | (1<<(2*j + 1));
00085 ++count;
00086 if(count >= 19*19) break;
00087 }
00088 }
00089 }
00090
00091 void random_go_moves(moves m, float *boards, float *labels, int n)
00092 {
00093 int i;
00094 memset(labels, 0, 19*19*n*sizeof(float));
00095 for(i = 0; i < n; ++i){
00096 char *b = m.data[rand()%m.n];
00097 int row = b[0];
00098 int col = b[1];
00099 labels[col + 19*(row + i*19)] = 1;
00100 string_to_board(b+2, boards+i*19*19);
00101 boards[col + 19*(row + i*19)] = 0;
00102
00103 int flip = rand()%2;
00104 int rotate = rand()%4;
00105 image in = float_to_image(19, 19, 1, boards+i*19*19);
00106 image out = float_to_image(19, 19, 1, labels+i*19*19);
00107 if(flip){
00108 flip_image(in);
00109 flip_image(out);
00110 }
00111 rotate_image_cw(in, rotate);
00112 rotate_image_cw(out, rotate);
00113 }
00114 }
00115
00116
00117 void train_go(char *cfgfile, char *weightfile)
00118 {
00119 srand(time(0));
00120 float avg_loss = -1;
00121 char *base = basecfg(cfgfile);
00122 printf("%s\n", base);
00123 network net = parse_network_cfg(cfgfile);
00124 if(weightfile){
00125 load_weights(&net, weightfile);
00126 }
00127 printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
00128
00129 char *backup_directory = "/home/pjreddie/backup/";
00130
00131 char buff[256];
00132 float *board = calloc(19*19*net.batch, sizeof(float));
00133 float *move = calloc(19*19*net.batch, sizeof(float));
00134 moves m = load_go_moves("/home/pjreddie/backup/go.train");
00135
00136
00137 int N = m.n;
00138 int epoch = (*net.seen)/N;
00139 while(get_current_batch(net) < net.max_batches || net.max_batches == 0){
00140 clock_t time=clock();
00141
00142 random_go_moves(m, board, move, net.batch);
00143 float loss = train_network_datum(net, board, move) / net.batch;
00144 if(avg_loss == -1) avg_loss = loss;
00145 avg_loss = avg_loss*.95 + loss*.05;
00146 printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen);
00147 if(*net.seen/N > epoch){
00148 epoch = *net.seen/N;
00149 char buff[256];
00150 sprintf(buff, "%s/%s_%d.weights", backup_directory,base, epoch);
00151 save_weights(net, buff);
00152
00153 }
00154 if(get_current_batch(net)%100 == 0){
00155 char buff[256];
00156 sprintf(buff, "%s/%s.backup",backup_directory,base);
00157 save_weights(net, buff);
00158 }
00159 if(get_current_batch(net)%10000 == 0){
00160 char buff[256];
00161 sprintf(buff, "%s/%s_%d.backup",backup_directory,base,get_current_batch(net));
00162 save_weights(net, buff);
00163 }
00164 }
00165 sprintf(buff, "%s/%s.weights", backup_directory, base);
00166 save_weights(net, buff);
00167
00168 free_network(net);
00169 free(base);
00170 }
00171
00172 void propagate_liberty(float *board, int *lib, int *visited, int row, int col, int side)
00173 {
00174 if (row < 0 || row > 18 || col < 0 || col > 18) return;
00175 int index = row*19 + col;
00176 if (board[index] != side) return;
00177 if (visited[index]) return;
00178 visited[index] = 1;
00179 lib[index] += 1;
00180 propagate_liberty(board, lib, visited, row+1, col, side);
00181 propagate_liberty(board, lib, visited, row-1, col, side);
00182 propagate_liberty(board, lib, visited, row, col+1, side);
00183 propagate_liberty(board, lib, visited, row, col-1, side);
00184 }
00185
00186
00187 int *calculate_liberties(float *board)
00188 {
00189 int *lib = calloc(19*19, sizeof(int));
00190 int visited[361];
00191 int i, j;
00192 for(j = 0; j < 19; ++j){
00193 for(i = 0; i < 19; ++i){
00194 memset(visited, 0, 19*19*sizeof(int));
00195 int index = j*19 + i;
00196 if(board[index] == 0){
00197 if ((i > 0) && board[index - 1]) propagate_liberty(board, lib, visited, j, i-1, board[index-1]);
00198 if ((i < 18) && board[index + 1]) propagate_liberty(board, lib, visited, j, i+1, board[index+1]);
00199 if ((j > 0) && board[index - 19]) propagate_liberty(board, lib, visited, j-1, i, board[index-19]);
00200 if ((j < 18) && board[index + 19]) propagate_liberty(board, lib, visited, j+1, i, board[index+19]);
00201 }
00202 }
00203 }
00204 return lib;
00205 }
00206
00207 void print_board(float *board, int swap, int *indexes)
00208 {
00209
00210 FILE *stream = stderr;
00211 int i,j,n;
00212 fprintf(stream, "\n\n");
00213 fprintf(stream, " ");
00214 for(i = 0; i < 19; ++i){
00215 fprintf(stream, "%c ", 'A' + i + 1*(i > 7 && noi));
00216 }
00217 fprintf(stream, "\n");
00218 for(j = 0; j < 19; ++j){
00219 fprintf(stream, "%2d", (inverted) ? 19-j : j+1);
00220 for(i = 0; i < 19; ++i){
00221 int index = j*19 + i;
00222 if(indexes){
00223 int found = 0;
00224 for(n = 0; n < nind; ++n){
00225 if(index == indexes[n]){
00226 found = 1;
00227
00228
00229
00230
00231
00232
00233
00234 if(n == 0) fprintf(stream, " 1");
00235 else if(n == 1) fprintf(stream, " 2");
00236 else if(n == 2) fprintf(stream, " 3");
00237 else if(n == 3) fprintf(stream, " 4");
00238 else if(n == 4) fprintf(stream, " 5");
00239 }
00240 }
00241 if(found) continue;
00242 }
00243
00244
00245 if(board[index]*-swap > 0) fprintf(stream, " O");
00246 else if(board[index]*-swap < 0) fprintf(stream, " X");
00247 else fprintf(stream, " ");
00248 }
00249 fprintf(stream, "\n");
00250 }
00251 }
00252
00253 void flip_board(float *board)
00254 {
00255 int i;
00256 for(i = 0; i < 19*19; ++i){
00257 board[i] = -board[i];
00258 }
00259 }
00260
00261 void predict_move(network net, float *board, float *move, int multi)
00262 {
00263 float *output = network_predict(net, board);
00264 copy_cpu(19*19, output, 1, move, 1);
00265 int i;
00266 if(multi){
00267 image bim = float_to_image(19, 19, 1, board);
00268 for(i = 1; i < 8; ++i){
00269 rotate_image_cw(bim, i);
00270 if(i >= 4) flip_image(bim);
00271
00272 float *output = network_predict(net, board);
00273 image oim = float_to_image(19, 19, 1, output);
00274
00275 if(i >= 4) flip_image(oim);
00276 rotate_image_cw(oim, -i);
00277
00278 axpy_cpu(19*19, 1, output, 1, move, 1);
00279
00280 if(i >= 4) flip_image(bim);
00281 rotate_image_cw(bim, -i);
00282 }
00283 scal_cpu(19*19, 1./8., move, 1);
00284 }
00285 for(i = 0; i < 19*19; ++i){
00286 if(board[i]) move[i] = 0;
00287 }
00288 }
00289
00290 void remove_connected(float *b, int *lib, int p, int r, int c)
00291 {
00292 if (r < 0 || r >= 19 || c < 0 || c >= 19) return;
00293 if (b[r*19 + c] != p) return;
00294 if (lib[r*19 + c] != 1) return;
00295 b[r*19 + c] = 0;
00296 remove_connected(b, lib, p, r+1, c);
00297 remove_connected(b, lib, p, r-1, c);
00298 remove_connected(b, lib, p, r, c+1);
00299 remove_connected(b, lib, p, r, c-1);
00300 }
00301
00302
00303 void move_go(float *b, int p, int r, int c)
00304 {
00305 int *l = calculate_liberties(b);
00306 b[r*19 + c] = p;
00307 remove_connected(b, l, -p, r+1, c);
00308 remove_connected(b, l, -p, r-1, c);
00309 remove_connected(b, l, -p, r, c+1);
00310 remove_connected(b, l, -p, r, c-1);
00311 free(l);
00312 }
00313
00314 int makes_safe_go(float *b, int *lib, int p, int r, int c){
00315 if (r < 0 || r >= 19 || c < 0 || c >= 19) return 0;
00316 if (b[r*19 + c] == -p){
00317 if (lib[r*19 + c] > 1) return 0;
00318 else return 1;
00319 }
00320 if (b[r*19 + c] == 0) return 1;
00321 if (lib[r*19 + c] > 1) return 1;
00322 return 0;
00323 }
00324
00325 int suicide_go(float *b, int p, int r, int c)
00326 {
00327 int *l = calculate_liberties(b);
00328 int safe = 0;
00329 safe = safe || makes_safe_go(b, l, p, r+1, c);
00330 safe = safe || makes_safe_go(b, l, p, r-1, c);
00331 safe = safe || makes_safe_go(b, l, p, r, c+1);
00332 safe = safe || makes_safe_go(b, l, p, r, c-1);
00333 free(l);
00334 return !safe;
00335 }
00336
00337 int legal_go(float *b, char *ko, int p, int r, int c)
00338 {
00339 if (b[r*19 + c]) return 0;
00340 char curr[91];
00341 char next[91];
00342 board_to_string(curr, b);
00343 move_go(b, p, r, c);
00344 board_to_string(next, b);
00345 string_to_board(curr, b);
00346 if(memcmp(next, ko, 91) == 0) return 0;
00347 return 1;
00348 }
00349
00350 int generate_move(network net, int player, float *board, int multi, float thresh, float temp, char *ko, int print)
00351 {
00352 int i, j;
00353 for(i = 0; i < net.n; ++i) net.layers[i].temperature = temp;
00354
00355 float move[361];
00356 if (player < 0) flip_board(board);
00357 predict_move(net, board, move, multi);
00358 if (player < 0) flip_board(board);
00359
00360
00361 for(i = 0; i < 19; ++i){
00362 for(j = 0; j < 19; ++j){
00363 if (!legal_go(board, ko, player, i, j)) move[i*19 + j] = 0;
00364 }
00365 }
00366
00367 int indexes[nind];
00368 top_k(move, 19*19, nind, indexes);
00369 if(thresh > move[indexes[0]]) thresh = move[indexes[nind-1]];
00370
00371 for(i = 0; i < 19; ++i){
00372 for(j = 0; j < 19; ++j){
00373 if (move[i*19 + j] < thresh) move[i*19 + j] = 0;
00374 }
00375 }
00376
00377
00378 int max = max_index(move, 19*19);
00379 int row = max / 19;
00380 int col = max % 19;
00381 int index = sample_array(move, 19*19);
00382
00383 if(print){
00384 top_k(move, 19*19, nind, indexes);
00385 for(i = 0; i < nind; ++i){
00386 if (!move[indexes[i]]) indexes[i] = -1;
00387 }
00388 print_board(board, player, indexes);
00389 for(i = 0; i < nind; ++i){
00390 fprintf(stderr, "%d: %f\n", i+1, move[indexes[i]]);
00391 }
00392 }
00393
00394 if(suicide_go(board, player, row, col)){
00395 return -1;
00396 }
00397 if(suicide_go(board, player, index/19, index%19)) index = max;
00398 return index;
00399 }
00400
00401 void valid_go(char *cfgfile, char *weightfile, int multi)
00402 {
00403 srand(time(0));
00404 char *base = basecfg(cfgfile);
00405 printf("%s\n", base);
00406 network net = parse_network_cfg(cfgfile);
00407 if(weightfile){
00408 load_weights(&net, weightfile);
00409 }
00410 set_batch_network(&net, 1);
00411 printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
00412
00413 float *board = calloc(19*19, sizeof(float));
00414 float *move = calloc(19*19, sizeof(float));
00415 moves m = load_go_moves("/home/pjreddie/backup/go.test");
00416
00417 int N = m.n;
00418 int i;
00419 int correct = 0;
00420 for(i = 0; i <N; ++i){
00421 char *b = m.data[i];
00422 int row = b[0];
00423 int col = b[1];
00424 int truth = col + 19*row;
00425 string_to_board(b+2, board);
00426 predict_move(net, board, move, multi);
00427 int index = max_index(move, 19*19);
00428 if(index == truth) ++correct;
00429 printf("%d Accuracy %f\n", i, (float) correct/(i+1));
00430 }
00431 }
00432
00433 void engine_go(char *filename, char *weightfile, int multi)
00434 {
00435 network net = parse_network_cfg(filename);
00436 if(weightfile){
00437 load_weights(&net, weightfile);
00438 }
00439 srand(time(0));
00440 set_batch_network(&net, 1);
00441 float *board = calloc(19*19, sizeof(float));
00442 char *one = calloc(91, sizeof(char));
00443 char *two = calloc(91, sizeof(char));
00444 int passed = 0;
00445 while(1){
00446 char buff[256];
00447 int id = 0;
00448 int has_id = (scanf("%d", &id) == 1);
00449 scanf("%s", buff);
00450 if (feof(stdin)) break;
00451 char ids[256];
00452 sprintf(ids, "%d", id);
00453
00454 if (!has_id) ids[0] = 0;
00455 if (!strcmp(buff, "protocol_version")){
00456 printf("=%s 2\n\n", ids);
00457 } else if (!strcmp(buff, "name")){
00458 printf("=%s DarkGo\n\n", ids);
00459 } else if (!strcmp(buff, "version")){
00460 printf("=%s 1.0\n\n", ids);
00461 } else if (!strcmp(buff, "known_command")){
00462 char comm[256];
00463 scanf("%s", comm);
00464 int known = (!strcmp(comm, "protocol_version") ||
00465 !strcmp(comm, "name") ||
00466 !strcmp(comm, "version") ||
00467 !strcmp(comm, "known_command") ||
00468 !strcmp(comm, "list_commands") ||
00469 !strcmp(comm, "quit") ||
00470 !strcmp(comm, "boardsize") ||
00471 !strcmp(comm, "clear_board") ||
00472 !strcmp(comm, "komi") ||
00473 !strcmp(comm, "final_status_list") ||
00474 !strcmp(comm, "play") ||
00475 !strcmp(comm, "genmove"));
00476 if(known) printf("=%s true\n\n", ids);
00477 else printf("=%s false\n\n", ids);
00478 } else if (!strcmp(buff, "list_commands")){
00479 printf("=%s protocol_version\nname\nversion\nknown_command\nlist_commands\nquit\nboardsize\nclear_board\nkomi\nplay\ngenmove\nfinal_status_list\n\n", ids);
00480 } else if (!strcmp(buff, "quit")){
00481 break;
00482 } else if (!strcmp(buff, "boardsize")){
00483 int boardsize = 0;
00484 scanf("%d", &boardsize);
00485
00486 if(boardsize != 19){
00487 printf("?%s unacceptable size\n\n", ids);
00488 } else {
00489 printf("=%s \n\n", ids);
00490 }
00491 } else if (!strcmp(buff, "clear_board")){
00492 passed = 0;
00493 memset(board, 0, 19*19*sizeof(float));
00494 printf("=%s \n\n", ids);
00495 } else if (!strcmp(buff, "komi")){
00496 float komi = 0;
00497 scanf("%f", &komi);
00498 printf("=%s \n\n", ids);
00499 } else if (!strcmp(buff, "play")){
00500 char color[256];
00501 scanf("%s ", color);
00502 char c;
00503 int r;
00504 int count = scanf("%c%d", &c, &r);
00505 int player = (color[0] == 'b' || color[0] == 'B') ? 1 : -1;
00506 if(c == 'p' && count < 2) {
00507 passed = 1;
00508 printf("=%s \n\n", ids);
00509 char *line = fgetl(stdin);
00510 free(line);
00511 fflush(stdout);
00512 fflush(stderr);
00513 continue;
00514 } else {
00515 passed = 0;
00516 }
00517 if(c >= 'A' && c <= 'Z') c = c - 'A';
00518 if(c >= 'a' && c <= 'z') c = c - 'a';
00519 if(c >= 8) --c;
00520 r = 19 - r;
00521 fprintf(stderr, "move: %d %d\n", r, c);
00522
00523 char *swap = two;
00524 two = one;
00525 one = swap;
00526 move_go(board, player, r, c);
00527 board_to_string(one, board);
00528
00529 printf("=%s \n\n", ids);
00530 print_board(board, 1, 0);
00531 } else if (!strcmp(buff, "genmove")){
00532 char color[256];
00533 scanf("%s", color);
00534 int player = (color[0] == 'b' || color[0] == 'B') ? 1 : -1;
00535
00536 int index = generate_move(net, player, board, multi, .1, .7, two, 1);
00537 if(passed || index < 0){
00538 printf("=%s pass\n\n", ids);
00539 passed = 0;
00540 } else {
00541 int row = index / 19;
00542 int col = index % 19;
00543
00544 char *swap = two;
00545 two = one;
00546 one = swap;
00547
00548 move_go(board, player, row, col);
00549 board_to_string(one, board);
00550 row = 19 - row;
00551 if (col >= 8) ++col;
00552 printf("=%s %c%d\n\n", ids, 'A' + col, row);
00553 print_board(board, 1, 0);
00554 }
00555
00556 } else if (!strcmp(buff, "p")){
00557
00558 } else if (!strcmp(buff, "final_status_list")){
00559 char type[256];
00560 scanf("%s", type);
00561 fprintf(stderr, "final_status\n");
00562 char *line = fgetl(stdin);
00563 free(line);
00564 if(type[0] == 'd' || type[0] == 'D'){
00565 FILE *f = fopen("game.txt", "w");
00566 int i, j;
00567 int count = 2;
00568 fprintf(f, "boardsize 19\n");
00569 fprintf(f, "clear_board\n");
00570 for(j = 0; j < 19; ++j){
00571 for(i = 0; i < 19; ++i){
00572 if(board[j*19 + i] == 1) fprintf(f, "play black %c%d\n", 'A'+i+(i>=8), 19-j);
00573 if(board[j*19 + i] == -1) fprintf(f, "play white %c%d\n", 'A'+i+(i>=8), 19-j);
00574 if(board[j*19 + i]) ++count;
00575 }
00576 }
00577 fprintf(f, "final_status_list dead\n");
00578 fclose(f);
00579 FILE *p = popen("./gnugo --mode gtp < game.txt", "r");
00580 for(i = 0; i < count; ++i){
00581 free(fgetl(p));
00582 free(fgetl(p));
00583 }
00584 char *l = 0;
00585 while((l = fgetl(p))){
00586 printf("%s\n", l);
00587 free(l);
00588 }
00589 } else {
00590 printf("?%s unknown command\n\n", ids);
00591 }
00592 } else {
00593 char *line = fgetl(stdin);
00594 free(line);
00595 printf("?%s unknown command\n\n", ids);
00596 }
00597 fflush(stdout);
00598 fflush(stderr);
00599 }
00600 }
00601
00602 void test_go(char *cfg, char *weights, int multi)
00603 {
00604 network net = parse_network_cfg(cfg);
00605 if(weights){
00606 load_weights(&net, weights);
00607 }
00608 srand(time(0));
00609 set_batch_network(&net, 1);
00610 float *board = calloc(19*19, sizeof(float));
00611 float *move = calloc(19*19, sizeof(float));
00612 int color = 1;
00613 while(1){
00614 float *output = network_predict(net, board);
00615 copy_cpu(19*19, output, 1, move, 1);
00616 int i;
00617 if(multi){
00618 image bim = float_to_image(19, 19, 1, board);
00619 for(i = 1; i < 8; ++i){
00620 rotate_image_cw(bim, i);
00621 if(i >= 4) flip_image(bim);
00622
00623 float *output = network_predict(net, board);
00624 image oim = float_to_image(19, 19, 1, output);
00625
00626 if(i >= 4) flip_image(oim);
00627 rotate_image_cw(oim, -i);
00628
00629 axpy_cpu(19*19, 1, output, 1, move, 1);
00630
00631 if(i >= 4) flip_image(bim);
00632 rotate_image_cw(bim, -i);
00633 }
00634 scal_cpu(19*19, 1./8., move, 1);
00635 }
00636 for(i = 0; i < 19*19; ++i){
00637 if(board[i]) move[i] = 0;
00638 }
00639
00640 int indexes[nind];
00641 int row, col;
00642 top_k(move, 19*19, nind, indexes);
00643 print_board(board, color, indexes);
00644 for(i = 0; i < nind; ++i){
00645 int index = indexes[i];
00646 row = index / 19;
00647 col = index % 19;
00648 printf("%d: %c %d, %.2f%%\n", i+1, col + 'A' + 1*(col > 7 && noi), (inverted)?19 - row : row+1, move[index]*100);
00649 }
00650
00651
00652 if(color == 1) printf("X Enter move: ");
00653 else printf("O Enter move: ");
00654
00655 char c;
00656 char *line = fgetl(stdin);
00657 int picked = 1;
00658 int dnum = sscanf(line, "%d", &picked);
00659 int cnum = sscanf(line, "%c", &c);
00660 if (strlen(line) == 0 || dnum) {
00661 --picked;
00662 if (picked < nind){
00663 int index = indexes[picked];
00664 row = index / 19;
00665 col = index % 19;
00666 board[row*19 + col] = 1;
00667 }
00668 } else if (cnum){
00669 if (c <= 'T' && c >= 'A'){
00670 int num = sscanf(line, "%c %d", &c, &row);
00671 row = (inverted)?19 - row : row-1;
00672 col = c - 'A';
00673 if (col > 7 && noi) col -= 1;
00674 if (num == 2) board[row*19 + col] = 1;
00675 } else if (c == 'p') {
00676
00677 } else if(c=='b' || c == 'w'){
00678 char g;
00679 int num = sscanf(line, "%c %c %d", &g, &c, &row);
00680 row = (inverted)?19 - row : row-1;
00681 col = c - 'A';
00682 if (col > 7 && noi) col -= 1;
00683 if (num == 3) board[row*19 + col] = (g == 'b') ? color : -color;
00684 } else if(c == 'c'){
00685 char g;
00686 int num = sscanf(line, "%c %c %d", &g, &c, &row);
00687 row = (inverted)?19 - row : row-1;
00688 col = c - 'A';
00689 if (col > 7 && noi) col -= 1;
00690 if (num == 3) board[row*19 + col] = 0;
00691 }
00692 }
00693 free(line);
00694 flip_board(board);
00695 color = -color;
00696 }
00697 }
00698
00699 float score_game(float *board)
00700 {
00701 FILE *f = fopen("game.txt", "w");
00702 int i, j;
00703 int count = 3;
00704 fprintf(f, "komi 6.5\n");
00705 fprintf(f, "boardsize 19\n");
00706 fprintf(f, "clear_board\n");
00707 for(j = 0; j < 19; ++j){
00708 for(i = 0; i < 19; ++i){
00709 if(board[j*19 + i] == 1) fprintf(f, "play black %c%d\n", 'A'+i+(i>=8), 19-j);
00710 if(board[j*19 + i] == -1) fprintf(f, "play white %c%d\n", 'A'+i+(i>=8), 19-j);
00711 if(board[j*19 + i]) ++count;
00712 }
00713 }
00714 fprintf(f, "final_score\n");
00715 fclose(f);
00716 FILE *p = popen("./gnugo --mode gtp < game.txt", "r");
00717 for(i = 0; i < count; ++i){
00718 free(fgetl(p));
00719 free(fgetl(p));
00720 }
00721 char *l = 0;
00722 float score = 0;
00723 char player = 0;
00724 while((l = fgetl(p))){
00725 fprintf(stderr, "%s \t", l);
00726 int n = sscanf(l, "= %c+%f", &player, &score);
00727 free(l);
00728 if (n == 2) break;
00729 }
00730 if(player == 'W') score = -score;
00731 pclose(p);
00732 return score;
00733 }
00734
00735 void self_go(char *filename, char *weightfile, char *f2, char *w2, int multi)
00736 {
00737 network net = parse_network_cfg(filename);
00738 if(weightfile){
00739 load_weights(&net, weightfile);
00740 }
00741
00742 network net2 = net;
00743 if(f2){
00744 net2 = parse_network_cfg(f2);
00745 if(w2){
00746 load_weights(&net2, w2);
00747 }
00748 }
00749 srand(time(0));
00750 char boards[300][93];
00751 int count = 0;
00752 set_batch_network(&net, 1);
00753 set_batch_network(&net2, 1);
00754 float *board = calloc(19*19, sizeof(float));
00755 char *one = calloc(91, sizeof(char));
00756 char *two = calloc(91, sizeof(char));
00757 int done = 0;
00758 int player = 1;
00759 int p1 = 0;
00760 int p2 = 0;
00761 int total = 0;
00762 while(1){
00763 if (done || count >= 300){
00764 float score = score_game(board);
00765 int i = (score > 0)? 0 : 1;
00766 if((score > 0) == (total%2==0)) ++p1;
00767 else ++p2;
00768 ++total;
00769 fprintf(stderr, "Total: %d, Player 1: %f, Player 2: %f\n", total, (float)p1/total, (float)p2/total);
00770 int j;
00771 for(; i < count; i += 2){
00772 for(j = 0; j < 93; ++j){
00773 printf("%c", boards[i][j]);
00774 }
00775 printf("\n");
00776 }
00777 memset(board, 0, 19*19*sizeof(float));
00778 player = 1;
00779 done = 0;
00780 count = 0;
00781 fflush(stdout);
00782 fflush(stderr);
00783 }
00784
00785
00786 network use = ((total%2==0) == (player==1)) ? net : net2;
00787 int index = generate_move(use, player, board, multi, .1, .7, two, 0);
00788 if(index < 0){
00789 done = 1;
00790 continue;
00791 }
00792 int row = index / 19;
00793 int col = index % 19;
00794
00795 char *swap = two;
00796 two = one;
00797 one = swap;
00798
00799 if(player < 0) flip_board(board);
00800 boards[count][0] = row;
00801 boards[count][1] = col;
00802 board_to_string(boards[count] + 2, board);
00803 if(player < 0) flip_board(board);
00804 ++count;
00805
00806 move_go(board, player, row, col);
00807 board_to_string(one, board);
00808
00809 player = -player;
00810 }
00811 }
00812
00813 void run_go(int argc, char **argv)
00814 {
00815
00816 if(argc < 4){
00817 fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
00818 return;
00819 }
00820
00821 char *cfg = argv[3];
00822 char *weights = (argc > 4) ? argv[4] : 0;
00823 char *c2 = (argc > 5) ? argv[5] : 0;
00824 char *w2 = (argc > 6) ? argv[6] : 0;
00825 int multi = find_arg(argc, argv, "-multi");
00826 if(0==strcmp(argv[2], "train")) train_go(cfg, weights);
00827 else if(0==strcmp(argv[2], "valid")) valid_go(cfg, weights, multi);
00828 else if(0==strcmp(argv[2], "self")) self_go(cfg, weights, c2, w2, multi);
00829 else if(0==strcmp(argv[2], "test")) test_go(cfg, weights, multi);
00830 else if(0==strcmp(argv[2], "engine")) engine_go(cfg, weights, multi);
00831 }
00832
00833