5 from __future__
import print_function
11 from pgm_learner.srv
import (DiscreteParameterEstimation,
12 DiscreteParameterEstimationRequest,
15 DiscreteStructureEstimation,
16 DiscreteStructureEstimationRequest,
18 from pgm_learner.msg
import GraphEdge, DiscreteGraphState, DiscreteNodeState
21 from libpgm.graphskeleton
import GraphSkeleton
22 from libpgm.nodedata
import NodeData
23 from libpgm.discretebayesiannetwork
import DiscreteBayesianNetwork
28 super(self.__class__, self).
__init__(arg)
29 self.
pkg = rospkg.RosPack()
30 self.
data_path = os.path.join(self.pkg.get_path(
"pgm_learner"),
"test",
"graph-test.txt")
32 self.
param_estimate = rospy.ServiceProxy(
"pgm_learner/discrete/parameter_estimation", DiscreteParameterEstimation)
33 self.
query = rospy.ServiceProxy(
"pgm_learner/discrete/query", DiscreteQuery)
34 self.
struct_estimate = rospy.ServiceProxy(
"pgm_learner/discrete/structure_estimation", DiscreteStructureEstimation)
35 self.param_estimate.wait_for_service(timeout=30)
36 self.query.wait_for_service(timeout=30)
37 self.struct_estimate.wait_for_service(timeout=30)
39 req = DiscreteParameterEstimationRequest()
42 skel = GraphSkeleton()
44 req.graph.nodes = skel.V
45 req.graph.edges = [GraphEdge(k, v)
for k,v
in skel.E]
49 teacher_nd = NodeData()
51 bn = DiscreteBayesianNetwork(skel, teacher_nd)
52 data = bn.randomsample(200)
54 gs = DiscreteGraphState()
55 for k_s, v_s
in v.items():
56 gs.node_states.append(DiscreteNodeState(node=k_s, state=v_s))
62 teacher_nd = NodeData()
64 req = DiscreteQueryRequest()
65 req.nodes = U.discrete_nodes_to_ros(teacher_nd.Vdata)
66 req.evidence = [DiscreteNodeState(
"Letter",
"weak")]
69 self.assertEqual(len(res.nodes), 1)
71 self.assertEqual(n.name,
"Grade")
72 self.assertListEqual([
'A',
'B',
'C'], n.outcomes)
75 req = DiscreteStructureEstimationRequest()
77 skel = GraphSkeleton()
80 teacher_nd = NodeData()
82 bn = DiscreteBayesianNetwork(skel, teacher_nd)
83 data = bn.randomsample(8000)
85 gs = DiscreteGraphState()
86 for k_s, v_s
in v.items():
87 gs.node_states.append(DiscreteNodeState(node=k_s, state=v_s))
91 self.assertIsNotNone(res.graph)
92 self.assertEqual(len(res.graph.nodes), 5)
93 self.assertGreater(len(res.graph.edges), 0)
95 if __name__ ==
'__main__':
96 test_name =
"test_discrete_bayesian_network_learner" 97 rospy.init_node(test_name)
99 rostest.rosrun(
"pgm_learner", test_name, TestDiscreteBNLearnerNode)
def test_param_estimation(self)
def test_structure_estimation(self)