rule_learner_alg.cpp
Go to the documentation of this file.
00001 #include "rule_learner_alg.h"
00002 
00003 RuleLearnerAlgorithm::RuleLearnerAlgorithm(void)
00004 {
00005         this->needs_learning_ = false;
00006 }
00007 
00008 RuleLearnerAlgorithm::~RuleLearnerAlgorithm(void)
00009 {
00010 }
00011 
00012 void RuleLearnerAlgorithm::config_update(Config& new_cfg, uint32_t level)
00013 {
00014         this->lock();
00015         
00016         if (this->learning_m_ != new_cfg.learning_m) {
00017                 this->learning_m_ = new_cfg.learning_m;
00018                 this->rules_.set_learning_m(this->learning_m_);
00019         }
00020         
00021         if (this->rules_path_.compare(new_cfg.rules_path) != 0){
00022                 this->rules_path_ = new_cfg.rules_path;
00023                 // reload rules
00024                 this->rules_.read_rules_from_file(this->rules_path_);
00025         }
00026         
00027         if (this->restore_transitions_path_ != new_cfg.restore_transitions_path) {
00028                 this->restore_transitions_path_ = new_cfg.restore_transitions_path;
00029                 this->transitions_ = read_transitions_from_file(this->restore_transitions_path_);
00030                 for ( TransitionList::iterator transition=this->transitions_.begin() ; transition != this->transitions_.end(); ++transition ) {
00031                         this->rules_.add_execution(*transition, true);
00032                 }
00033                 ROS_INFO_STREAM("transitions\n " << this->transitions_);
00034         }
00035         
00036         this->pasula_rules_path_ = new_cfg.pasula_rules_path;
00037         this->updated_rules_path_ = new_cfg.updated_rules_path;
00038         this->debugging_rules_path_ = new_cfg.debugging_rules_path;
00039         this->transitions_path_ = new_cfg.transitions_path;
00040         this->backup_transitions_path_ = new_cfg.backup_transitions_path;
00041         
00042         // save the current configuration
00043         this->config_=new_cfg;
00044         
00045         this->unlock();
00046 }
00047 
00048 bool RuleLearnerAlgorithm::add_state(std::string state)
00049 {
00050         bool rules_updated;
00051         
00052         if (this->needs_learning_) {
00053                 this->add_execution(state);
00054                 
00055                 this->needs_learning_ = false;
00056                 rules_updated = true;
00057                 
00058                 // debugging
00059                 this->rules_.write_debugging_rules_to_file(debugging_rules_path_);
00060         }
00061         else {
00062                 rules_updated = false;
00063         }
00064         
00065         this->last_state_ = state;
00066         
00067         return rules_updated;
00068 }
00069 
00070 bool RuleLearnerAlgorithm::add_action(std::string action_name, std::vector<uint> params)
00071 {
00072         std::stringstream ss;
00073         ss << action_name << "(";
00074         for (size_t i = 0; i < params.size(); i++) {
00075                 ss << (params[i] + 60);
00076                 if ( (i + 1) < params.size() ) {
00077                         ss << ",";
00078                 }
00079         }
00080         ss << ")";
00081         this->last_action_ = get_symbol_from_string(ss.str());
00082         
00083         if (this->last_action_.exists()) { // (req.action_movements_successful && this->alg_.rules_.needs_learning(req.action, representation_to_string_srv_.response.state_string)) { // needs learning
00084                 this->needs_learning_ = true;
00085                 ROS_DEBUG_STREAM("Needs learning");
00086         }
00087         else {
00088                 this->needs_learning_ = false;
00089                 ROS_DEBUG_STREAM("Doesn't need learning");
00090         }
00091         
00092         return this->needs_learning_;
00093 }
00094 
00095 // RuleLearnerAlgorithm Public API
00096 void RuleLearnerAlgorithm::add_execution(std::string state)
00097 {
00098         ROS_INFO_STREAM("Learning\nAction " << this->last_action_ << "\nPrevious state is " << this->last_state_ << "\nCurrent state is " << state);
00099 
00100         // add execution
00101         this->lock();
00102         ROS_INFO_STREAM("transitions\n " << transitions_);
00103         Transition transition(this->last_state_, this->last_action_, state);
00104         transitions_.push_back(transition);
00105         write_transitions_to_file(transitions_, this->backup_transitions_path_);
00106         new_transitions_.push_back(transition);
00107         rules_.add_execution(transition);
00108         this->unlock();
00109 
00110         // write results
00111         ROS_DEBUG_STREAM("New rules are " << rules_);
00112         rules_.write_rules_to_file(this->updated_rules_path_);
00113 }
00114 
00115 
00116 std::vector<int> RuleLearnerAlgorithm::update_with_pasula_rules(std::vector<std::string> actions, TransitionList transitions)
00117 {       
00118         std::vector<int> res; // contains number of pasula rules, number of learned rules, total number of rules
00119         ROS_INFO_STREAM("Updating with pasula rules");
00120         
00121         // read pasula rules
00122         RuleSet pasula_rules(this->pasula_rules_path_, this->learning_m_);
00123         pasula_rules.transform_rules_to_pasula();
00124         res.push_back(pasula_rules.rules_.size()); // number of pasula rules
00125         
00126         // remove learned rules from standard rules
00127         for ( RuleList::iterator rule=pasula_rules.rules_.begin() ; rule != pasula_rules.rules_.end(); ++rule ) {
00128                 rules_.remove_action_rules((*rule)->name_);
00129         }
00130         
00131         // add transitions_
00132         for ( TransitionList::iterator transition=transitions.begin() ; transition != transitions.end(); ++transition ) {
00133                 pasula_rules.add_execution(*transition, false);
00134         }
00135         
00136         // remove not pasula and add to rules
00137 //         RuleSet not_learned_rules = RuleSet(rules_.m_);
00138 //         not_learned_rules.add(pasula_rules.extract_pasula_not_learned_rules());
00139         res.push_back(pasula_rules.rules_.size()); // number of learned rules
00140         
00141         // add new transitions obtained while learning
00142         for ( TransitionList::iterator transition=new_transitions_.begin() ; transition != new_transitions_.end(); ++transition ) {
00143                 pasula_rules.add_execution(*transition);
00144         }
00145         
00146 //         for ( TransitionList::iterator transition=this->transitions_.begin() ; transition != this->transitions_.end(); ++transition ) {
00147 //                 not_learned_rules.add_execution(*transition);
00148 //         }
00149         
00150 //         rules_.add(not_learned_rules.rules_);
00151         rules_.add(pasula_rules.rules_);
00152         rules_.update_probabilities();
00153         
00154         ROS_INFO_STREAM("Updated rules with Pasula are " << rules_);
00155         rules_.write_rules_to_file(this->updated_rules_path_);
00156         res.push_back(rules_.rules_.size()); // number of total rules
00157         return res;
00158 }
00159 
00160 std::vector<std::string> RuleLearnerAlgorithm::actions_needing_pasula_learning() {
00161         std::vector<std::string> result;
00162         
00163         result = rules_.actions_needing_pasula_learning();
00164         
00165         // remove duplicated
00166         remove_duplicates(result);
00167         
00168         return result;
00169 }
00170 
00171 TransitionList RuleLearnerAlgorithm::get_transitions_with_actions(std::vector<std::string> actions) {
00172         TransitionList important_transitions;
00173         for ( TransitionList::iterator transition=transitions_.begin() ; transition != transitions_.end(); ++transition ) {
00174                 bool transition_added = false;
00175                 std::vector<std::string>::iterator action=actions.begin();
00176                 while ((action != actions.end()) && !transition_added) {
00177                         if (transition->action_.has_name(*action)) {
00178                                 important_transitions.push_back(Transition(*transition));
00179                                 transition_added=true;
00180                         }
00181                         ++action;
00182                 }
00183         }
00184         
00185         return important_transitions;
00186 }
00187 
00188 void RuleLearnerAlgorithm::write_transitions(TransitionList transitions) {
00189         write_transitions_to_file(transitions, this->transitions_path_);
00190 }
00191 
00192 void RuleLearnerAlgorithm::reset_new_transitions() {
00193         new_transitions_.clear();
00194 }


iri_rule_learner
Author(s): dmartinez
autogenerated on Fri Dec 6 2013 20:43:48