go.c
Go to the documentation of this file.
00001 #include "network.h"
00002 #include "utils.h"
00003 #include "parser.h"
00004 #include "option_list.h"
00005 #include "blas.h"
00006 
00007 #ifdef OPENCV
00008 #include "opencv2/highgui/highgui_c.h"
00009 #endif
00010 
00011 int inverted = 1;
00012 int noi = 1;
00013 static const int nind = 5;
00014 
00015 typedef struct {
00016     char **data;
00017     int n;
00018 } moves;
00019 
00020 char *fgetgo(FILE *fp)
00021 {
00022     if(feof(fp)) return 0;
00023     size_t size = 94;
00024     char *line = malloc(size*sizeof(char));
00025     if(size != fread(line, sizeof(char), size, fp)){
00026         free(line);
00027         return 0;
00028     }
00029 
00030     return line;
00031 }
00032 
00033 moves load_go_moves(char *filename)
00034 {
00035     moves m;
00036     m.n = 128;
00037     m.data = calloc(128, sizeof(char*));
00038     FILE *fp = fopen(filename, "rb");
00039     int count = 0;
00040     char *line = 0;
00041     while((line = fgetgo(fp))){
00042         if(count >= m.n){
00043             m.n *= 2;
00044             m.data = realloc(m.data, m.n*sizeof(char*));
00045         }
00046         m.data[count] = line;
00047         ++count;
00048     }
00049     printf("%d\n", count);
00050     m.n = count;
00051     m.data = realloc(m.data, count*sizeof(char*));
00052     return m;
00053 }
00054 
00055 void string_to_board(char *s, float *board)
00056 {
00057     int i, j;
00058     //memset(board, 0, 1*19*19*sizeof(float));
00059     int count = 0;
00060     for(i = 0; i < 91; ++i){
00061         char c = s[i];
00062         for(j = 0; j < 4; ++j){
00063             int me = (c >> (2*j)) & 1;
00064             int you = (c >> (2*j + 1)) & 1;
00065             if (me) board[count] = 1;
00066             else if (you) board[count] = -1;
00067             else board[count] = 0;
00068             ++count;
00069             if(count >= 19*19) break;
00070         }
00071     }
00072 }
00073 
00074 void board_to_string(char *s, float *board)
00075 {
00076     int i, j;
00077     memset(s, 0, (19*19/4+1)*sizeof(char));
00078     int count = 0;
00079     for(i = 0; i < 91; ++i){
00080         for(j = 0; j < 4; ++j){
00081             int me = (board[count] == 1);
00082             int you = (board[count] == -1);
00083             if (me) s[i] = s[i] | (1<<(2*j));
00084             if (you) s[i] = s[i] | (1<<(2*j + 1));
00085             ++count;
00086             if(count >= 19*19) break;
00087         }
00088     }
00089 }
00090 
00091 void random_go_moves(moves m, float *boards, float *labels, int n)
00092 {
00093     int i;
00094     memset(labels, 0, 19*19*n*sizeof(float));
00095     for(i = 0; i < n; ++i){
00096         char *b = m.data[rand()%m.n];
00097         int row = b[0];
00098         int col = b[1];
00099         labels[col + 19*(row + i*19)] = 1;
00100         string_to_board(b+2, boards+i*19*19);
00101         boards[col + 19*(row + i*19)] = 0;
00102 
00103         int flip = rand()%2;
00104         int rotate = rand()%4;
00105         image in = float_to_image(19, 19, 1, boards+i*19*19);
00106         image out = float_to_image(19, 19, 1, labels+i*19*19);
00107         if(flip){
00108             flip_image(in);
00109             flip_image(out);
00110         }
00111         rotate_image_cw(in, rotate);
00112         rotate_image_cw(out, rotate);
00113     }
00114 }
00115 
00116 
00117 void train_go(char *cfgfile, char *weightfile)
00118 {
00119     srand(time(0));
00120     float avg_loss = -1;
00121     char *base = basecfg(cfgfile);
00122     printf("%s\n", base);
00123     network net = parse_network_cfg(cfgfile);
00124     if(weightfile){
00125         load_weights(&net, weightfile);
00126     }
00127     printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
00128 
00129     char *backup_directory = "/home/pjreddie/backup/";
00130 
00131     char buff[256];
00132     float *board = calloc(19*19*net.batch, sizeof(float));
00133     float *move = calloc(19*19*net.batch, sizeof(float));
00134     moves m = load_go_moves("/home/pjreddie/backup/go.train");
00135     //moves m = load_go_moves("games.txt");
00136 
00137     int N = m.n;
00138     int epoch = (*net.seen)/N;
00139     while(get_current_batch(net) < net.max_batches || net.max_batches == 0){
00140         clock_t time=clock();
00141 
00142         random_go_moves(m, board, move, net.batch);
00143         float loss = train_network_datum(net, board, move) / net.batch;
00144         if(avg_loss == -1) avg_loss = loss;
00145         avg_loss = avg_loss*.95 + loss*.05;
00146         printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen);
00147         if(*net.seen/N > epoch){
00148             epoch = *net.seen/N;
00149             char buff[256];
00150             sprintf(buff, "%s/%s_%d.weights", backup_directory,base, epoch);
00151             save_weights(net, buff);
00152 
00153         }
00154         if(get_current_batch(net)%100 == 0){
00155             char buff[256];
00156             sprintf(buff, "%s/%s.backup",backup_directory,base);
00157             save_weights(net, buff);
00158         }
00159         if(get_current_batch(net)%10000 == 0){
00160             char buff[256];
00161             sprintf(buff, "%s/%s_%d.backup",backup_directory,base,get_current_batch(net));
00162             save_weights(net, buff);
00163         }
00164     }
00165     sprintf(buff, "%s/%s.weights", backup_directory, base);
00166     save_weights(net, buff);
00167 
00168     free_network(net);
00169     free(base);
00170 }
00171 
00172 void propagate_liberty(float *board, int *lib, int *visited, int row, int col, int side)
00173 {
00174     if (row < 0 || row > 18 || col < 0 || col > 18) return;
00175     int index = row*19 + col;
00176     if (board[index] != side) return;
00177     if (visited[index]) return;
00178     visited[index] = 1;
00179     lib[index] += 1;
00180     propagate_liberty(board, lib, visited, row+1, col, side);
00181     propagate_liberty(board, lib, visited, row-1, col, side);
00182     propagate_liberty(board, lib, visited, row, col+1, side);
00183     propagate_liberty(board, lib, visited, row, col-1, side);
00184 }
00185 
00186 
00187 int *calculate_liberties(float *board)
00188 {
00189     int *lib = calloc(19*19, sizeof(int));
00190     int visited[361];
00191     int i, j;
00192     for(j = 0; j < 19; ++j){
00193         for(i = 0; i < 19; ++i){
00194             memset(visited, 0, 19*19*sizeof(int));
00195             int index = j*19 + i;
00196             if(board[index] == 0){
00197                 if ((i > 0)  && board[index - 1]) propagate_liberty(board, lib, visited, j, i-1, board[index-1]);
00198                 if ((i < 18) && board[index + 1]) propagate_liberty(board, lib, visited, j, i+1, board[index+1]);
00199                 if ((j > 0)  && board[index - 19]) propagate_liberty(board, lib, visited, j-1, i, board[index-19]);
00200                 if ((j < 18) && board[index + 19]) propagate_liberty(board, lib, visited, j+1, i, board[index+19]);
00201             }
00202         }
00203     }
00204     return lib;
00205 }
00206 
00207 void print_board(float *board, int swap, int *indexes)
00208 {
00209     //FILE *stream = stdout;
00210     FILE *stream = stderr;
00211     int i,j,n;
00212     fprintf(stream, "\n\n");
00213     fprintf(stream, "   ");
00214     for(i = 0; i < 19; ++i){
00215         fprintf(stream, "%c ", 'A' + i + 1*(i > 7 && noi));
00216     }
00217     fprintf(stream, "\n");
00218     for(j = 0; j < 19; ++j){
00219         fprintf(stream, "%2d", (inverted) ? 19-j : j+1);
00220         for(i = 0; i < 19; ++i){
00221             int index = j*19 + i;
00222             if(indexes){
00223                 int found = 0;
00224                 for(n = 0; n < nind; ++n){
00225                     if(index == indexes[n]){
00226                         found = 1;
00227                         /*
00228                         if(n == 0) fprintf(stream, "\uff11");
00229                         else if(n == 1) fprintf(stream, "\uff12");
00230                         else if(n == 2) fprintf(stream, "\uff13");
00231                         else if(n == 3) fprintf(stream, "\uff14");
00232                         else if(n == 4) fprintf(stream, "\uff15");
00233                         */
00234                         if(n == 0) fprintf(stream, " 1");
00235                         else if(n == 1) fprintf(stream, " 2");
00236                         else if(n == 2) fprintf(stream, " 3");
00237                         else if(n == 3) fprintf(stream, " 4");
00238                         else if(n == 4) fprintf(stream, " 5");
00239                     }
00240                 }
00241                 if(found) continue;
00242             }
00243             //if(board[index]*-swap > 0) fprintf(stream, "\u25C9 ");
00244             //else if(board[index]*-swap < 0) fprintf(stream, "\u25EF ");
00245             if(board[index]*-swap > 0) fprintf(stream, " O");
00246             else if(board[index]*-swap < 0) fprintf(stream, " X");
00247             else fprintf(stream, "  ");
00248         }
00249         fprintf(stream, "\n");
00250     }
00251 }
00252 
00253 void flip_board(float *board)
00254 {
00255     int i;
00256     for(i = 0; i < 19*19; ++i){
00257         board[i] = -board[i];
00258     }
00259 }
00260 
00261 void predict_move(network net, float *board, float *move, int multi)
00262 {
00263     float *output = network_predict(net, board);
00264     copy_cpu(19*19, output, 1, move, 1);
00265     int i;
00266     if(multi){
00267         image bim = float_to_image(19, 19, 1, board);
00268         for(i = 1; i < 8; ++i){
00269             rotate_image_cw(bim, i);
00270             if(i >= 4) flip_image(bim);
00271 
00272             float *output = network_predict(net, board);
00273             image oim = float_to_image(19, 19, 1, output);
00274 
00275             if(i >= 4) flip_image(oim);
00276             rotate_image_cw(oim, -i);
00277 
00278             axpy_cpu(19*19, 1, output, 1, move, 1);
00279 
00280             if(i >= 4) flip_image(bim);
00281             rotate_image_cw(bim, -i);
00282         }
00283         scal_cpu(19*19, 1./8., move, 1);
00284     }
00285     for(i = 0; i < 19*19; ++i){
00286         if(board[i]) move[i] = 0;
00287     }
00288 }
00289 
00290 void remove_connected(float *b, int *lib, int p, int r, int c)
00291 {
00292     if (r < 0 || r >= 19 || c < 0 || c >= 19) return;
00293     if (b[r*19 + c] != p) return;
00294     if (lib[r*19 + c] != 1) return;
00295     b[r*19 + c] = 0;
00296     remove_connected(b, lib, p, r+1, c);
00297     remove_connected(b, lib, p, r-1, c);
00298     remove_connected(b, lib, p, r, c+1);
00299     remove_connected(b, lib, p, r, c-1);
00300 }
00301 
00302 
00303 void move_go(float *b, int p, int r, int c)
00304 {
00305     int *l = calculate_liberties(b);
00306     b[r*19 + c] = p;
00307     remove_connected(b, l, -p, r+1, c);
00308     remove_connected(b, l, -p, r-1, c);
00309     remove_connected(b, l, -p, r, c+1);
00310     remove_connected(b, l, -p, r, c-1);
00311     free(l);
00312 }
00313 
00314 int makes_safe_go(float *b, int *lib, int p, int r, int c){
00315     if (r < 0 || r >= 19 || c < 0 || c >= 19) return 0;
00316     if (b[r*19 + c] == -p){
00317         if (lib[r*19 + c] > 1) return 0;
00318         else return 1;
00319     }
00320     if (b[r*19 + c] == 0) return 1;
00321     if (lib[r*19 + c] > 1) return 1;
00322     return 0;
00323 }
00324 
00325 int suicide_go(float *b, int p, int r, int c)
00326 {
00327     int *l = calculate_liberties(b);
00328     int safe = 0;
00329     safe = safe || makes_safe_go(b, l, p, r+1, c);
00330     safe = safe || makes_safe_go(b, l, p, r-1, c);
00331     safe = safe || makes_safe_go(b, l, p, r, c+1);
00332     safe = safe || makes_safe_go(b, l, p, r, c-1);
00333     free(l);
00334     return !safe;
00335 }
00336 
00337 int legal_go(float *b, char *ko, int p, int r, int c)
00338 {
00339     if (b[r*19 + c]) return 0;
00340     char curr[91];
00341     char next[91];
00342     board_to_string(curr, b);
00343     move_go(b, p, r, c);
00344     board_to_string(next, b);
00345     string_to_board(curr, b);
00346     if(memcmp(next, ko, 91) == 0) return 0;
00347     return 1;
00348 }
00349 
00350 int generate_move(network net, int player, float *board, int multi, float thresh, float temp, char *ko, int print)
00351 {
00352     int i, j;
00353     for(i = 0; i < net.n; ++i) net.layers[i].temperature = temp;
00354 
00355     float move[361];
00356     if (player < 0) flip_board(board);
00357     predict_move(net, board, move, multi);
00358     if (player < 0) flip_board(board);
00359 
00360     
00361     for(i = 0; i < 19; ++i){
00362         for(j = 0; j < 19; ++j){
00363             if (!legal_go(board, ko, player, i, j)) move[i*19 + j] = 0;
00364         }
00365     }
00366 
00367     int indexes[nind];
00368     top_k(move, 19*19, nind, indexes);
00369     if(thresh > move[indexes[0]]) thresh = move[indexes[nind-1]];
00370 
00371     for(i = 0; i < 19; ++i){
00372         for(j = 0; j < 19; ++j){
00373             if (move[i*19 + j] < thresh) move[i*19 + j] = 0;
00374         }
00375     }
00376 
00377 
00378     int max = max_index(move, 19*19);
00379     int row = max / 19;
00380     int col = max % 19;
00381     int index = sample_array(move, 19*19);
00382 
00383     if(print){
00384         top_k(move, 19*19, nind, indexes);
00385         for(i = 0; i < nind; ++i){
00386             if (!move[indexes[i]]) indexes[i] = -1;
00387         }
00388         print_board(board, player, indexes);
00389         for(i = 0; i < nind; ++i){
00390             fprintf(stderr, "%d: %f\n", i+1, move[indexes[i]]);
00391         }
00392     }
00393 
00394     if(suicide_go(board, player, row, col)){
00395         return -1; 
00396     }
00397     if(suicide_go(board, player, index/19, index%19)) index = max;
00398     return index;
00399 }
00400 
00401 void valid_go(char *cfgfile, char *weightfile, int multi)
00402 {
00403     srand(time(0));
00404     char *base = basecfg(cfgfile);
00405     printf("%s\n", base);
00406     network net = parse_network_cfg(cfgfile);
00407     if(weightfile){
00408         load_weights(&net, weightfile);
00409     }
00410     set_batch_network(&net, 1);
00411     printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
00412 
00413     float *board = calloc(19*19, sizeof(float));
00414     float *move = calloc(19*19, sizeof(float));
00415     moves m = load_go_moves("/home/pjreddie/backup/go.test");
00416 
00417     int N = m.n;
00418     int i;
00419     int correct = 0;
00420     for(i = 0; i <N; ++i){
00421         char *b = m.data[i];
00422         int row = b[0];
00423         int col = b[1];
00424         int truth = col + 19*row;
00425         string_to_board(b+2, board);
00426         predict_move(net, board, move, multi);
00427         int index = max_index(move, 19*19);
00428         if(index == truth) ++correct;
00429         printf("%d Accuracy %f\n", i, (float) correct/(i+1));
00430     }
00431 }
00432 
00433 void engine_go(char *filename, char *weightfile, int multi)
00434 {
00435     network net = parse_network_cfg(filename);
00436     if(weightfile){
00437         load_weights(&net, weightfile);
00438     }
00439     srand(time(0));
00440     set_batch_network(&net, 1);
00441     float *board = calloc(19*19, sizeof(float));
00442     char *one = calloc(91, sizeof(char));
00443     char *two = calloc(91, sizeof(char));
00444     int passed = 0;
00445     while(1){
00446         char buff[256];
00447         int id = 0;
00448         int has_id = (scanf("%d", &id) == 1);
00449         scanf("%s", buff);
00450         if (feof(stdin)) break;
00451         char ids[256];
00452         sprintf(ids, "%d", id);
00453         //fprintf(stderr, "%s\n", buff);
00454         if (!has_id) ids[0] = 0;
00455         if (!strcmp(buff, "protocol_version")){
00456             printf("=%s 2\n\n", ids);
00457         } else if (!strcmp(buff, "name")){
00458             printf("=%s DarkGo\n\n", ids);
00459         } else if (!strcmp(buff, "version")){
00460             printf("=%s 1.0\n\n", ids);
00461         } else if (!strcmp(buff, "known_command")){
00462             char comm[256];
00463             scanf("%s", comm);
00464             int known = (!strcmp(comm, "protocol_version") || 
00465                     !strcmp(comm, "name") || 
00466                     !strcmp(comm, "version") || 
00467                     !strcmp(comm, "known_command") || 
00468                     !strcmp(comm, "list_commands") || 
00469                     !strcmp(comm, "quit") || 
00470                     !strcmp(comm, "boardsize") || 
00471                     !strcmp(comm, "clear_board") || 
00472                     !strcmp(comm, "komi") || 
00473                     !strcmp(comm, "final_status_list") || 
00474                     !strcmp(comm, "play") || 
00475                     !strcmp(comm, "genmove"));
00476             if(known) printf("=%s true\n\n", ids);
00477             else printf("=%s false\n\n", ids);
00478         } else if (!strcmp(buff, "list_commands")){
00479             printf("=%s protocol_version\nname\nversion\nknown_command\nlist_commands\nquit\nboardsize\nclear_board\nkomi\nplay\ngenmove\nfinal_status_list\n\n", ids);
00480         } else if (!strcmp(buff, "quit")){
00481             break;
00482         } else if (!strcmp(buff, "boardsize")){
00483             int boardsize = 0;
00484             scanf("%d", &boardsize);
00485             //fprintf(stderr, "%d\n", boardsize);
00486             if(boardsize != 19){
00487                 printf("?%s unacceptable size\n\n", ids);
00488             } else {
00489                 printf("=%s \n\n", ids);
00490             }
00491         } else if (!strcmp(buff, "clear_board")){
00492             passed = 0;
00493             memset(board, 0, 19*19*sizeof(float));
00494             printf("=%s \n\n", ids);
00495         } else if (!strcmp(buff, "komi")){
00496             float komi = 0;
00497             scanf("%f", &komi);
00498             printf("=%s \n\n", ids);
00499         } else if (!strcmp(buff, "play")){
00500             char color[256];
00501             scanf("%s ", color);
00502             char c;
00503             int r;
00504             int count = scanf("%c%d", &c, &r);
00505             int player = (color[0] == 'b' || color[0] == 'B') ? 1 : -1;
00506             if(c == 'p' && count < 2) {
00507                 passed = 1;
00508                 printf("=%s \n\n", ids);
00509                 char *line = fgetl(stdin);
00510                 free(line);
00511                 fflush(stdout);
00512                 fflush(stderr);
00513                 continue;
00514             } else {
00515                 passed = 0;
00516             }
00517             if(c >= 'A' && c <= 'Z') c = c - 'A';
00518             if(c >= 'a' && c <= 'z') c = c - 'a';
00519             if(c >= 8) --c;
00520             r = 19 - r;
00521             fprintf(stderr, "move: %d %d\n", r, c);
00522 
00523             char *swap = two;
00524             two = one;
00525             one = swap;
00526             move_go(board, player, r, c);
00527             board_to_string(one, board);
00528 
00529             printf("=%s \n\n", ids);
00530             print_board(board, 1, 0);
00531         } else if (!strcmp(buff, "genmove")){
00532             char color[256];
00533             scanf("%s", color);
00534             int player = (color[0] == 'b' || color[0] == 'B') ? 1 : -1;
00535 
00536             int index = generate_move(net, player, board, multi, .1, .7, two, 1);
00537             if(passed || index < 0){
00538                 printf("=%s pass\n\n", ids);
00539                 passed = 0;
00540             } else {
00541                 int row = index / 19;
00542                 int col = index % 19;
00543 
00544                 char *swap = two;
00545                 two = one;
00546                 one = swap;
00547 
00548                 move_go(board, player, row, col);
00549                 board_to_string(one, board);
00550                 row = 19 - row;
00551                 if (col >= 8) ++col;
00552                 printf("=%s %c%d\n\n", ids, 'A' + col, row);
00553                 print_board(board, 1, 0);
00554             }
00555 
00556         } else if (!strcmp(buff, "p")){
00557             //print_board(board, 1, 0);
00558         } else if (!strcmp(buff, "final_status_list")){
00559             char type[256];
00560             scanf("%s", type);
00561             fprintf(stderr, "final_status\n");
00562             char *line = fgetl(stdin);
00563             free(line);
00564             if(type[0] == 'd' || type[0] == 'D'){
00565                 FILE *f = fopen("game.txt", "w");
00566                 int i, j;
00567                 int count = 2;
00568                 fprintf(f, "boardsize 19\n");
00569                 fprintf(f, "clear_board\n");
00570                 for(j = 0; j < 19; ++j){
00571                     for(i = 0; i < 19; ++i){
00572                         if(board[j*19 + i] == 1) fprintf(f, "play black %c%d\n", 'A'+i+(i>=8), 19-j);
00573                         if(board[j*19 + i] == -1) fprintf(f, "play white %c%d\n", 'A'+i+(i>=8), 19-j);
00574                         if(board[j*19 + i]) ++count;
00575                     }
00576                 }
00577                 fprintf(f, "final_status_list dead\n");
00578                 fclose(f);
00579                 FILE *p = popen("./gnugo --mode gtp < game.txt", "r");
00580                 for(i = 0; i < count; ++i){
00581                     free(fgetl(p));
00582                     free(fgetl(p));
00583                 }
00584                 char *l = 0;
00585                 while((l = fgetl(p))){
00586                     printf("%s\n", l);
00587                     free(l);
00588                 }
00589             } else {
00590                 printf("?%s unknown command\n\n", ids);
00591             }
00592         } else {
00593             char *line = fgetl(stdin);
00594             free(line);
00595             printf("?%s unknown command\n\n", ids);
00596         }
00597         fflush(stdout);
00598         fflush(stderr);
00599     }
00600 }
00601 
00602 void test_go(char *cfg, char *weights, int multi)
00603 {
00604     network net = parse_network_cfg(cfg);
00605     if(weights){
00606         load_weights(&net, weights);
00607     }
00608     srand(time(0));
00609     set_batch_network(&net, 1);
00610     float *board = calloc(19*19, sizeof(float));
00611     float *move = calloc(19*19, sizeof(float));
00612     int color = 1;
00613     while(1){
00614         float *output = network_predict(net, board);
00615         copy_cpu(19*19, output, 1, move, 1);
00616         int i;
00617         if(multi){
00618             image bim = float_to_image(19, 19, 1, board);
00619             for(i = 1; i < 8; ++i){
00620                 rotate_image_cw(bim, i);
00621                 if(i >= 4) flip_image(bim);
00622 
00623                 float *output = network_predict(net, board);
00624                 image oim = float_to_image(19, 19, 1, output);
00625 
00626                 if(i >= 4) flip_image(oim);
00627                 rotate_image_cw(oim, -i);
00628 
00629                 axpy_cpu(19*19, 1, output, 1, move, 1);
00630 
00631                 if(i >= 4) flip_image(bim);
00632                 rotate_image_cw(bim, -i);
00633             }
00634             scal_cpu(19*19, 1./8., move, 1);
00635         }
00636         for(i = 0; i < 19*19; ++i){
00637             if(board[i]) move[i] = 0;
00638         }
00639 
00640         int indexes[nind];
00641         int row, col;
00642         top_k(move, 19*19, nind, indexes);
00643         print_board(board, color, indexes);
00644         for(i = 0; i < nind; ++i){
00645             int index = indexes[i];
00646             row = index / 19;
00647             col = index % 19;
00648             printf("%d: %c %d, %.2f%%\n", i+1, col + 'A' + 1*(col > 7 && noi), (inverted)?19 - row : row+1, move[index]*100);
00649         }
00650         //if(color == 1) printf("\u25EF Enter move: ");
00651         //else printf("\u25C9 Enter move: ");
00652         if(color == 1) printf("X Enter move: ");
00653         else printf("O Enter move: ");
00654 
00655         char c;
00656         char *line = fgetl(stdin);
00657         int picked = 1;
00658         int dnum = sscanf(line, "%d", &picked);
00659         int cnum = sscanf(line, "%c", &c);
00660         if (strlen(line) == 0 || dnum) {
00661             --picked;
00662             if (picked < nind){
00663                 int index = indexes[picked];
00664                 row = index / 19;
00665                 col = index % 19;
00666                 board[row*19 + col] = 1;
00667             }
00668         } else if (cnum){
00669             if (c <= 'T' && c >= 'A'){
00670                 int num = sscanf(line, "%c %d", &c, &row);
00671                 row = (inverted)?19 - row : row-1;
00672                 col = c - 'A';
00673                 if (col > 7 && noi) col -= 1;
00674                 if (num == 2) board[row*19 + col] = 1;
00675             } else if (c == 'p') {
00676                 // Pass
00677             } else if(c=='b' || c == 'w'){
00678                 char g;
00679                 int num = sscanf(line, "%c %c %d", &g, &c, &row);
00680                 row = (inverted)?19 - row : row-1;
00681                 col = c - 'A';
00682                 if (col > 7 && noi) col -= 1;
00683                 if (num == 3) board[row*19 + col] = (g == 'b') ? color : -color;
00684             } else if(c == 'c'){
00685                 char g;
00686                 int num = sscanf(line, "%c %c %d", &g, &c, &row);
00687                 row = (inverted)?19 - row : row-1;
00688                 col = c - 'A';
00689                 if (col > 7 && noi) col -= 1;
00690                 if (num == 3) board[row*19 + col] = 0;
00691             }
00692         }
00693         free(line);
00694         flip_board(board);
00695         color = -color;
00696     }
00697 }
00698 
00699 float score_game(float *board)
00700 {
00701     FILE *f = fopen("game.txt", "w");
00702     int i, j;
00703     int count = 3;
00704     fprintf(f, "komi 6.5\n");
00705     fprintf(f, "boardsize 19\n");
00706     fprintf(f, "clear_board\n");
00707     for(j = 0; j < 19; ++j){
00708         for(i = 0; i < 19; ++i){
00709             if(board[j*19 + i] == 1) fprintf(f, "play black %c%d\n", 'A'+i+(i>=8), 19-j);
00710             if(board[j*19 + i] == -1) fprintf(f, "play white %c%d\n", 'A'+i+(i>=8), 19-j);
00711             if(board[j*19 + i]) ++count;
00712         }
00713     }
00714     fprintf(f, "final_score\n");
00715     fclose(f);
00716     FILE *p = popen("./gnugo --mode gtp < game.txt", "r");
00717     for(i = 0; i < count; ++i){
00718         free(fgetl(p));
00719         free(fgetl(p));
00720     }
00721     char *l = 0;
00722     float score = 0;
00723     char player = 0;
00724     while((l = fgetl(p))){
00725         fprintf(stderr, "%s  \t", l);
00726         int n = sscanf(l, "= %c+%f", &player, &score);
00727         free(l);
00728         if (n == 2) break;
00729     }
00730     if(player == 'W') score = -score;
00731     pclose(p);
00732     return score;
00733 }
00734 
00735 void self_go(char *filename, char *weightfile, char *f2, char *w2, int multi)
00736 {
00737     network net = parse_network_cfg(filename);
00738     if(weightfile){
00739         load_weights(&net, weightfile);
00740     }
00741 
00742     network net2 = net;
00743     if(f2){
00744         net2 = parse_network_cfg(f2);
00745         if(w2){
00746             load_weights(&net2, w2);
00747         }
00748     }
00749     srand(time(0));
00750     char boards[300][93];
00751     int count = 0;
00752     set_batch_network(&net, 1);
00753     set_batch_network(&net2, 1);
00754     float *board = calloc(19*19, sizeof(float));
00755     char *one = calloc(91, sizeof(char));
00756     char *two = calloc(91, sizeof(char));
00757     int done = 0;
00758     int player = 1;
00759     int p1 = 0;
00760     int p2 = 0;
00761     int total = 0;
00762     while(1){
00763         if (done || count >= 300){
00764             float score = score_game(board);
00765             int i = (score > 0)? 0 : 1;
00766             if((score > 0) == (total%2==0)) ++p1;
00767             else ++p2;
00768             ++total;
00769             fprintf(stderr, "Total: %d, Player 1: %f, Player 2: %f\n", total, (float)p1/total, (float)p2/total);
00770             int j;
00771             for(; i < count; i += 2){
00772                 for(j = 0; j < 93; ++j){
00773                     printf("%c", boards[i][j]);
00774                 }
00775                 printf("\n");
00776             }
00777             memset(board, 0, 19*19*sizeof(float));
00778             player = 1;
00779             done = 0;
00780             count = 0;
00781             fflush(stdout);
00782             fflush(stderr);
00783         }
00784         //print_board(board, 1, 0);
00785         //sleep(1);
00786         network use = ((total%2==0) == (player==1)) ? net : net2;
00787         int index = generate_move(use, player, board, multi, .1, .7, two, 0);
00788         if(index < 0){
00789             done = 1;
00790             continue;
00791         }
00792         int row = index / 19;
00793         int col = index % 19;
00794 
00795         char *swap = two;
00796         two = one;
00797         one = swap;
00798 
00799         if(player < 0) flip_board(board);
00800         boards[count][0] = row;
00801         boards[count][1] = col;
00802         board_to_string(boards[count] + 2, board);
00803         if(player < 0) flip_board(board);
00804         ++count;
00805 
00806         move_go(board, player, row, col);
00807         board_to_string(one, board);
00808 
00809         player = -player;
00810     }
00811 }
00812 
00813 void run_go(int argc, char **argv)
00814 {
00815     //boards_go();
00816     if(argc < 4){
00817         fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
00818         return;
00819     }
00820 
00821     char *cfg = argv[3];
00822     char *weights = (argc > 4) ? argv[4] : 0;
00823     char *c2 = (argc > 5) ? argv[5] : 0;
00824     char *w2 = (argc > 6) ? argv[6] : 0;
00825     int multi = find_arg(argc, argv, "-multi");
00826     if(0==strcmp(argv[2], "train")) train_go(cfg, weights);
00827     else if(0==strcmp(argv[2], "valid")) valid_go(cfg, weights, multi);
00828     else if(0==strcmp(argv[2], "self")) self_go(cfg, weights, c2, w2, multi);
00829     else if(0==strcmp(argv[2], "test")) test_go(cfg, weights, multi);
00830     else if(0==strcmp(argv[2], "engine")) engine_go(cfg, weights, multi);
00831 }
00832 
00833 


rail_object_detector
Author(s):
autogenerated on Sat Jun 8 2019 20:26:30