test_lg_bn.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 # Author: Yuki Furuta <furushchev@jsk.imi.i.u-tokyo.ac.jp>
4 
5 import os
6 import unittest
7 import rospy
8 import rospkg
9 from pgm_learner.srv import (LinearGaussianParameterEstimation,
10  LinearGaussianParameterEstimationRequest,
11  LinearGaussianStructureEstimation,
12  LinearGaussianStructureEstimationRequest,
13  )
14 from pgm_learner.msg import GraphEdge, LinearGaussianGraphState, LinearGaussianNodeState
15 
16 from libpgm.graphskeleton import GraphSkeleton
17 from libpgm.nodedata import NodeData
18 from libpgm.lgbayesiannetwork import LGBayesianNetwork
19 
20 
21 class TestLGBNLearnerNode(unittest.TestCase):
22  def __init__(self, arg):
23  super(self.__class__, self).__init__(arg)
24  self.pkg = rospkg.RosPack()
25  self.data_path = os.path.join(self.pkg.get_path("pgm_learner"), "test", "graph-test.txt")
26  self.teacher_data_path = os.path.join(self.pkg.get_path("pgm_learner"), "test", "graph-lg-test.txt")
27  self.param_estimate = rospy.ServiceProxy("pgm_learner/linear_gaussian/parameter_estimation", LinearGaussianParameterEstimation)
28  self.struct_estimate = rospy.ServiceProxy("pgm_learner/linear_gaussian/structure_estimation", LinearGaussianStructureEstimation)
29  self.param_estimate.wait_for_service(timeout=30)
30  self.struct_estimate.wait_for_service(timeout=30)
31 
33  req = LinearGaussianParameterEstimationRequest()
34 
35  # load graph structure
36  skel = GraphSkeleton()
37  skel.load(self.data_path)
38  req.graph.nodes = skel.V
39  req.graph.edges = [GraphEdge(k, v) for k,v in skel.E]
40  skel.toporder()
41 
42  # generate trial data
43  teacher_nd = NodeData()
44  teacher_nd.load(self.teacher_data_path)
45  bn = LGBayesianNetwork(skel, teacher_nd)
46  data = bn.randomsample(200)
47  for v in data:
48  gs = LinearGaussianGraphState()
49  for k_s, v_s in v.items():
50  gs.node_states.append(LinearGaussianNodeState(node=k_s, state=v_s))
51  req.states.append(gs)
52 
53  self.assertEqual(len(self.param_estimate(req).nodes), 5)
54 
56  req = LinearGaussianStructureEstimationRequest()
57 
58  # generate trial data
59  skel = GraphSkeleton()
60  skel.load(self.data_path)
61  skel.toporder()
62  teacher_nd = NodeData()
63  teacher_nd.load(self.teacher_data_path)
64  bn = LGBayesianNetwork(skel, teacher_nd)
65  data = bn.randomsample(8000)
66  for v in data:
67  gs = LinearGaussianGraphState()
68  for k_s, v_s in v.items():
69  gs.node_states.append(LinearGaussianNodeState(node=k_s, state=v_s))
70  req.states.append(gs)
71 
72  res = self.struct_estimate(req)
73  self.assertIsNotNone(res.graph)
74  self.assertEqual(len(res.graph.nodes), 5)
75  self.assertEqual(len(res.graph.edges), 4)
76 
77 
78 if __name__ == '__main__':
79  test_name = "test_linear_gaussian_bayesian_network_learner"
80  rospy.init_node(test_name)
81  import rostest
82  rostest.rosrun("pgm_learner", test_name, TestLGBNLearnerNode)
83 


pgm_learner
Author(s): Yuki Furuta
autogenerated on Tue May 11 2021 02:55:44