Go to the documentation of this file.00001
00011 #ifndef CBA_H_
00012 #define CBA_H_
00013
00014 #include <ANN/ANN.h>
00015 #include <ros/ros.h>
00016 #include <lfd_common/action_complete.h>
00017 #include <lfd_common/state.h>
00018 #include <vector>
00019
00024 #define MAX_DATA_POINTS "~max_data_points"
00025
00029 #define DEFAULT_MAX_POINTS 1024
00030
00034 #define DIST_THRESH_MULT "~dist_thresh_mult"
00035
00039 #define DEFAULT_DIST_MULT 1.5
00040
00044 #define ANN_EPSILON 0
00045
00050 typedef struct
00051 {
00052 int l;
00053 double c;
00054 int db;
00055 } prediction;
00056
00061 typedef struct
00062 {
00063 int l;
00064 int db;
00065 int cnt;
00066 double thresh;
00067 } conf;
00068
00075 class cba_learner
00076 {
00077 public:
00083 cba_learner();
00084
00090 virtual ~cba_learner();
00091
00097 void step();
00098
00099 private:
00107 prediction *classify_state();
00108
00116 double nearest_neighbor();
00117
00127 double conf_thresh(int l, int db);
00128
00134 void update_thresholds();
00135
00143 void state_listener_callback(const lfd_common::state::ConstPtr &msg);
00144
00154 bool a_complete_callback(lfd_common::action_complete::Request &req, lfd_common::action_complete::Response &resp);
00155
00156 ros::NodeHandle node;
00158 ros::Publisher execute, add_point, change_point;
00159 ros::Subscriber state_listener;
00160 ros::ServiceClient classify, correction, demonstration;
00161 ros::ServiceServer a_complete;
00163 float *s, *sc;
00164 int s_size, max_pts, pts, a;
00165 bool action_complete, autonomous_action;
00166 double dist_thresh, dist_mult;
00167 std::vector<conf*> conf_thresholds;
00169 ANNpointArray ann_data;
00170 int32_t *labels;
00171 };
00172
00180 int main(int argc, char **argv);
00181
00182 #endif