network.h
Go to the documentation of this file.
00001 // Oh boy, why am I about to do this....
00002 #ifndef NETWORK_H
00003 #define NETWORK_H
00004 
00005 #include "image.h"
00006 #include "layer.h"
00007 #include "data.h"
00008 #include "tree.h"
00009 
00010 typedef enum {
00011     CONSTANT, STEP, EXP, POLY, STEPS, SIG, RANDOM
00012 } learning_rate_policy;
00013 
00014 typedef struct network{
00015     float *workspace;
00016     int n;
00017     int batch;
00018     int *seen;
00019     float epoch;
00020     int subdivisions;
00021     float momentum;
00022     float decay;
00023     layer *layers;
00024     int outputs;
00025     float *output;
00026     learning_rate_policy policy;
00027 
00028     float learning_rate;
00029     float gamma;
00030     float scale;
00031     float power;
00032     int time_steps;
00033     int step;
00034     int max_batches;
00035     float *scales;
00036     int   *steps;
00037     int num_steps;
00038     int burn_in;
00039 
00040     int adam;
00041     float B1;
00042     float B2;
00043     float eps;
00044 
00045     int inputs;
00046     int h, w, c;
00047     int max_crop;
00048     int min_crop;
00049     float angle;
00050     float aspect;
00051     float exposure;
00052     float saturation;
00053     float hue;
00054 
00055     int gpu_index;
00056     tree *hierarchy;
00057 
00058     #ifdef GPU
00059     float **input_gpu;
00060     float **truth_gpu;
00061     #endif
00062 } network;
00063 
00064 typedef struct network_state {
00065     float *truth;
00066     float *input;
00067     float *delta;
00068     float *workspace;
00069     int train;
00070     int index;
00071     network net;
00072 } network_state;
00073 
00074 #ifdef GPU
00075 float train_networks(network *nets, int n, data d, int interval);
00076 void sync_nets(network *nets, int n, int interval);
00077 float train_network_datum_gpu(network net, float *x, float *y);
00078 float *network_predict_gpu(network net, float *input);
00079 float * get_network_output_gpu_layer(network net, int i);
00080 float * get_network_delta_gpu_layer(network net, int i);
00081 float *get_network_output_gpu(network net);
00082 void forward_network_gpu(network net, network_state state);
00083 void backward_network_gpu(network net, network_state state);
00084 void update_network_gpu(network net);
00085 #endif
00086 
00087 float get_current_rate(network net);
00088 int get_current_batch(network net);
00089 void free_network(network net);
00090 void compare_networks(network n1, network n2, data d);
00091 char *get_layer_string(LAYER_TYPE a);
00092 
00093 network make_network(int n);
00094 void forward_network(network net, network_state state);
00095 void backward_network(network net, network_state state);
00096 void update_network(network net);
00097 
00098 float train_network(network net, data d);
00099 float train_network_batch(network net, data d, int n);
00100 float train_network_sgd(network net, data d, int n);
00101 float train_network_datum(network net, float *x, float *y);
00102 
00103 matrix network_predict_data(network net, data test);
00104 float *network_predict(network net, float *input);
00105 float network_accuracy(network net, data d);
00106 float *network_accuracies(network net, data d, int n);
00107 float network_accuracy_multi(network net, data d, int n);
00108 void top_predictions(network net, int n, int *index);
00109 float *get_network_output(network net);
00110 float *get_network_output_layer(network net, int i);
00111 float *get_network_delta_layer(network net, int i);
00112 float *get_network_delta(network net);
00113 int get_network_output_size_layer(network net, int i);
00114 int get_network_output_size(network net);
00115 image get_network_image(network net);
00116 image get_network_image_layer(network net, int i);
00117 int get_predicted_class_network(network net);
00118 void print_network(network net);
00119 void visualize_network(network net);
00120 int resize_network(network *net, int w, int h);
00121 void set_batch_network(network *net, int b);
00122 int get_network_input_size(network net);
00123 float get_network_cost(network net);
00124 
00125 int get_network_nuisance(network net);
00126 int get_network_background(network net);
00127 
00128 #endif
00129 


rail_object_detector
Author(s):
autogenerated on Sat Jun 8 2019 20:26:30