data_collection_server.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 
7 import os
8 import os.path as osp
9 import pickle as pkl
10 import sys
11 
12 import numpy as np
13 import PIL.Image
14 import yaml
15 
16 import cv_bridge
17 import dynamic_reconfigure.server
18 import genpy
19 from jsk_topic_tools.log_utils import jsk_logfatal
20 import message_filters
21 import roslib.message
22 import rospy
23 from std_srvs.srv import Trigger
24 from std_srvs.srv import TriggerResponse
25 
26 from jsk_data.cfg import DataCollectionServerConfig
27 
28 
29 def dump_ndarray(filename, arr):
30  ext = osp.splitext(filename)[1]
31  if ext == '.pkl':
32  pkl.dump(arr, open(filename, 'wb'))
33  elif ext == '.npz':
34  np.savez_compressed(filename, arr)
35  elif ext in ['.png', '.jpg']:
36  PIL.Image.fromarray(arr).save(filename)
37  else:
38  raise ValueError
39 
40 
41 class DataCollectionServer(object):
42 
43  """Server to collect data.
44 
45  <rosparam>
46  save_dir: ~/.ros
47  topics:
48  - name: /camera/rgb/image_raw
49  msg_class: sensor_msgs/Image
50  fname: image.png
51  savetype: ColorImage
52  - name: /camera/depth/image_raw
53  msg_class: sensor_msgs/Image
54  fname: depth.pkl
55  savetype: DepthImage
56  params:
57  - key: /in_hand_data_collection_main/object
58  fname: label.txt
59  savetype: Text
60  </rosparam>
61  """
62 
63  def __init__(self):
64  dynamic_reconfigure.server.Server(
65  DataCollectionServerConfig, self.reconfig_cb)
66  self.msg = {}
67  self.topics = rospy.get_param('~topics', [])
68  # validation for saving topics
69  for topic in self.topics:
70  required_fields = ['name', 'msg_class', 'fname', 'savetype']
71  for field in required_fields:
72  if field not in topic:
73  jsk_logfatal("Required field '{}' for topic is missing"
74  .format(field))
75  sys.exit(1)
76  self.params = rospy.get_param('~params', [])
77  self.slop = rospy.get_param('~slop', 0.1)
78  # validation for saving params
79  for param in self.params:
80  required_fields = ['key', 'fname', 'savetype']
81  for field in required_fields:
82  if field not in param:
83  jsk_logfatal("Required field '{}' for param is missing"
84  .format(field))
85  sys.exit(1)
86 
87  method = rospy.get_param('~method', 'request')
88  use_message_filters = rospy.get_param('~message_filters', False)
89  self.timestamp_save_dir = rospy.get_param('~timestamp_save_dir', True)
90 
91  if rospy.has_param('~with_request'):
92  rospy.logwarn('Deprecated param: ~with_request, Use ~method')
93  if not rospy.get_param('~with_request'):
94  use_message_filters = True
95  method = None
96  if method == 'message_filters':
97  rospy.logwarn(
98  'Deprecated param: ~method: message_filters,'
99  'Use ~message_filters: true')
100  use_message_filters = True
101  method = None
102 
103  # set subscribers
104  self.subs = []
105  for topic in self.topics:
106  msg_class = roslib.message.get_message_class(
107  topic['msg_class'])
108  if use_message_filters:
110  topic['name'], msg_class)
111  else:
112  sub = rospy.Subscriber(
113  topic['name'], msg_class, self.sub_cb,
114  callback_args=topic['name'])
115  self.subs.append(sub)
116 
117  # add synchoronizer if use_message_filters
118  if use_message_filters:
119  queue_size = rospy.get_param('~queue_size', 10)
120  approximate_sync = rospy.get_param('~approximate_sync', False)
121  if approximate_sync:
122  slop = rospy.get_param('~slop', 0.1)
123  self.sync = message_filters.ApproximateTimeSynchronizer(
124  self.subs, queue_size=queue_size, slop=slop)
125  else:
127  self.subs, queue_size=queue_size)
128 
129  # set collecting method
130  if method == 'request':
131  if use_message_filters:
132  self.sync.registerCallback(self.sync_sub_cb)
133  self.server = rospy.Service(
134  '~save_request', Trigger, self.sync_service_cb)
135  else:
136  self.server = rospy.Service(
137  '~save_request', Trigger, self.service_cb)
138  elif method == 'timer':
139  duration = rospy.Duration(1.0 / rospy.get_param('~hz', 1.0))
140  self.start = False
141  self.start_server = rospy.Service(
142  '~start_request', Trigger, self.start_service_cb)
143  self.end_server = rospy.Service(
144  '~end_request', Trigger, self.end_service_cb)
145  if use_message_filters:
146  self.sync.registerCallback(self.sync_sub_cb)
147  self.timer = rospy.Timer(duration, self.sync_timer_cb)
148  else:
149  self.timer = rospy.Timer(duration, self.timer_cb)
150  else:
151  if use_message_filters:
152  self.sync.registerCallback(self.sync_sub_and_save_cb)
153  else:
154  rospy.logerr(
155  '~use_filters: False, ~method: None is not supported')
156  sys.exit(1)
157 
158  def reconfig_cb(self, config, level):
159  self.save_dir = osp.expanduser(config['save_dir'])
160  if not osp.exists(self.save_dir):
161  os.makedirs(self.save_dir)
162  return config
163 
164  def __del__(self):
165  for sub in self.subs:
166  sub.unregister()
167 
168  def sync_sub_cb(self, *msgs):
169  for topic, msg in zip(self.topics, msgs):
170  self.msg[topic['name']] = {
171  'stamp': msg.header.stamp,
172  'msg': msg
173  }
174 
175  def sync_sub_and_save_cb(self, *msgs):
176  self.sync_sub_cb(*msgs)
177  self._sync_save()
178 
179  def sub_cb(self, msg, topic_name):
180  self.msg[topic_name] = {
181  'stamp': msg.header.stamp if msg._has_header else rospy.Time.now(),
182  'msg': msg
183  }
184 
185  def save_topic(self, topic, msg, savetype, filename):
186  if savetype == 'ColorImage':
187  bridge = cv_bridge.CvBridge()
188  img = bridge.imgmsg_to_cv2(msg, 'rgb8')
189  dump_ndarray(filename, img)
190  elif savetype == 'DepthImage':
191  bridge = cv_bridge.CvBridge()
192  depth = bridge.imgmsg_to_cv2(msg)
193  dump_ndarray(filename, depth)
194  elif savetype == 'LabelImage':
195  bridge = cv_bridge.CvBridge()
196  label = bridge.imgmsg_to_cv2(msg)
197  dump_ndarray(filename, label)
198  elif savetype == 'YAML':
199  msg_yaml = genpy.message.strify_message(msg)
200  with open(filename, 'w') as f:
201  f.write(msg_yaml)
202  else:
203  rospy.logerr('Unexpected savetype for topic: {}'.format(savetype))
204  raise ValueError
205 
206  def save_param(self, param, savetype, filename):
207  value = rospy.get_param(param)
208  if savetype == 'Text':
209  with open(filename, 'w') as f:
210  f.write(str(value))
211  elif savetype == 'YAML':
212  content = yaml.safe_dump(value, allow_unicode=True,
213  default_flow_style=False)
214  with open(filename, 'w') as f:
215  f.write(content)
216  else:
217  rospy.logerr('Unexpected savetype for param: {}'.format(savetype))
218  raise ValueError
219 
220  def _sync_save(self):
221  stamp = self.msg[self.topics[0]['name']]['stamp']
222  save_dir = osp.join(self.save_dir, str(stamp.to_nsec()))
223  if not osp.exists(save_dir):
224  os.makedirs(save_dir)
225  for topic in self.topics:
226  msg = self.msg[topic['name']]['msg']
227  filename = osp.join(save_dir, topic['fname'])
228  self.save_topic(topic['name'], msg, topic['savetype'], filename)
229  for param in self.params:
230  filename = osp.join(save_dir, param['fname'])
231  self.save_param(param['key'], param['savetype'], filename)
232  msg = 'Saved data to {}'.format(save_dir)
233  rospy.loginfo(msg)
234  return True, msg
235 
236  def _save(self):
237  now = rospy.Time.now()
238  saving_msgs = {}
239  while len(saving_msgs) < len(self.topics):
240  for topic in self.topics:
241  if topic['name'] in saving_msgs:
242  continue
243  if topic['name'] not in self.msg:
244  continue
245  stamp = self.msg[topic['name']]['stamp']
246  if abs(now - stamp) < rospy.Duration(self.slop):
247  saving_msgs[topic['name']] = self.msg[topic['name']]['msg']
248  if now < stamp:
249  msg = 'timeout for topic [{}]. try bigger slop'.format(
250  topic['name'])
251  rospy.logerr(msg)
252  return False, msg
253  rospy.sleep(0.01)
254 
255  if self.timestamp_save_dir:
256  save_dir = osp.join(self.save_dir, str(now.to_nsec()))
257  else:
258  save_dir = self.save_dir
259 
260  if not osp.exists(save_dir):
261  os.makedirs(save_dir)
262  for topic in self.topics:
263  msg = saving_msgs[topic['name']]
264  filename = osp.join(save_dir, topic['fname'])
265  self.save_topic(topic['name'], msg, topic['savetype'], filename)
266  for param in self.params:
267  filename = osp.join(save_dir, param['fname'])
268  self.save_param(param['key'], param['savetype'], filename)
269  msg = 'Saved data to {}'.format(save_dir)
270  rospy.loginfo(msg)
271  return True, msg
272 
273  def start_service_cb(self, req):
274  self.start = True
275  return TriggerResponse(success=True)
276 
277  def end_service_cb(self, req):
278  self.start = False
279  return TriggerResponse(success=True)
280 
281  def service_cb(self, req):
282  result, msg = self._save()
283  if result:
284  return TriggerResponse(success=True, message=msg)
285  else:
286  return TriggerResponse(success=False, message=msg)
287 
288  def sync_service_cb(self, req):
289  result, msg = self._sync_save()
290  if result:
291  return TriggerResponse(success=True, message=msg)
292  else:
293  return TriggerResponse(success=False, message=msg)
294 
295  def timer_cb(self, event):
296  if self.start:
297  result, msg = self._save()
298 
299  def sync_timer_cb(self, event):
300  if self.start:
301  result, msg = self._sync_save()
302 
303 
304 if __name__ == '__main__':
305  rospy.init_node('data_collection_server')
307  rospy.spin()
def save_topic(self, topic, msg, savetype, filename)
def save_param(self, param, savetype, filename)
def dump_ndarray(filename, arr)


jsk_data
Author(s):
autogenerated on Tue Feb 6 2018 03:45:36