00001
00002
00003 from __future__ import absolute_import
00004 from __future__ import division
00005 from __future__ import print_function
00006
00007 import os
00008 import os.path as osp
00009 import pickle as pkl
00010 import sys
00011
00012 import numpy as np
00013 import PIL.Image
00014 import yaml
00015
00016 import cv_bridge
00017 import dynamic_reconfigure.server
00018 import genpy
00019 from jsk_topic_tools.log_utils import jsk_logfatal
00020 import message_filters
00021 import roslib.message
00022 import rospy
00023 from std_srvs.srv import Trigger
00024 from std_srvs.srv import TriggerResponse
00025
00026 from jsk_data.cfg import DataCollectionServerConfig
00027
00028
00029 def dump_ndarray(filename, arr):
00030 ext = osp.splitext(filename)[1]
00031 if ext == '.pkl':
00032 pkl.dump(arr, open(filename, 'wb'))
00033 elif ext == '.npz':
00034 np.savez_compressed(filename, arr)
00035 elif ext in ['.png', '.jpg']:
00036 PIL.Image.fromarray(arr).save(filename)
00037 else:
00038 raise ValueError
00039
00040
00041 class DataCollectionServer(object):
00042
00043 """Server to collect data.
00044
00045 <rosparam>
00046 save_dir: ~/.ros
00047 topics:
00048 - name: /camera/rgb/image_raw
00049 msg_class: sensor_msgs/Image
00050 fname: image.png
00051 savetype: ColorImage
00052 - name: /camera/depth/image_raw
00053 msg_class: sensor_msgs/Image
00054 fname: depth.pkl
00055 savetype: DepthImage
00056 params:
00057 - key: /in_hand_data_collection_main/object
00058 fname: label.txt
00059 savetype: Text
00060 </rosparam>
00061 """
00062
00063 def __init__(self):
00064 dynamic_reconfigure.server.Server(
00065 DataCollectionServerConfig, self.reconfig_cb)
00066 self.msg = {}
00067 self.topics = rospy.get_param('~topics', [])
00068
00069 for topic in self.topics:
00070 required_fields = ['name', 'msg_class', 'fname', 'savetype']
00071 for field in required_fields:
00072 if field not in topic:
00073 jsk_logfatal("Required field '{}' for topic is missing"
00074 .format(field))
00075 sys.exit(1)
00076 self.params = rospy.get_param('~params', [])
00077 self.slop = rospy.get_param('~slop', 0.1)
00078
00079 for param in self.params:
00080 required_fields = ['key', 'fname', 'savetype']
00081 for field in required_fields:
00082 if field not in param:
00083 jsk_logfatal("Required field '{}' for param is missing"
00084 .format(field))
00085 sys.exit(1)
00086 if rospy.get_param('~with_request', True):
00087 self.subs = []
00088 for topic in self.topics:
00089 msg_class = roslib.message.get_message_class(topic['msg_class'])
00090 sub = rospy.Subscriber(topic['name'], msg_class, self.sub_cb,
00091 callback_args=topic['name'])
00092 self.subs.append(sub)
00093 self.server = rospy.Service('~save_request', Trigger,
00094 self.service_cb)
00095 else:
00096 self.subs = []
00097 for topic in self.topics:
00098 msg_class = roslib.message.get_message_class(topic['msg_class'])
00099 sub = message_filters.Subscriber(topic['name'], msg_class)
00100 self.subs.append(sub)
00101 self.sync = message_filters.TimeSynchronizer(
00102 self.subs, queue_size=rospy.get_param('~queue_size', 10))
00103 self.sync.registerCallback(self.sync_sub_cb)
00104
00105 def reconfig_cb(self, config, level):
00106 self.save_dir = osp.expanduser(config['save_dir'])
00107 if not osp.exists(self.save_dir):
00108 os.makedirs(self.save_dir)
00109 return config
00110
00111 def sync_sub_cb(self, *msgs):
00112 stamp = msgs[0].header.stamp
00113 save_dir = osp.join(self.save_dir, str(stamp.to_nsec()))
00114 if not osp.exists(save_dir):
00115 os.makedirs(save_dir)
00116 for i, topic in enumerate(self.topics):
00117 filename = osp.join(save_dir, topic['fname'])
00118 self.save_topic(
00119 topic['name'], msgs[i], topic['savetype'], filename)
00120 for param in self.params:
00121 filename = osp.join(save_dir, param['fname'])
00122 self.save_param(param['key'], param['savetype'], filename)
00123 rospy.loginfo('Saved data to %s'% save_dir)
00124
00125 def __del__(self):
00126 for sub in self.subs:
00127 sub.unregister()
00128
00129 def sub_cb(self, msg, topic_name):
00130 self.msg[topic_name] = {
00131 'stamp': msg.header.stamp if msg._has_header else rospy.Time.now(),
00132 'msg': msg
00133 }
00134
00135 def save_topic(self, topic, msg, savetype, filename):
00136 if savetype == 'ColorImage':
00137 bridge = cv_bridge.CvBridge()
00138 img = bridge.imgmsg_to_cv2(msg, 'rgb8')
00139 dump_ndarray(filename, img)
00140 elif savetype == 'DepthImage':
00141 bridge = cv_bridge.CvBridge()
00142 depth = bridge.imgmsg_to_cv2(msg)
00143 dump_ndarray(filename, depth)
00144 elif savetype == 'LabelImage':
00145 bridge = cv_bridge.CvBridge()
00146 label = bridge.imgmsg_to_cv2(msg)
00147 dump_ndarray(filename, label)
00148 elif savetype == 'YAML':
00149 msg_yaml = genpy.message.strify_message(msg)
00150 with open(filename, 'w') as f:
00151 f.write(msg_yaml)
00152 else:
00153 rospy.logerr('Unexpected savetype for topic: {}'.format(savetype))
00154 raise ValueError
00155
00156 def save_param(self, param, savetype, filename):
00157 value = rospy.get_param(param)
00158 if savetype == 'Text':
00159 with open(filename, 'w') as f:
00160 f.write(str(value))
00161 elif savetype == 'YAML':
00162 content = yaml.safe_dump(value, allow_unicode=True,
00163 default_flow_style=False)
00164 with open(filename, 'w') as f:
00165 f.write(content)
00166 else:
00167 rospy.logerr('Unexpected savetype for param: {}'.format(savetype))
00168 raise ValueError
00169
00170 def service_cb(self, req):
00171 now = rospy.Time.now()
00172 saving_msgs = {}
00173 while len(saving_msgs) < len(self.topics):
00174 for topic in self.topics:
00175 if topic['name'] in saving_msgs:
00176 continue
00177 stamp = self.msg[topic['name']]['stamp']
00178 if ((topic['name'] in self.msg) and
00179 abs(now - stamp) < rospy.Duration(self.slop)):
00180 saving_msgs[topic['name']] = self.msg[topic['name']]['msg']
00181 if now < stamp:
00182 msg = 'timeout for topic [{}]. try bigger slop'.format(
00183 topic['name'])
00184 rospy.logerr(msg)
00185 return TriggerResponse(success=False, message=msg)
00186 rospy.sleep(0.01)
00187 save_dir = osp.join(self.save_dir, str(now.to_nsec()))
00188 if not osp.exists(save_dir):
00189 os.makedirs(save_dir)
00190 for topic in self.topics:
00191 msg = saving_msgs[topic['name']]
00192 filename = osp.join(save_dir, topic['fname'])
00193 self.save_topic(topic['name'], msg, topic['savetype'], filename)
00194 for param in self.params:
00195 filename = osp.join(save_dir, param['fname'])
00196 self.save_param(param['key'], param['savetype'], filename)
00197 message = 'Saved data to {}'.format(save_dir)
00198 rospy.loginfo(message)
00199 return TriggerResponse(success=True, message=message)
00200
00201
00202 if __name__ == '__main__':
00203 rospy.init_node('data_collection_server')
00204 server = DataCollectionServer()
00205 rospy.spin()