test_discrete_bn.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 # -*- coding: utf-8 -*-
00003 # Author: Yuki Furuta <furushchev@jsk.imi.i.u-tokyo.ac.jp>
00004 
00005 from __future__ import print_function
00006 import os
00007 import sys
00008 import unittest
00009 import rospy
00010 import rospkg
00011 from pgm_learner.srv import (DiscreteParameterEstimation,
00012                              DiscreteParameterEstimationRequest,
00013                              DiscreteQuery,
00014                              DiscreteQueryRequest,
00015                              DiscreteStructureEstimation,
00016                              DiscreteStructureEstimationRequest,
00017                              )
00018 from pgm_learner.msg import GraphEdge, DiscreteGraphState, DiscreteNodeState
00019 import pgm_learner.msg_utils as U
00020 
00021 from libpgm.graphskeleton import GraphSkeleton
00022 from libpgm.nodedata import NodeData
00023 from libpgm.discretebayesiannetwork import DiscreteBayesianNetwork
00024 
00025 
00026 class TestDiscreteBNLearnerNode(unittest.TestCase):
00027     def __init__(self, arg):
00028         super(self.__class__, self).__init__(arg)
00029         self.pkg = rospkg.RosPack()
00030         self.data_path = os.path.join(self.pkg.get_path("pgm_learner"), "test", "graph-test.txt")
00031         self.teacher_data_path = self.data_path
00032         self.param_estimate = rospy.ServiceProxy("pgm_learner/discrete/parameter_estimation", DiscreteParameterEstimation)
00033         self.query = rospy.ServiceProxy("pgm_learner/discrete/query", DiscreteQuery)
00034         self.struct_estimate = rospy.ServiceProxy("pgm_learner/discrete/structure_estimation", DiscreteStructureEstimation)
00035         self.param_estimate.wait_for_service(timeout=30)
00036         self.query.wait_for_service(timeout=30)
00037         self.struct_estimate.wait_for_service(timeout=30)
00038     def test_param_estimation(self):
00039         req = DiscreteParameterEstimationRequest()
00040 
00041         # load graph structure
00042         skel = GraphSkeleton()
00043         skel.load(self.data_path)
00044         req.graph.nodes = skel.V
00045         req.graph.edges = [GraphEdge(k, v) for k,v in skel.E]
00046         skel.toporder()
00047 
00048         # generate trial data
00049         teacher_nd = NodeData()
00050         teacher_nd.load(self.teacher_data_path)
00051         bn = DiscreteBayesianNetwork(skel, teacher_nd)
00052         data = bn.randomsample(200)
00053         for v in data:
00054             gs = DiscreteGraphState()
00055             for k_s, v_s in v.items():
00056                 gs.node_states.append(DiscreteNodeState(node=k_s, state=v_s))
00057             req.states.append(gs)
00058 
00059         self.assertEqual(len(self.param_estimate(req).nodes), 5)
00060 
00061     def test_query(self):
00062         teacher_nd = NodeData()
00063         teacher_nd.load(self.teacher_data_path)
00064         req = DiscreteQueryRequest()
00065         req.nodes = U.discrete_nodes_to_ros(teacher_nd.Vdata)
00066         req.evidence = [DiscreteNodeState("Letter", "weak")]
00067         req.query = ["Grade"]
00068         res = self.query(req)
00069         self.assertEqual(len(res.nodes), 1)
00070         n = res.nodes[0]
00071         self.assertEqual(n.name, "Grade")
00072         self.assertListEqual(['A','B','C'], n.outcomes)
00073 
00074     def test_structure_estimation(self):
00075         req = DiscreteStructureEstimationRequest()
00076 
00077         skel = GraphSkeleton()
00078         skel.load(self.data_path)
00079         skel.toporder()
00080         teacher_nd = NodeData()
00081         teacher_nd.load(self.teacher_data_path)
00082         bn = DiscreteBayesianNetwork(skel, teacher_nd)
00083         data = bn.randomsample(8000)
00084         for v in data:
00085             gs = DiscreteGraphState()
00086             for k_s, v_s in v.items():
00087                 gs.node_states.append(DiscreteNodeState(node=k_s, state=v_s))
00088             req.states.append(gs)
00089 
00090         res = self.struct_estimate(req)
00091         self.assertIsNotNone(res.graph)
00092         self.assertEqual(len(res.graph.nodes), 5)
00093         self.assertGreater(len(res.graph.edges), 0)
00094 
00095 if __name__ == '__main__':
00096     test_name = "test_discrete_bayesian_network_learner"
00097     rospy.init_node(test_name)
00098     import rostest
00099     rostest.rosrun("pgm_learner", test_name, TestDiscreteBNLearnerNode)
00100 


pgm_learner
Author(s): Yuki Furuta
autogenerated on Wed Jul 10 2019 03:24:11