dice.c
Go to the documentation of this file.
00001 #include "network.h"
00002 #include "utils.h"
00003 #include "parser.h"
00004 
00005 char *dice_labels[] = {"face1","face2","face3","face4","face5","face6"};
00006 
00007 void train_dice(char *cfgfile, char *weightfile)
00008 {
00009     srand(time(0));
00010     float avg_loss = -1;
00011     char *base = basecfg(cfgfile);
00012     char *backup_directory = "/home/pjreddie/backup/";
00013     printf("%s\n", base);
00014     network net = parse_network_cfg(cfgfile);
00015     if(weightfile){
00016         load_weights(&net, weightfile);
00017     }
00018     printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
00019     int imgs = 1024;
00020     int i = *net.seen/imgs;
00021     char **labels = dice_labels;
00022     list *plist = get_paths("data/dice/dice.train.list");
00023     char **paths = (char **)list_to_array(plist);
00024     printf("%d\n", plist->size);
00025     clock_t time;
00026     while(1){
00027         ++i;
00028         time=clock();
00029         data train = load_data_old(paths, imgs, plist->size, labels, 6, net.w, net.h);
00030         printf("Loaded: %lf seconds\n", sec(clock()-time));
00031 
00032         time=clock();
00033         float loss = train_network(net, train);
00034         if(avg_loss == -1) avg_loss = loss;
00035         avg_loss = avg_loss*.9 + loss*.1;
00036         printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), *net.seen);
00037         free_data(train);
00038         if((i % 100) == 0) net.learning_rate *= .1;
00039         if(i%100==0){
00040             char buff[256];
00041             sprintf(buff, "%s/%s_%d.weights",backup_directory,base, i);
00042             save_weights(net, buff);
00043         }
00044     }
00045 }
00046 
00047 void validate_dice(char *filename, char *weightfile)
00048 {
00049     network net = parse_network_cfg(filename);
00050     if(weightfile){
00051         load_weights(&net, weightfile);
00052     }
00053     srand(time(0));
00054 
00055     char **labels = dice_labels;
00056     list *plist = get_paths("data/dice/dice.val.list");
00057 
00058     char **paths = (char **)list_to_array(plist);
00059     int m = plist->size;
00060     free_list(plist);
00061 
00062     data val = load_data_old(paths, m, 0, labels, 6, net.w, net.h);
00063     float *acc = network_accuracies(net, val, 2);
00064     printf("Validation Accuracy: %f, %d images\n", acc[0], m);
00065     free_data(val);
00066 }
00067 
00068 void test_dice(char *cfgfile, char *weightfile, char *filename)
00069 {
00070     network net = parse_network_cfg(cfgfile);
00071     if(weightfile){
00072         load_weights(&net, weightfile);
00073     }
00074     set_batch_network(&net, 1);
00075     srand(2222222);
00076     int i = 0;
00077     char **names = dice_labels;
00078     char buff[256];
00079     char *input = buff;
00080     int indexes[6];
00081     while(1){
00082         if(filename){
00083             strncpy(input, filename, 256);
00084         }else{
00085             printf("Enter Image Path: ");
00086             fflush(stdout);
00087             input = fgets(input, 256, stdin);
00088             if(!input) return;
00089             strtok(input, "\n");
00090         }
00091         image im = load_image_color(input, net.w, net.h);
00092         float *X = im.data;
00093         float *predictions = network_predict(net, X);
00094         top_predictions(net, 6, indexes);
00095         for(i = 0; i < 6; ++i){
00096             int index = indexes[i];
00097             printf("%s: %f\n", names[index], predictions[index]);
00098         }
00099         free_image(im);
00100         if (filename) break;
00101     }
00102 }
00103 
00104 void run_dice(int argc, char **argv)
00105 {
00106     if(argc < 4){
00107         fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
00108         return;
00109     }
00110 
00111     char *cfg = argv[3];
00112     char *weights = (argc > 4) ? argv[4] : 0;
00113     char *filename = (argc > 5) ? argv[5]: 0;
00114     if(0==strcmp(argv[2], "test")) test_dice(cfg, weights, filename);
00115     else if(0==strcmp(argv[2], "train")) train_dice(cfg, weights);
00116     else if(0==strcmp(argv[2], "valid")) validate_dice(cfg, weights);
00117 }
00118 


rail_object_detector
Author(s):
autogenerated on Sat Jun 8 2019 20:26:30