test_discrete_bn.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 # Author: Yuki Furuta <furushchev@jsk.imi.i.u-tokyo.ac.jp>
4 
5 from __future__ import print_function
6 import os
7 import sys
8 import unittest
9 import rospy
10 import rospkg
11 from pgm_learner.srv import (DiscreteParameterEstimation,
12  DiscreteParameterEstimationRequest,
13  DiscreteQuery,
14  DiscreteQueryRequest,
15  DiscreteStructureEstimation,
16  DiscreteStructureEstimationRequest,
17  )
18 from pgm_learner.msg import GraphEdge, DiscreteGraphState, DiscreteNodeState
19 import pgm_learner.msg_utils as U
20 
21 from libpgm.graphskeleton import GraphSkeleton
22 from libpgm.nodedata import NodeData
23 from libpgm.discretebayesiannetwork import DiscreteBayesianNetwork
24 
25 
26 class TestDiscreteBNLearnerNode(unittest.TestCase):
27  def __init__(self, arg):
28  super(self.__class__, self).__init__(arg)
29  self.pkg = rospkg.RosPack()
30  self.data_path = os.path.join(self.pkg.get_path("pgm_learner"), "test", "graph-test.txt")
32  self.param_estimate = rospy.ServiceProxy("pgm_learner/discrete/parameter_estimation", DiscreteParameterEstimation)
33  self.query = rospy.ServiceProxy("pgm_learner/discrete/query", DiscreteQuery)
34  self.struct_estimate = rospy.ServiceProxy("pgm_learner/discrete/structure_estimation", DiscreteStructureEstimation)
35  self.param_estimate.wait_for_service(timeout=30)
36  self.query.wait_for_service(timeout=30)
37  self.struct_estimate.wait_for_service(timeout=30)
39  req = DiscreteParameterEstimationRequest()
40 
41  # load graph structure
42  skel = GraphSkeleton()
43  skel.load(self.data_path)
44  req.graph.nodes = skel.V
45  req.graph.edges = [GraphEdge(k, v) for k,v in skel.E]
46  skel.toporder()
47 
48  # generate trial data
49  teacher_nd = NodeData()
50  teacher_nd.load(self.teacher_data_path)
51  bn = DiscreteBayesianNetwork(skel, teacher_nd)
52  data = bn.randomsample(200)
53  for v in data:
54  gs = DiscreteGraphState()
55  for k_s, v_s in v.items():
56  gs.node_states.append(DiscreteNodeState(node=k_s, state=v_s))
57  req.states.append(gs)
58 
59  self.assertEqual(len(self.param_estimate(req).nodes), 5)
60 
61  def test_query(self):
62  teacher_nd = NodeData()
63  teacher_nd.load(self.teacher_data_path)
64  req = DiscreteQueryRequest()
65  req.nodes = U.discrete_nodes_to_ros(teacher_nd.Vdata)
66  req.evidence = [DiscreteNodeState("Letter", "weak")]
67  req.query = ["Grade"]
68  res = self.query(req)
69  self.assertEqual(len(res.nodes), 1)
70  n = res.nodes[0]
71  self.assertEqual(n.name, "Grade")
72  self.assertListEqual(['A','B','C'], n.outcomes)
73 
75  req = DiscreteStructureEstimationRequest()
76 
77  skel = GraphSkeleton()
78  skel.load(self.data_path)
79  skel.toporder()
80  teacher_nd = NodeData()
81  teacher_nd.load(self.teacher_data_path)
82  bn = DiscreteBayesianNetwork(skel, teacher_nd)
83  data = bn.randomsample(8000)
84  for v in data:
85  gs = DiscreteGraphState()
86  for k_s, v_s in v.items():
87  gs.node_states.append(DiscreteNodeState(node=k_s, state=v_s))
88  req.states.append(gs)
89 
90  res = self.struct_estimate(req)
91  self.assertIsNotNone(res.graph)
92  self.assertEqual(len(res.graph.nodes), 5)
93  self.assertGreater(len(res.graph.edges), 0)
94 
95 if __name__ == '__main__':
96  test_name = "test_discrete_bayesian_network_learner"
97  rospy.init_node(test_name)
98  import rostest
99  rostest.rosrun("pgm_learner", test_name, TestDiscreteBNLearnerNode)
100 


pgm_learner
Author(s): Yuki Furuta
autogenerated on Tue May 11 2021 02:55:44