9 from pgm_learner.srv
import (LinearGaussianParameterEstimation,
10 LinearGaussianParameterEstimationRequest,
11 LinearGaussianStructureEstimation,
12 LinearGaussianStructureEstimationRequest,
14 from pgm_learner.msg
import GraphEdge, LinearGaussianGraphState, LinearGaussianNodeState
16 from libpgm.graphskeleton
import GraphSkeleton
17 from libpgm.nodedata
import NodeData
18 from libpgm.lgbayesiannetwork
import LGBayesianNetwork
23 super(self.__class__, self).
__init__(arg)
24 self.
pkg = rospkg.RosPack()
25 self.
data_path = os.path.join(self.pkg.get_path(
"pgm_learner"),
"test",
"graph-test.txt")
26 self.
teacher_data_path = os.path.join(self.pkg.get_path(
"pgm_learner"),
"test",
"graph-lg-test.txt")
27 self.
param_estimate = rospy.ServiceProxy(
"pgm_learner/linear_gaussian/parameter_estimation", LinearGaussianParameterEstimation)
28 self.
struct_estimate = rospy.ServiceProxy(
"pgm_learner/linear_gaussian/structure_estimation", LinearGaussianStructureEstimation)
29 self.param_estimate.wait_for_service(timeout=30)
30 self.struct_estimate.wait_for_service(timeout=30)
33 req = LinearGaussianParameterEstimationRequest()
36 skel = GraphSkeleton()
38 req.graph.nodes = skel.V
39 req.graph.edges = [GraphEdge(k, v)
for k,v
in skel.E]
43 teacher_nd = NodeData()
45 bn = LGBayesianNetwork(skel, teacher_nd)
46 data = bn.randomsample(200)
48 gs = LinearGaussianGraphState()
49 for k_s, v_s
in v.items():
50 gs.node_states.append(LinearGaussianNodeState(node=k_s, state=v_s))
56 req = LinearGaussianStructureEstimationRequest()
59 skel = GraphSkeleton()
62 teacher_nd = NodeData()
64 bn = LGBayesianNetwork(skel, teacher_nd)
65 data = bn.randomsample(8000)
67 gs = LinearGaussianGraphState()
68 for k_s, v_s
in v.items():
69 gs.node_states.append(LinearGaussianNodeState(node=k_s, state=v_s))
73 self.assertIsNotNone(res.graph)
74 self.assertEqual(len(res.graph.nodes), 5)
75 self.assertEqual(len(res.graph.edges), 4)
78 if __name__ ==
'__main__':
79 test_name =
"test_linear_gaussian_bayesian_network_learner" 80 rospy.init_node(test_name)
82 rostest.rosrun(
"pgm_learner", test_name, TestLGBNLearnerNode)
def test_structure_estimation(self)
def test_param_estimation(self)