test_lg_bn.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 # -*- coding: utf-8 -*-
00003 # Author: Yuki Furuta <furushchev@jsk.imi.i.u-tokyo.ac.jp>
00004 
00005 import os
00006 import unittest
00007 import rospy
00008 import rospkg
00009 from pgm_learner.srv import (LinearGaussianParameterEstimation,
00010                              LinearGaussianParameterEstimationRequest,
00011                              LinearGaussianStructureEstimation,
00012                              LinearGaussianStructureEstimationRequest,
00013                              )
00014 from pgm_learner.msg import GraphEdge, LinearGaussianGraphState, LinearGaussianNodeState
00015 
00016 from libpgm.graphskeleton import GraphSkeleton
00017 from libpgm.nodedata import NodeData
00018 from libpgm.lgbayesiannetwork import LGBayesianNetwork
00019 
00020 
00021 class TestLGBNLearnerNode(unittest.TestCase):
00022     def __init__(self, arg):
00023         super(self.__class__, self).__init__(arg)
00024         self.pkg = rospkg.RosPack()
00025         self.data_path = os.path.join(self.pkg.get_path("pgm_learner"), "test", "graph-test.txt")
00026         self.teacher_data_path = os.path.join(self.pkg.get_path("pgm_learner"), "test", "graph-lg-test.txt")
00027         self.param_estimate = rospy.ServiceProxy("pgm_learner/linear_gaussian/parameter_estimation", LinearGaussianParameterEstimation)
00028         self.struct_estimate = rospy.ServiceProxy("pgm_learner/linear_gaussian/structure_estimation", LinearGaussianStructureEstimation)
00029         self.param_estimate.wait_for_service(timeout=30)
00030         self.struct_estimate.wait_for_service(timeout=30)
00031 
00032     def test_param_estimation(self):
00033         req = LinearGaussianParameterEstimationRequest()
00034 
00035         # load graph structure
00036         skel = GraphSkeleton()
00037         skel.load(self.data_path)
00038         req.graph.nodes = skel.V
00039         req.graph.edges = [GraphEdge(k, v) for k,v in skel.E]
00040         skel.toporder()
00041 
00042         # generate trial data
00043         teacher_nd = NodeData()
00044         teacher_nd.load(self.teacher_data_path)
00045         bn = LGBayesianNetwork(skel, teacher_nd)
00046         data = bn.randomsample(200)
00047         for v in data:
00048             gs = LinearGaussianGraphState()
00049             for k_s, v_s in v.items():
00050                 gs.node_states.append(LinearGaussianNodeState(node=k_s, state=v_s))
00051             req.states.append(gs)
00052 
00053         self.assertEqual(len(self.param_estimate(req).nodes), 5)
00054 
00055     def test_structure_estimation(self):
00056         req = LinearGaussianStructureEstimationRequest()
00057 
00058         # generate trial data
00059         skel = GraphSkeleton()
00060         skel.load(self.data_path)
00061         skel.toporder()
00062         teacher_nd = NodeData()
00063         teacher_nd.load(self.teacher_data_path)
00064         bn = LGBayesianNetwork(skel, teacher_nd)
00065         data = bn.randomsample(8000)
00066         for v in data:
00067             gs = LinearGaussianGraphState()
00068             for k_s, v_s in v.items():
00069                 gs.node_states.append(LinearGaussianNodeState(node=k_s, state=v_s))
00070             req.states.append(gs)
00071 
00072         res = self.struct_estimate(req)
00073         self.assertIsNotNone(res.graph)
00074         self.assertEqual(len(res.graph.nodes), 5)
00075         self.assertEqual(len(res.graph.edges), 4)
00076 
00077 
00078 if __name__ == '__main__':
00079     test_name = "test_linear_gaussian_bayesian_network_learner"
00080     rospy.init_node(test_name)
00081     import rostest
00082     rostest.rosrun("pgm_learner", test_name, TestLGBNLearnerNode)
00083 


pgm_learner
Author(s): Yuki Furuta
autogenerated on Wed Jul 10 2019 03:24:11