Go to the documentation of this file.00001 #ifndef IRI_RULE_LEARNER_RULES_H
00002 #define IRI_RULE_LEARNER_RULES_H 1
00003
00004 #include <iri_rule_learner/RuleLearnerConfig.h>
00005 #include <fstream>
00006 #include <symbols.h>
00007 #include <predicates.h>
00008 #include <transitions.h>
00009
00010
00011 class Rule : public Symbol
00012 {
00013 public:
00014 PredicateGroup preconditions_;
00015 OutcomeList outcomes_;
00016 uint m_;
00017 uint total_executions_;
00018
00019 Rule(const Rule& rule);
00020
00021 Rule(Symbol symbol, PredicateList preconditions, uint m);
00022
00023 void add_outcome(PredicateList predicates, float init_probability);
00024
00025 void add_outcome(Outcome outcome);
00026
00027 bool satisfiesPreconditions(PredicateList predicates);
00028
00029 void sanitize_probabities();
00030
00031 bool has_enough_experiences();
00032
00033 virtual bool needs_learning() =0;
00034 virtual bool needs_pasula_learning() =0;
00035 virtual void add_old_execution(PredicateList outcome_predicates) =0;
00036 virtual void add_execution(PredicateList outcome_predicates) =0;
00037 virtual void update_probabilities() =0;
00038
00039 friend std::ostream& operator<<(std::ostream &out, const Rule& r);
00040 };
00041
00042 typedef std::vector< boost::shared_ptr<Rule> > RuleList;
00043
00044
00045 class MRule : public Rule
00046 {
00047 public:
00048 MRule(const Rule& rule) :
00049 Rule(rule)
00050 { }
00051
00052 MRule(Symbol symbol, PredicateList preconditions, uint m) :
00053 Rule(symbol, preconditions, m)
00054 { }
00055
00056 bool needs_learning()
00057 {
00058 if (total_executions_ < m_)
00059 return true;
00060 else
00061 return false;
00062 }
00063 bool needs_pasula_learning() {
00064 return has_enough_experiences();
00065 }
00066
00067 void add_old_execution(PredicateList outcome_predicates) { add_execution(outcome_predicates); }
00068
00069 void add_execution(PredicateList outcome_predicates)
00070 {
00071 OutcomeList::iterator winner_it;
00072
00073 ROS_DEBUG_STREAM("Adding execution. Search outcome for " << PredicateGroup(outcome_predicates));
00074
00075 int max = -1;
00076 for ( OutcomeList::iterator it=outcomes_.begin() ; it < outcomes_.end(); it++ ) {
00077 if ( it->is_satisfied(outcome_predicates) ) {
00078 int n_satisfied = it->count_satisfied(outcome_predicates);
00079 if (n_satisfied > max) {
00080 max = n_satisfied;
00081 winner_it = it;
00082 ROS_DEBUG_STREAM("New winner is " << *it);
00083 }
00084 }
00085 }
00086
00087 ROS_DEBUG_STREAM("Updated winner is " << *winner_it);
00088
00089 ROS_INFO_STREAM("Rule before learning " << *this);
00090 winner_it->add_execution();
00091 total_executions_++;
00092 update_probabilities();
00093 ROS_INFO_STREAM("Rule after learning " << *this);
00094 }
00095
00096 void update_probabilities()
00097 {
00098 double learn_param;
00099
00100
00101
00102
00103
00104 if (total_executions_ > 0)
00105 learn_param = m_ / sqrt(total_executions_);
00106 else
00107 learn_param = m_;
00108
00109
00110
00111
00112
00113
00114
00115 for ( OutcomeList::iterator it=outcomes_.begin() ; it < outcomes_.end(); it++ )
00116 it->update_probability(total_executions_, learn_param);
00117
00118 sanitize_probabities();
00119 }
00120 };
00121
00122
00123 class PasulaRule : public Rule
00124 {
00125 public:
00126 uint new_executions_;
00127
00128 PasulaRule(const Rule& rule) :
00129 Rule(rule),
00130 new_executions_(0)
00131 { }
00132
00133 bool needs_learning()
00134 {
00135 return true;
00136 }
00137 bool needs_pasula_learning(){
00138 return true;
00139 }
00140 void add_old_execution(PredicateList outcome_predicates)
00141 {
00142 ++total_executions_;
00143 }
00144 void add_execution(PredicateList outcome_predicates)
00145 {
00146 ++total_executions_;
00147 ++new_executions_;
00148 }
00149 void update_probabilities()
00150 { }
00151 };
00152
00153
00154 class RuleSet
00155 {
00156 public:
00157 RuleList rules_;
00158 uint m_;
00159
00160
00161 RuleSet();
00162 RuleSet(uint m);
00163 RuleSet(uint m, RuleList rules);
00164 RuleSet(std::string file_path, uint m);
00165
00166
00167 void read_rules_from_file(std::string file_path);
00168 void write_rules_to_file(std::string file_path);
00169 void write_debugging_rules_to_file(std::string file_path);
00170 void sanitize_probabities();
00171
00172
00173 bool needs_learning(std::string name, std::string preconditions);
00174 bool needs_learning(std::string name, PredicateList preconditions);
00175 void add_execution(Transition transition, bool is_new = true);
00176
00177
00178 std::vector<std::string> actions_needing_pasula_learning();
00179 RuleList extract_pasula_not_learned_rules();
00180 void transform_rules_to_pasula();
00181
00182
00183 void set_learning_m(uint m);
00184 void update_probabilities();
00185
00186
00187 void add(RuleList rules);
00188 void remove_action_rules(std::string action);
00189
00190
00191 friend std::ostream& operator<<(std::ostream &out, const RuleSet& r);
00192 };
00193
00194
00195
00196
00197
00198 #endif