00001 #include "pasula_learner_alg.h" 00002 00003 PasulaLearnerAlgorithm::PasulaLearnerAlgorithm(void) 00004 { 00005 } 00006 00007 PasulaLearnerAlgorithm::~PasulaLearnerAlgorithm(void) 00008 { 00009 } 00010 00011 void PasulaLearnerAlgorithm::config_update(Config& new_cfg, uint32_t level) 00012 { 00013 this->lock(); 00014 00015 this->rules_file = new_cfg.rules_file; 00016 this->transitions_file = new_cfg.transitions_file; 00017 this->symbols_file = new_cfg.symbols_file; 00018 00019 // Regularizer 00020 this->alpha_pen = new_cfg.alpha_pen; 00021 // Lower bounds for probabilities of states in case of noise outcome 00022 this->noise_lower_bound = new_cfg.noise_lower_bound; 00023 // ... same, only for noisy default rule 00024 this->noise_lower_bound_default_rule = new_cfg.noise_lower_bound_default_rule; 00025 00026 // save the current configuration 00027 this->config_=new_cfg; 00028 00029 this->unlock(); 00030 } 00031 00032 // PasulaLearnerAlgorithm Public API 00033 bool PasulaLearnerAlgorithm::learn_rules() 00034 { 00035 // Rule learning algorithm is heuristic and makes some random choices. 00036 srand (time(NULL)); 00037 rnd.seed(rand()); 00038 00039 // ------------------------------------- 00040 // PARAMETERS 00041 // ------------------------------------- 00042 00043 // Log-file 00044 MT::String logfile("learn.log"); 00045 00046 // Symbols 00047 relational::SymL symbols; 00048 relational::ArgumentTypeL types; 00049 relational::readSymbolsAndTypes(symbols, types, MT::String(this->symbols_file.c_str() )); 00050 00051 // Data 00052 relational::StateTransitionL transitions = relational::StateTransition::read_SAS_SAS(MT::String(this->transitions_file.c_str() )); 00053 PRINT(transitions.N); 00054 // write(transitions); 00055 00056 00057 // ------------------------------------- 00058 // LEARN 00059 // ------------------------------------- 00060 00061 relational::learn::set_penalty(alpha_pen); 00062 relational::learn::set_p_min(noise_lower_bound, noise_lower_bound_default_rule); 00063 relational::RuleSetContainer rulesC; 00064 ROS_INFO("Starting learning"); 00065 relational::learn::learn_rules(rulesC, transitions); 00066 00067 relational::write(rulesC.rules, MT::String(this->rules_file.c_str() )); 00068 ROS_INFO_STREAM("Rules learned"); 00069 00070 return true; 00071 }