00001
00002 #ifndef NETWORK_H
00003 #define NETWORK_H
00004
00005 #include "image.h"
00006 #include "layer.h"
00007 #include "data.h"
00008 #include "tree.h"
00009
00010 typedef enum {
00011 CONSTANT, STEP, EXP, POLY, STEPS, SIG, RANDOM
00012 } learning_rate_policy;
00013
00014 typedef struct network{
00015 float *workspace;
00016 int n;
00017 int batch;
00018 int *seen;
00019 float epoch;
00020 int subdivisions;
00021 float momentum;
00022 float decay;
00023 layer *layers;
00024 int outputs;
00025 float *output;
00026 learning_rate_policy policy;
00027
00028 float learning_rate;
00029 float gamma;
00030 float scale;
00031 float power;
00032 int time_steps;
00033 int step;
00034 int max_batches;
00035 float *scales;
00036 int *steps;
00037 int num_steps;
00038 int burn_in;
00039
00040 int adam;
00041 float B1;
00042 float B2;
00043 float eps;
00044
00045 int inputs;
00046 int h, w, c;
00047 int max_crop;
00048 int min_crop;
00049 float angle;
00050 float aspect;
00051 float exposure;
00052 float saturation;
00053 float hue;
00054
00055 int gpu_index;
00056 tree *hierarchy;
00057
00058 #ifdef GPU
00059 float **input_gpu;
00060 float **truth_gpu;
00061 #endif
00062 } network;
00063
00064 typedef struct network_state {
00065 float *truth;
00066 float *input;
00067 float *delta;
00068 float *workspace;
00069 int train;
00070 int index;
00071 network net;
00072 } network_state;
00073
00074 #ifdef GPU
00075 float train_networks(network *nets, int n, data d, int interval);
00076 void sync_nets(network *nets, int n, int interval);
00077 float train_network_datum_gpu(network net, float *x, float *y);
00078 float *network_predict_gpu(network net, float *input);
00079 float * get_network_output_gpu_layer(network net, int i);
00080 float * get_network_delta_gpu_layer(network net, int i);
00081 float *get_network_output_gpu(network net);
00082 void forward_network_gpu(network net, network_state state);
00083 void backward_network_gpu(network net, network_state state);
00084 void update_network_gpu(network net);
00085 #endif
00086
00087 float get_current_rate(network net);
00088 int get_current_batch(network net);
00089 void free_network(network net);
00090 void compare_networks(network n1, network n2, data d);
00091 char *get_layer_string(LAYER_TYPE a);
00092
00093 network make_network(int n);
00094 void forward_network(network net, network_state state);
00095 void backward_network(network net, network_state state);
00096 void update_network(network net);
00097
00098 float train_network(network net, data d);
00099 float train_network_batch(network net, data d, int n);
00100 float train_network_sgd(network net, data d, int n);
00101 float train_network_datum(network net, float *x, float *y);
00102
00103 matrix network_predict_data(network net, data test);
00104 float *network_predict(network net, float *input);
00105 float network_accuracy(network net, data d);
00106 float *network_accuracies(network net, data d, int n);
00107 float network_accuracy_multi(network net, data d, int n);
00108 void top_predictions(network net, int n, int *index);
00109 float *get_network_output(network net);
00110 float *get_network_output_layer(network net, int i);
00111 float *get_network_delta_layer(network net, int i);
00112 float *get_network_delta(network net);
00113 int get_network_output_size_layer(network net, int i);
00114 int get_network_output_size(network net);
00115 image get_network_image(network net);
00116 image get_network_image_layer(network net, int i);
00117 int get_predicted_class_network(network net);
00118 void print_network(network net);
00119 void visualize_network(network net);
00120 int resize_network(network *net, int w, int h);
00121 void set_batch_network(network *net, int b);
00122 int get_network_input_size(network net);
00123 float get_network_cost(network net);
00124
00125 int get_network_nuisance(network net);
00126 int get_network_background(network net);
00127
00128 #endif
00129