rules.h
Go to the documentation of this file.
00001 #ifndef IRI_RULE_LEARNER_RULES_H
00002 #define IRI_RULE_LEARNER_RULES_H 1
00003 
00004 #include <iri_rule_learner/RuleLearnerConfig.h>
00005 #include <fstream>
00006 #include <symbols.h>
00007 #include <predicates.h>
00008 #include <transitions.h>
00009 
00010 
00011 class Rule : public Symbol
00012 {
00013 public:
00014         PredicateGroup preconditions_;
00015         OutcomeList outcomes_;
00016         uint m_;
00017         uint total_executions_;
00018         
00019         Rule(const Rule& rule);
00020         
00021         Rule(Symbol symbol, PredicateList preconditions, uint m);
00022         
00023         void add_outcome(PredicateList predicates, float init_probability);
00024         
00025         void add_outcome(Outcome outcome);
00026         
00027         bool satisfiesPreconditions(PredicateList predicates);
00028         
00029         void sanitize_probabities();
00030         
00031         bool has_enough_experiences();
00032         
00033         virtual bool needs_learning() =0;
00034         virtual bool needs_pasula_learning() =0;
00035         virtual void add_old_execution(PredicateList outcome_predicates) =0;
00036         virtual void add_execution(PredicateList outcome_predicates) =0;
00037         virtual void update_probabilities() =0;
00038         
00039         friend std::ostream& operator<<(std::ostream &out, const Rule& r);
00040 };
00041 
00042 typedef std::vector< boost::shared_ptr<Rule> > RuleList;
00043 
00044 
00045 class MRule : public Rule
00046 {
00047 public:
00048         MRule(const Rule& rule) :
00049                 Rule(rule)
00050         { }
00051         
00052         MRule(Symbol symbol, PredicateList preconditions, uint m) :
00053                 Rule(symbol, preconditions, m)
00054         { }
00055         
00056         bool needs_learning()
00057         {
00058                 if (total_executions_ < m_)
00059                         return true;
00060                 else 
00061                         return false;
00062         }
00063         bool needs_pasula_learning() {
00064                 return has_enough_experiences();
00065         }
00066         
00067         void add_old_execution(PredicateList outcome_predicates) { add_execution(outcome_predicates); }
00068         
00069         void add_execution(PredicateList outcome_predicates)
00070         {
00071                 OutcomeList::iterator winner_it;
00072                 
00073                 ROS_DEBUG_STREAM("Adding execution. Search outcome for " << PredicateGroup(outcome_predicates));
00074                 
00075                 int max = -1;
00076                 for ( OutcomeList::iterator it=outcomes_.begin() ; it < outcomes_.end(); it++ ) {
00077                         if ( it->is_satisfied(outcome_predicates) ) {
00078                                 int n_satisfied = it->count_satisfied(outcome_predicates);
00079                                 if (n_satisfied > max) {
00080                                         max = n_satisfied;
00081                                         winner_it = it;
00082                                         ROS_DEBUG_STREAM("New winner is " << *it);
00083                                 }
00084                         }
00085                 }
00086                 
00087                 ROS_DEBUG_STREAM("Updated winner is " << *winner_it);
00088                 
00089                 ROS_INFO_STREAM("Rule before learning " << *this);
00090                 winner_it->add_execution();
00091                 total_executions_++;
00092                 update_probabilities();
00093                 ROS_INFO_STREAM("Rule after learning " << *this);
00094         }
00095         
00096         void update_probabilities() 
00097         {
00098                 double learn_param;
00099                 
00100                 // option 1
00101                 //learn_param = m_;
00102                 
00103                 // option 2
00104                 if (total_executions_ > 0)
00105                         learn_param = m_ / sqrt(total_executions_);
00106                 else
00107                         learn_param = m_;
00108                 
00109                 // option 3
00110                 //if (total_executions_ > 0)
00111                         //learn_param = m_ / total_executions_;
00112                 //else
00113                         //learn_param = m_;
00114                 
00115                 for ( OutcomeList::iterator it=outcomes_.begin() ; it < outcomes_.end(); it++ )
00116                         it->update_probability(total_executions_, learn_param);
00117                 
00118                 sanitize_probabities();
00119         }
00120 };
00121 
00122 
00123 class PasulaRule : public Rule
00124 {
00125 public:
00126         uint new_executions_;
00127         
00128         PasulaRule(const Rule& rule) :
00129                 Rule(rule),
00130                 new_executions_(0)
00131         { }
00132         
00133         bool needs_learning()
00134         {
00135                 return true;
00136         }
00137         bool needs_pasula_learning(){
00138                 return true;
00139         }
00140         void add_old_execution(PredicateList outcome_predicates)
00141         {
00142                 ++total_executions_;
00143         }
00144         void add_execution(PredicateList outcome_predicates)
00145         {
00146                 ++total_executions_;
00147                 ++new_executions_;
00148         }
00149         void update_probabilities()
00150         { }
00151 };
00152 
00153 
00154 class RuleSet
00155 {
00156 public:
00157         RuleList rules_;
00158         uint m_;
00159         
00160         // constructors
00161         RuleSet();
00162         RuleSet(uint m);
00163         RuleSet(uint m, RuleList rules);
00164         RuleSet(std::string file_path, uint m);
00165         
00166         // i/o with files
00167         void read_rules_from_file(std::string file_path);
00168         void write_rules_to_file(std::string file_path);
00169         void write_debugging_rules_to_file(std::string file_path);
00170         void sanitize_probabities();
00171         
00172         // general learning
00173         bool needs_learning(std::string name, std::string preconditions);
00174         bool needs_learning(std::string name, PredicateList preconditions);
00175         void add_execution(Transition transition, bool is_new = true);
00176         
00177         // pasula stuff
00178         std::vector<std::string> actions_needing_pasula_learning();
00179         RuleList extract_pasula_not_learned_rules();
00180         void transform_rules_to_pasula();
00181         
00182         // m_estimate stuff
00183         void set_learning_m(uint m);
00184         void update_probabilities();
00185         
00186         // add and remove rules
00187         void add(RuleList rules);
00188         void remove_action_rules(std::string action);
00189         
00190         // operators
00191         friend std::ostream& operator<<(std::ostream &out, const RuleSet& r);
00192 };
00193 
00194 // typedef boost::shared_ptr<GraspPhase> GraspPhasePtr;
00195 
00196 // typedef std::vector<GraspPhasePtr> GraspPhaseSequence;
00197 
00198 #endif


iri_rule_learner
Author(s): dmartinez
autogenerated on Fri Dec 6 2013 20:43:48