Go to the documentation of this file.00001
00002
00003
00004
00005 from __future__ import print_function
00006 import os
00007 import sys
00008 import unittest
00009 import rospy
00010 import rospkg
00011 from pgm_learner.srv import (DiscreteParameterEstimation,
00012 DiscreteParameterEstimationRequest,
00013 DiscreteQuery,
00014 DiscreteQueryRequest,
00015 DiscreteStructureEstimation,
00016 DiscreteStructureEstimationRequest,
00017 )
00018 from pgm_learner.msg import GraphEdge, DiscreteGraphState, DiscreteNodeState
00019 import pgm_learner.msg_utils as U
00020
00021 from libpgm.graphskeleton import GraphSkeleton
00022 from libpgm.nodedata import NodeData
00023 from libpgm.discretebayesiannetwork import DiscreteBayesianNetwork
00024
00025
00026 class TestDiscreteBNLearnerNode(unittest.TestCase):
00027 def __init__(self, arg):
00028 super(self.__class__, self).__init__(arg)
00029 self.pkg = rospkg.RosPack()
00030 self.data_path = os.path.join(self.pkg.get_path("pgm_learner"), "test", "graph-test.txt")
00031 self.teacher_data_path = self.data_path
00032 self.param_estimate = rospy.ServiceProxy("pgm_learner/discrete/parameter_estimation", DiscreteParameterEstimation)
00033 self.query = rospy.ServiceProxy("pgm_learner/discrete/query", DiscreteQuery)
00034 self.struct_estimate = rospy.ServiceProxy("pgm_learner/discrete/structure_estimation", DiscreteStructureEstimation)
00035 self.param_estimate.wait_for_service(timeout=30)
00036 self.query.wait_for_service(timeout=30)
00037 self.struct_estimate.wait_for_service(timeout=30)
00038 def test_param_estimation(self):
00039 req = DiscreteParameterEstimationRequest()
00040
00041
00042 skel = GraphSkeleton()
00043 skel.load(self.data_path)
00044 req.graph.nodes = skel.V
00045 req.graph.edges = [GraphEdge(k, v) for k,v in skel.E]
00046 skel.toporder()
00047
00048
00049 teacher_nd = NodeData()
00050 teacher_nd.load(self.teacher_data_path)
00051 bn = DiscreteBayesianNetwork(skel, teacher_nd)
00052 data = bn.randomsample(200)
00053 for v in data:
00054 gs = DiscreteGraphState()
00055 for k_s, v_s in v.items():
00056 gs.node_states.append(DiscreteNodeState(node=k_s, state=v_s))
00057 req.states.append(gs)
00058
00059 self.assertEqual(len(self.param_estimate(req).nodes), 5)
00060
00061 def test_query(self):
00062 teacher_nd = NodeData()
00063 teacher_nd.load(self.teacher_data_path)
00064 req = DiscreteQueryRequest()
00065 req.nodes = U.discrete_nodes_to_ros(teacher_nd.Vdata)
00066 req.evidence = [DiscreteNodeState("Letter", "weak")]
00067 req.query = ["Grade"]
00068 res = self.query(req)
00069 self.assertEqual(len(res.nodes), 1)
00070 n = res.nodes[0]
00071 self.assertEqual(n.name, "Grade")
00072 self.assertListEqual(['A','B','C'], n.outcomes)
00073
00074 def test_structure_estimation(self):
00075 req = DiscreteStructureEstimationRequest()
00076
00077 skel = GraphSkeleton()
00078 skel.load(self.data_path)
00079 skel.toporder()
00080 teacher_nd = NodeData()
00081 teacher_nd.load(self.teacher_data_path)
00082 bn = DiscreteBayesianNetwork(skel, teacher_nd)
00083 data = bn.randomsample(8000)
00084 for v in data:
00085 gs = DiscreteGraphState()
00086 for k_s, v_s in v.items():
00087 gs.node_states.append(DiscreteNodeState(node=k_s, state=v_s))
00088 req.states.append(gs)
00089
00090 res = self.struct_estimate(req)
00091 self.assertIsNotNone(res.graph)
00092 self.assertEqual(len(res.graph.nodes), 5)
00093 self.assertGreater(len(res.graph.edges), 0)
00094
00095 if __name__ == '__main__':
00096 test_name = "test_discrete_bayesian_network_learner"
00097 rospy.init_node(test_name)
00098 import rostest
00099 rostest.rosrun("pgm_learner", test_name, TestDiscreteBNLearnerNode)
00100