data_collection_server.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 from __future__ import absolute_import
00004 from __future__ import division
00005 from __future__ import print_function
00006 
00007 import os
00008 import os.path as osp
00009 import pickle as pkl
00010 import sys
00011 
00012 import numpy as np
00013 import PIL.Image
00014 import yaml
00015 
00016 import cv_bridge
00017 import dynamic_reconfigure.server
00018 import genpy
00019 from jsk_topic_tools.log_utils import jsk_logfatal
00020 import message_filters
00021 import roslib.message
00022 import rospy
00023 from std_srvs.srv import Trigger
00024 from std_srvs.srv import TriggerResponse
00025 
00026 from jsk_data.cfg import DataCollectionServerConfig
00027 
00028 
00029 def dump_ndarray(filename, arr):
00030     ext = osp.splitext(filename)[1]
00031     if ext == '.pkl':
00032         pkl.dump(arr, open(filename, 'wb'))
00033     elif ext == '.npz':
00034         np.savez_compressed(filename, arr)
00035     elif ext in ['.png', '.jpg']:
00036         PIL.Image.fromarray(arr).save(filename)
00037     else:
00038         raise ValueError
00039 
00040 
00041 class DataCollectionServer(object):
00042 
00043     """Server to collect data.
00044 
00045       <rosparam>
00046         save_dir: ~/.ros
00047         topics:
00048           - name: /camera/rgb/image_raw
00049             msg_class: sensor_msgs/Image
00050             fname: image.png
00051             savetype: ColorImage
00052           - name: /camera/depth/image_raw
00053             msg_class: sensor_msgs/Image
00054             fname: depth.pkl
00055             savetype: DepthImage
00056         params:
00057           - key: /in_hand_data_collection_main/object
00058             fname: label.txt
00059             savetype: Text
00060       </rosparam>
00061     """
00062 
00063     def __init__(self):
00064         dynamic_reconfigure.server.Server(
00065             DataCollectionServerConfig, self.reconfig_cb)
00066         self.msg = {}
00067         self.topics = rospy.get_param('~topics', [])
00068         # validation for saving topics
00069         for topic in self.topics:
00070             required_fields = ['name', 'msg_class', 'fname', 'savetype']
00071             for field in required_fields:
00072                 if field not in topic:
00073                     jsk_logfatal("Required field '{}' for topic is missing"
00074                                  .format(field))
00075                     sys.exit(1)
00076         self.params = rospy.get_param('~params', [])
00077         self.slop = rospy.get_param('~slop', 0.1)
00078         # validation for saving params
00079         for param in self.params:
00080             required_fields = ['key', 'fname', 'savetype']
00081             for field in required_fields:
00082                 if field not in param:
00083                     jsk_logfatal("Required field '{}' for param is missing"
00084                                  .format(field))
00085                     sys.exit(1)
00086         if rospy.get_param('~with_request', True):
00087             self.subs = []
00088             for topic in self.topics:
00089                 msg_class = roslib.message.get_message_class(topic['msg_class'])
00090                 sub = rospy.Subscriber(topic['name'], msg_class, self.sub_cb,
00091                                        callback_args=topic['name'])
00092                 self.subs.append(sub)
00093             self.server = rospy.Service('~save_request', Trigger,
00094                                         self.service_cb)
00095         else:
00096             self.subs = []
00097             for topic in self.topics:
00098                 msg_class = roslib.message.get_message_class(topic['msg_class'])
00099                 sub = message_filters.Subscriber(topic['name'], msg_class)
00100                 self.subs.append(sub)
00101             self.sync = message_filters.TimeSynchronizer(
00102                 self.subs, queue_size=rospy.get_param('~queue_size', 10))
00103             self.sync.registerCallback(self.sync_sub_cb)
00104 
00105     def reconfig_cb(self, config, level):
00106         self.save_dir = osp.expanduser(config['save_dir'])
00107         if not osp.exists(self.save_dir):
00108             os.makedirs(self.save_dir)
00109         return config
00110 
00111     def sync_sub_cb(self, *msgs):
00112         stamp = msgs[0].header.stamp
00113         save_dir = osp.join(self.save_dir, str(stamp.to_nsec()))
00114         if not osp.exists(save_dir):
00115             os.makedirs(save_dir)
00116         for i, topic in enumerate(self.topics):
00117             filename = osp.join(save_dir, topic['fname'])
00118             self.save_topic(
00119                 topic['name'], msgs[i], topic['savetype'], filename)
00120         for param in self.params:
00121             filename = osp.join(save_dir, param['fname'])
00122             self.save_param(param['key'], param['savetype'], filename)
00123         rospy.loginfo('Saved data to %s'% save_dir)
00124 
00125     def __del__(self):
00126         for sub in self.subs:
00127             sub.unregister()
00128 
00129     def sub_cb(self, msg, topic_name):
00130         self.msg[topic_name] = {
00131             'stamp': msg.header.stamp if msg._has_header else rospy.Time.now(),
00132             'msg': msg
00133             }
00134 
00135     def save_topic(self, topic, msg, savetype, filename):
00136         if savetype == 'ColorImage':
00137             bridge = cv_bridge.CvBridge()
00138             img = bridge.imgmsg_to_cv2(msg, 'rgb8')
00139             dump_ndarray(filename, img)
00140         elif savetype == 'DepthImage':
00141             bridge = cv_bridge.CvBridge()
00142             depth = bridge.imgmsg_to_cv2(msg)
00143             dump_ndarray(filename, depth)
00144         elif savetype == 'LabelImage':
00145             bridge = cv_bridge.CvBridge()
00146             label = bridge.imgmsg_to_cv2(msg)
00147             dump_ndarray(filename, label)
00148         elif savetype == 'YAML':
00149             msg_yaml = genpy.message.strify_message(msg)
00150             with open(filename, 'w') as f:
00151                 f.write(msg_yaml)
00152         else:
00153             rospy.logerr('Unexpected savetype for topic: {}'.format(savetype))
00154             raise ValueError
00155 
00156     def save_param(self, param, savetype, filename):
00157         value = rospy.get_param(param)
00158         if savetype == 'Text':
00159             with open(filename, 'w') as f:
00160                 f.write(str(value))
00161         elif savetype == 'YAML':
00162             content = yaml.safe_dump(value, allow_unicode=True,
00163                                      default_flow_style=False)
00164             with open(filename, 'w') as f:
00165                 f.write(content)
00166         else:
00167             rospy.logerr('Unexpected savetype for param: {}'.format(savetype))
00168             raise ValueError
00169 
00170     def service_cb(self, req):
00171         now = rospy.Time.now()
00172         saving_msgs = {}
00173         while len(saving_msgs) < len(self.topics):
00174             for topic in self.topics:
00175                 if topic['name'] in saving_msgs:
00176                     continue
00177                 stamp = self.msg[topic['name']]['stamp']
00178                 if ((topic['name'] in self.msg) and
00179                         abs(now - stamp) < rospy.Duration(self.slop)):
00180                     saving_msgs[topic['name']] = self.msg[topic['name']]['msg']
00181                 if now < stamp:
00182                     msg = 'timeout for topic [{}]. try bigger slop'.format(
00183                         topic['name'])
00184                     rospy.logerr(msg)
00185                     return TriggerResponse(success=False, message=msg)
00186             rospy.sleep(0.01)
00187         save_dir = osp.join(self.save_dir, str(now.to_nsec()))
00188         if not osp.exists(save_dir):
00189             os.makedirs(save_dir)
00190         for topic in self.topics:
00191             msg = saving_msgs[topic['name']]
00192             filename = osp.join(save_dir, topic['fname'])
00193             self.save_topic(topic['name'], msg, topic['savetype'], filename)
00194         for param in self.params:
00195             filename = osp.join(save_dir, param['fname'])
00196             self.save_param(param['key'], param['savetype'], filename)
00197         message = 'Saved data to {}'.format(save_dir)
00198         rospy.loginfo(message)
00199         return TriggerResponse(success=True, message=message)
00200 
00201 
00202 if __name__ == '__main__':
00203     rospy.init_node('data_collection_server')
00204     server = DataCollectionServer()
00205     rospy.spin()


jsk_data
Author(s):
autogenerated on Fri Sep 8 2017 03:39:16