rules.cpp
Go to the documentation of this file.
00001 #include "rules.h"
00002 #include "symbols.h"
00003 // #include "symbols.cpp"
00004 
00005 
00011 std::ostream& operator<<(std::ostream &out, const Rule& r)
00012 {
00013         out << "ACTION:\n";
00014         out << *static_cast<const Symbol*>(&r);
00015         out << "\n";
00016         out << "CONTEXT:\n";
00017         out << r.preconditions_;
00018         out << "\n";
00019         out << "OUTCOMES:\n";
00020 //      for ( OutcomeList::iterator it=r.outcomes_.begin() ; it != r.outcomes_.end(); it++ )
00021         for ( size_t i = 0; i < r.outcomes_.size(); i++)
00022                 out << r.outcomes_[i] << "\n";
00023         out << "\n";
00024         
00025         return out;
00026 }
00027 
00028 std::ostream& operator<<(std::ostream &out, const RuleSet& r)
00029 {
00030 //      for ( RuleList::iterator it=r.rules_.begin() ; it != r.rules_.end(); it++ )
00031 //              out << *it << "\n";
00032         for ( size_t i = 0; i < r.rules_.size(); i++)
00033                 out << (*r.rules_[i]);
00034         
00035         return out;
00036 }
00037 
00043 Rule::Rule(const Rule& rule) :
00044         Symbol(rule),
00045         preconditions_(rule.preconditions_),
00046         outcomes_(rule.outcomes_),
00047         m_(rule.m_)
00048 { 
00049         total_executions_ = 0;
00050         for ( OutcomeList::const_iterator it = rule.outcomes_.begin(); it != rule.outcomes_.end(); it++) {
00051                 total_executions_ += it->num_executions_;
00052         }
00053 }
00054 
00055 Rule::Rule(Symbol symbol, PredicateList preconditions, uint m) :
00056         Symbol(symbol),
00057         preconditions_(preconditions),
00058         m_(m),
00059         total_executions_(0)
00060 { }
00061 
00062 void Rule::add_outcome(PredicateList predicates, float init_probability)
00063 {
00064         outcomes_.push_back(Outcome(predicates, init_probability));
00065 }
00066 
00067 void Rule::add_outcome(Outcome outcome)
00068 {
00069         outcomes_.push_back(outcome);
00070 }
00071 
00072 bool Rule::satisfiesPreconditions(PredicateList predicates)
00073 {
00074         if (preconditions_.is_satisfied(predicates))
00075                 return true;
00076         else
00077                 return false;
00078 }
00079 
00080 bool Rule::has_enough_experiences()
00081 {
00082         if (total_executions_ > m_)
00083                 return true;
00084         else
00085                 return false;
00086 }
00087 
00088 void Rule::sanitize_probabities()
00089 {
00090         float sum_probs = 0.0;
00091         for ( OutcomeList::iterator it = this->outcomes_.begin(); it != this->outcomes_.end(); it++) {
00092                 sum_probs += it->init_probability_;
00093         }
00094         if (sum_probs != 1.0) {
00095                 this->outcomes_[0].init_probability_ -= (sum_probs - 1.0);
00096                 update_probabilities();
00097         }
00098 }
00099 
00105 RuleSet::RuleSet() :
00106         m_(0)
00107 { }
00108 
00109 RuleSet::RuleSet(uint m) :
00110         m_(m)
00111 { }
00112 
00113 RuleSet::RuleSet(uint m, RuleList rules) :
00114         m_(m)
00115 {
00116         add(rules);
00117 }
00118 
00119 RuleSet::RuleSet(std::string file_path, uint m) :
00120         m_(m)
00121 { 
00122         read_rules_from_file(file_path);
00123 }
00124 
00125 void RuleSet::read_rules_from_file(std::string file_path)
00126 {
00127         std::string line;
00128         std::ifstream rules_file(file_path.c_str());
00129         
00130         // Clear previous rules
00131         rules_.clear();
00132         
00133         if (rules_file.is_open())
00134         {
00135                 std::string line1;
00136                 std::string line2;
00137                 std::string line_aux;
00138                 
00139                 getline (rules_file,line_aux);
00140                 while ( rules_file.good() )
00141                 {
00142                         // ACTION:
00143                         // read until ACTION: found
00144                         while (line_aux.compare("ACTION:") != 0) {
00145                                 getline (rules_file,line_aux);
00146                         }
00147                         
00148                         // Action name
00149                         getline (rules_file,line1);
00150                         
00151                         // CONTEXT:
00152                         getline (rules_file,line_aux);
00153                         
00154                         // Preconditions
00155                         getline (rules_file,line2);
00156                         
00157                         MRule rule(get_symbol_from_string(line1), 
00158                                 get_predicates_from_string(line2), 
00159                                 m_);
00160                         
00161                         // OUTCOMES:
00162                         getline (rules_file,line_aux);
00163                         
00164                         // Outcome i
00165                         getline (rules_file,line1);
00166                         while (line1.compare("") != 0) {
00167                                 rule.add_outcome(get_outcome_from_string(line1));
00168                                 getline (rules_file,line1);
00169                         }
00170                         
00171                         if ( (!rule.has_name(DUMMY_ACTION_NAME)) && (!rule.has_name(DEFAULT_ACTION_NAME))) {
00172                                 rules_.push_back(boost::make_shared<MRule>(rule));
00173                         }
00174                         
00175                         // read lines until not empty line
00176                         getline (rules_file,line_aux);
00177                         line_aux.erase(remove_if(line_aux.begin(), line_aux.end(), isspace), line_aux.end());
00178                         while ((not rules_file.eof()) && rules_file.good() && (line_aux.compare("") == 0)) {
00179                                 getline (rules_file,line_aux);
00180                         }
00181                 }
00182         }
00183         
00184         sanitize_probabities();
00185         
00186         ROS_DEBUG_STREAM("Read rules are " << *this);
00187         rules_file.close();
00188 }
00189 
00190 void RuleSet::write_rules_to_file(std::string file_path)
00191 {
00192         std::ofstream rules_file(file_path.c_str());
00193         ROS_DEBUG_STREAM("Writting rules to file" << file_path);
00194         if (rules_file.is_open())
00195         {
00196                 rules_file << *this;
00197         }
00198         else
00199                 ROS_ERROR_STREAM("Couldn't write rules to " << file_path);
00200         rules_file.close();
00201 }
00202 
00203 void RuleSet::write_debugging_rules_to_file(std::string file_path)
00204 {
00205         std::ofstream rules_file(file_path.c_str());
00206         if (rules_file.is_open())
00207         {
00208                 for ( size_t i = 0; i < this->rules_.size(); i++) {
00209                         rules_file << *this->rules_[i];
00210                         rules_file << "Total " << this->rules_[i]->total_executions_ << std::endl;
00211                         for ( size_t j = 0; j < this->rules_[i]->outcomes_.size(); j++)
00212                                 rules_file << "Out" << i << " " <<  this->rules_[i]->outcomes_[j].num_executions_ << " - " << this->rules_[i]->outcomes_[j].init_probability_ << "\n";
00213                         rules_file << std::endl;
00214                         rules_file << std::endl;
00215                 }
00216         }
00217         else {
00218                 ROS_ERROR_STREAM("Couldn't write debugging rules to " << file_path);
00219         }
00220         rules_file.close();
00221 }
00222 
00223 
00224 void RuleSet::sanitize_probabities()
00225 {
00226         for ( RuleList::iterator it=rules_.begin() ; it != rules_.end(); ++it ) {
00227                 (*it)->sanitize_probabities();
00228         }
00229 }
00230 
00231 
00232 // void RuleSet::add_execution(Symbol action, PredicateList preconditions, PredicateList outcome_predicates)
00233 void RuleSet::add_execution(Transition transition, bool is_new)
00234 {
00235         PredicateGroup prec_group(transition.prev_state_);
00236         PredicateGroup out_group(transition.next_state_);
00237         
00238         // grounded predicates
00239         // 1 -> X, etc...
00240         std::string DEFAULT_VARIABLES[] = {"X", "Y", "Z"};
00241         for (size_t i = 0; i < transition.action_.param_names_.size(); i++) {
00242                 prec_group.change_variable(transition.action_.param_names_[i], DEFAULT_VARIABLES[i]);
00243                 out_group.change_variable(transition.action_.param_names_[i], DEFAULT_VARIABLES[i]);
00244                 transition.action_.param_names_[i] = DEFAULT_VARIABLES[i];
00245         }
00246         
00247         // get preconditions?
00248         PredicateGroup important_prec = prec_group.get_predicates_with_vars(transition.action_.param_names_);
00249         
00250         // get differences
00251         PredicateGroup important_out = out_group.get_predicates_with_vars(transition.action_.param_names_);
00252         PredicateGroup differences = important_prec.get_differences(important_out);
00253         
00254         if (is_new) {
00255                 ROS_INFO_STREAM("Adding execution " << transition.action_ << "\nPreconditions " << PredicateGroup(important_prec) << "\nOutcome " << PredicateGroup(differences));
00256         }
00257         for ( RuleList::iterator it=rules_.begin() ; it != rules_.end(); it++ ) {
00258                 if ( ((*it)->has_name(transition.action_.name_)) && ((*it)->satisfiesPreconditions(important_prec.predicates_)) ) {
00259                         if (is_new)
00260                                 (*it)->add_execution(differences.predicates_);
00261                         else
00262                                 (*it)->add_old_execution(differences.predicates_);
00263                 }
00264         }
00265 }
00266 
00267 
00268 std::vector<std::string> RuleSet::actions_needing_pasula_learning()
00269 {
00270         std::vector<std::string> result;
00271         for ( RuleList::iterator it=rules_.begin() ; it != rules_.end(); ++it ) {
00272                 if ((*it)->needs_pasula_learning()) {
00273                         result.push_back((*it)->name_);
00274                 }
00275         }
00276         return result;
00277 }
00278 
00279 RuleList RuleSet::extract_pasula_not_learned_rules()
00280 {
00281         RuleList unlearned_rules;
00282         
00283         for ( RuleList::iterator rule=rules_.begin() ; rule != rules_.end(); ) {
00284                 if (!(*rule)->has_enough_experiences()) {
00285                         (*rule)->total_executions_=0;
00286                         unlearned_rules.push_back(boost::make_shared<MRule>(MRule(**rule)));
00287                         rule = rules_.erase(rule);
00288                 }
00289                 else {
00290                         ++rule;
00291                 }
00292         }
00293         
00294         return unlearned_rules;
00295 }
00296 
00297 
00298 void RuleSet::add(RuleList rules)
00299 {
00300         for ( RuleList::iterator rule=rules.begin() ; rule != rules.end(); rule++ ) {
00301                 rules_.push_back(*rule);
00302         }
00303 }
00304 
00305 
00306 void RuleSet::remove_action_rules(std::string action)
00307 {
00308         for ( RuleList::iterator rule=rules_.begin() ; rule != rules_.end();  ) {
00309                 if ((*rule)->name_.compare(action) == 0) {
00310                         rule = rules_.erase(rule);
00311                 }
00312                 else {
00313                         ++rule;
00314                 }
00315         }
00316 }
00317 
00318 bool RuleSet::needs_learning(std::string name, PredicateList preconditions)
00319 {
00320         bool result = false;
00321         
00322         for ( RuleList::iterator it=rules_.begin() ; it != rules_.end(); it++ ) {
00323                 if ( ((*it)->has_name(name)) && ((*it)->satisfiesPreconditions(preconditions)) )
00324                         if ((*it)->needs_learning())
00325                                 result = true;
00326         }
00327         
00328         return result;
00329 }
00330 
00331 
00332 void RuleSet::transform_rules_to_pasula()
00333 {
00334         RuleList new_rules;
00335         for ( RuleList::iterator rule=rules_.begin() ; rule != rules_.end();  ) {
00336                 new_rules.push_back(boost::make_shared<PasulaRule>(PasulaRule(**rule)));
00337                 rule = rules_.erase(rule);
00338         }
00339         rules_ = new_rules;
00340 }
00341 
00342 
00343 void RuleSet::set_learning_m(uint m)
00344 {
00345         this->m_ = m;
00346         update_probabilities();
00347 }
00348 
00349 void RuleSet::update_probabilities()
00350 {
00351         for ( RuleList::iterator it=rules_.begin() ; it != rules_.end(); it++ ) {
00352                 (*it)->update_probabilities();
00353         }
00354 }
00355 
00356 bool RuleSet::needs_learning(std::string name, std::string preconditions)
00357 {
00358         return needs_learning(name, get_predicates_from_string(preconditions));
00359 }


iri_rule_learner
Author(s): dmartinez
autogenerated on Fri Dec 6 2013 20:43:48