load_bag.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 # Software License Agreement (BSD License)
00003 #
00004 #  Copyright (c) 2010, UC Regents
00005 #  All rights reserved.
00006 #
00007 #  Redistribution and use in source and binary forms, with or without
00008 #  modification, are permitted provided that the following conditions
00009 #  are met:
00010 #
00011 #   * Redistributions of source code must retain the above copyright
00012 #     notice, this list of conditions and the following disclaimer.
00013 #   * Redistributions in binary form must reproduce the above
00014 #     copyright notice, this list of conditions and the following
00015 #     disclaimer in the documentation and/or other materials provided
00016 #     with the distribution.
00017 #   * Neither the name of the University of California nor the names of its
00018 #     contributors may be used to endorse or promote products derived
00019 #     from this software without specific prior written permission.
00020 #
00021 #  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
00022 #  "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
00023 #  LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
00024 #  FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
00025 #  COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
00026 #  INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
00027 #  BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00028 #  LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
00029 #  CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00030 #  LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
00031 #  ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
00032 #  POSSIBILITY OF SUCH DAMAGE.
00033 
00034 import roslib
00035 roslib.load_manifest('starmac_tools')
00036 import rosbag
00037 import rospy
00038 import sys
00039 import numpy as np
00040 import scipy.io as sio
00041 import gzip
00042 import cPickle
00043 import time
00044 
00045 _primitive_types = [ 'byte', 'int8', 'uint8',
00046                     'int16', 'uint16', 'int32', 'uint32',
00047                     'int64', 'uint64', 'float32', 'float64', 'bool']
00048 
00049 def mangle(s):
00050     """
00051     Function to mangle field names, replacing first underscores with double underscores,
00052     then slashes with underscores
00053     """
00054     mangled = s.replace('_','__').replace('/','_')
00055     assert(unmangle(mangled)) == s
00056     return mangled
00057 
00058 def unmangle(s):
00059     return s[::-1].replace('__','~')[::-1].replace('_', '/').replace('~','_')
00060 
00061 class BagLoader(object):
00062     PICKLED_EXT = '.sbag.gz'
00063     """
00064     Class that allows one to load data from a ROS .bag file. 
00065     
00066     For the impatient:
00067     >>> mybag = BagLoader('some_bagfile.bag')
00068     >>> t = mybag._grey_innerloop_time
00069     >>> V = mybag._grey_innerloop_battVoltage
00070     >>> from pylab import *
00071     >>> plot(t, V)    
00072     
00073     Data is stored in the instance private member
00074     dictionary _data, with keys that are based on the ROS 'name' of each time-series variable (the ROS topic name
00075     as recorded in the bagfile is concatenated with slashes with the field name, i.e. /namespace/topic/field). 
00076     The variables are also available as instance members, with name mangling to convert slashes to underscores.
00077     (this is handy when you have some tab completion as in IPython)
00078     
00079     Note that at present, 'nested' fields, i.e. toplevel fields that are not a ROS 'primitive type' are not supported
00080     and are ignored, with the exception of fields of type Header in the top level, for which the stamp field is retrieved
00081     as a special case.
00082     
00083     Each topic includes the special field 'time' -- note that this is the time at which the message was recorded
00084     by rosbag. If the message includes a Header field, then at present the timestamp within cannot be accessed,
00085     as nested variables are not yet handled.
00086     
00087     Another useful feature is the ability to save the data to a Matlab-readable .mat file:
00088     >>> mybag.save_mat('test.mat')
00089     
00090     Note - the minimum recorded timestamp in the bagfile is taken to be time zero, and the message and header (when present)
00091     timestamps subtract this value (self.start_stamp) in 64-bit integer arithmetic before storing the result as a double. Thus
00092     the time values in the output are double (64-bit float) times in nanoseconds with zero being the aforementioned quantity.
00093     Note that this could lead to large negative times in cases where the timestamp is set to zero (may happen in headers). It is
00094     the user's responsibility to handle such a case appropriately. 
00095     """
00096     def __init__(self, filename=None, verbose=False, skip_images=True):
00097         self.verbose = verbose
00098         self._data = {} # this is where everything will go
00099         self._dict_names = {} # will contain the 'mangled' names as keys 
00100         self._ros_types = {}
00101         if filename is not None:
00102             if filename.lower().endswith('.mat'):
00103                 self.load_mat(filename)
00104             else:
00105                 self._load(filename)
00106         self.filename = filename
00107         
00108     def _save_pickled(self, filename):
00109         """
00110         Save a pickled and gzipped representation of the loaded data. Unfortunately this seems to run very slowly..
00111         """
00112         if not filename.endswith(BagLoader.PICKLED_EXT):
00113             print('Warning: Recommended filename extension is %s' % BagLoader.PICKLED_EXT)
00114         stuff = {'___VERSION_MAJOR__': 0, # increment this whenever format changes and breaks backwards compatibility
00115                  '___VERSION_MINOR__': 1, # increment this whenever format changes
00116                  '_data': self._data,
00117                  '_dict_names': self._dict_names,
00118                  '_ros_types': self._ros_types,
00119                  'topic_msg_counts': self.topic_msg_counts,
00120                  'start_stamp': self.start_stamp,
00121                  }
00122         fh = gzip.open(filename, 'w')
00123         cPickle.dump(stuff, fh)
00124         fh.close()
00125         
00126     def _load_pickled(self, filename):
00127         """
00128         Load a pickled and gzipped representation of the data. Unfortunately this seems to run very slowly..
00129         """
00130         if not filename.endswith(BagLoader.PICKLED_EXT):
00131             print('Warning: Recommended filename extension is %s' % BagLoader.PICKLED_EXT)
00132         fh = gzip.open(filename, 'r')
00133         stuff = cPickle.load(fh)
00134         fh.close()
00135         assert(stuff['___VERSION_MAJOR__'] == 0)
00136         self._dict_names = stuff['_dict_names']
00137         self._ros_types = stuff['_ros_types']
00138         self.topic_msg_counts = stuff['topic_msg_counts']
00139         self.start_stamp = stuff['start_stamp']
00140         self._data = stuff['_data']
00141         for dict_name, name in self._dict_names.items():
00142             self.__dict__[dict_name] = self._data[name]
00143 
00144     def _load(self, filename):
00145         # Load it up:
00146         print "Loading data from", filename
00147         self.b = rosbag.Bag(filename)
00148         # Figure out what time zero should be:
00149         self.start_stamp  = min([index[ 0].time for index in self.b._connection_indexes.values()])
00150         # Get the list of topics:
00151         topics = set([c.topic for c in self.b._get_connections()])                
00152         topic_datatypes    = {}
00153         topic_conn_counts  = {}
00154         topic_msg_counts   = {}
00155         topic_freqs_median = {}
00156         # And figure out a bunch of stuff about them:
00157         self.topic_msg_counts = {}
00158         for topic in topics:
00159             connections = list(self.b._get_connections(topic))
00160             stamps = [entry.time.to_sec() for entry in self.b._get_entries(connections)]
00161             
00162             topic_datatypes[topic] = connections[0].datatype
00163             topic_conn_counts[topic] = len(connections)
00164             self.topic_msg_counts[topic] = len(stamps)
00165         #    if len(stamps) > 1:
00166         #        periods = [s1 - s0 for s1, s0 in zip(stamps[1:], stamps[:-1])]
00167         #        med_period = _median(periods)
00168         #        if med_period > 0.0:
00169         #            topic_freqs_median[topic] = 1.0 / med_period
00170         
00171         
00172         for topic in topics:
00173             if self.verbose: print "Loading topic %s, type %s" % (topic, topic_datatypes[topic]),
00174             if topic_datatypes[topic] == 'sensor_msgs/Image':
00175                 if self.verbose: print ".. Skipping"
00176             else:
00177                 if self.verbose: print 
00178                 self._load_process_topic(topic)
00179                 
00180     def _new_data(self, name, init_data, ros_type=None):
00181         """
00182         Initiaize a spot in self._data for the given name (topic+field), with 
00183         """
00184         self._data[name] = init_data
00185         self.__dict__[mangle(name)] = self._data[name]
00186         self._dict_names[mangle(name)] = name
00187         self._ros_types[name] = ros_type
00188                 
00189     def _load_process_topic(self, topic):
00190         msg_iter = self.b.read_messages(topic)
00191         i = 0
00192         N = self.topic_msg_counts[topic]
00193         #(topic, msg, time) = msg_iter.next()
00194         for (topic, msg, time) in msg_iter:
00195             if i == 0:
00196                 #print 'Topic: ', topic, '(', msg._type, ')'
00197                 (fieldnames, types) = msg.__slots__, msg._get_types()
00198 #                self._data[topic+'/_time'] = np.zeros(shape=(N,), dtype=np.float64)
00199 #                self.__dict__[mangle(topic+'/_time')] = self._data[topic+'/_time']
00200                 self._new_data(topic+'/_time', np.zeros(shape=(N,), dtype=np.float64))
00201                 if ('header','Header') in zip(fieldnames,types):
00202 #                    self._data[topic+'/_header_time'] =  np.zeros(shape=(N,), dtype=np.float64)
00203 #                    self.__dict__[mangle(topic+'/_header_time')] = self._data[topic+'/_header_time']
00204                     self._new_data(topic+'/_header_time', np.zeros(shape=(N,), dtype=np.float64))
00205                     has_header = True
00206                 else:
00207                     has_header = False
00208             header_time = self._sub_load_init_array(i, 0, topic, msg, time, N, get_header=has_header)
00209             self._data[topic+'/_time'][i] = (time-self.start_stamp).to_sec()
00210             if has_header:
00211                 self._data[topic+'/_header_time'][i] = (header_time-self.start_stamp).to_sec()
00212             i += 1
00213 
00214     
00215     def _sub_load_init_array(self, i, depth, base, msg, time, N, get_header=False):
00216         verbose = self.verbose
00217         (fieldnames, types) = msg.__slots__, msg._get_types()
00218         if get_header:
00219             header_time = msg.__getattribute__('header').stamp
00220         else:
00221             header_time = None
00222         for fieldname2, type2 in zip(fieldnames, types):
00223             t_stripped = type2.split('[')[0]
00224             is_primitive = t_stripped in _primitive_types
00225             is_array = ('[' in type2 and ']' in type2)
00226             full_field_name = base + '/' + fieldname2
00227             array_size = None
00228             if type2 is 'time':
00229                 for s in ('/secs', '/nsecs'):
00230                     if i == 0:
00231                         self._new_data(full_field_name + s, np.zeros(shape=(N,), dtype=np.uint32), (t_stripped, is_primitive, is_array, array_size))
00232                         if verbose: print full_field_name + s, ':', str(self._data[full_field_name+s].dtype), str(self._data[full_field_name+s].shape)
00233                 self._data[full_field_name + '/secs'][i] = msg.__getattribute__(fieldname2).secs
00234                 self._data[full_field_name + '/nsecs'][i] = msg.__getattribute__(fieldname2).nsecs
00235             else:
00236                 if is_primitive:
00237                     if is_array:
00238                         array_size_str = type2.split('[')[1].split(']')[0]
00239                         if len(array_size_str) > 0: # fixed length array
00240                             array_size = int(array_size_str)
00241                             if i == 0:
00242                                 self._new_data(full_field_name, np.zeros(shape=(N,array_size), dtype=t_stripped), 
00243                                                (t_stripped, is_primitive, is_array, array_size))
00244                                 if verbose: print " "*depth + full_field_name, ':', str(self._data[full_field_name].dtype), str(self._data[full_field_name].shape)
00245                             self._data[full_field_name][i,:] = msg.__getattribute__(fieldname2)
00246                         else:
00247                             if i == 0:
00248                                 array_size = None # variable length array
00249                                 self._new_data(full_field_name, [], (t_stripped, is_primitive, is_array, array_size))
00250                                 if verbose: print " "*depth + full_field_name, ':', type2
00251                             if t_stripped == 'uint8':
00252                                 self._data[full_field_name].append(np.fromstring(msg.__getattribute__(fieldname2), dtype='uint8'))
00253                             else:
00254                                 self._data[full_field_name].append(msg.__getattribute__(fieldname2))
00255                     else:
00256                         if i == 0:
00257                             self._new_data(full_field_name, np.zeros(shape=(N,), dtype=t_stripped), 
00258                                            (t_stripped, is_primitive, is_array, array_size))
00259                             if verbose: print " "*depth + full_field_name, ':', str(self._data[full_field_name].dtype), str(self._data[full_field_name].shape)
00260                         self._data[full_field_name][i] = msg.__getattribute__(fieldname2)
00261                 elif t_stripped == 'string':
00262                     if i == 0:
00263                         self._new_data(full_field_name, [], (t_stripped, is_primitive, is_array, array_size))
00264                         if verbose:
00265                             if is_array:
00266                                 print(" "*depth + full_field_name, ': string')
00267                             else:
00268                                 print(" "*depth + full_field_name, ': string')
00269                     self._data[full_field_name].append(msg.__getattribute__(fieldname2))
00270                 else:
00271                     if is_array:
00272                         if i == 0:
00273                             self._new_data(full_field_name, [], (t_stripped, is_primitive, is_array, array_size))
00274                             if verbose: print " "*depth + full_field_name, type2
00275                             print "WARNING: Arrays of non-primitive types not yet fully supported (%s : %s)" % (fieldname2, type2)
00276                         else:
00277                             self._data[full_field_name].append(msg.__getattribute__(fieldname2))
00278                     else:
00279                         if i == 0:
00280                             if verbose: print " "*depth + full_field_name, type2, "*"
00281                         self._sub_load_init_array(i, depth+1, full_field_name, msg.__getattribute__(fieldname2), time, N)
00282                             #print " "*depth + "<<< done recursing on ", full_field_name
00283         return header_time
00284                 
00285     def save_mat(self, mat_filename):
00286         """
00287         Save data to a MATLAB (version 5) .mat file. Note that not all data can be saved in this format; basically only
00288         fields that are primitive types or arrays of primitive types (besides strings) are saved.
00289         """
00290         mdict = {}
00291         for k, v in self._data.iteritems():
00292             if v is not None:
00293                 rtypes = self._ros_types[k]
00294                 if rtypes is None:
00295                     is_primitive = True # special case to handle derived header time
00296                     is_string = False
00297                     is_array = False
00298                     array_size = 0
00299                 else:
00300                     (t_stripped, is_primitive, is_array, array_size) = self._ros_types[k]
00301                     is_string = (t_stripped == 'string')
00302                 if is_primitive:
00303                     mdict['v'+mangle(k)] = v # Matlab variable names must start with a letter
00304                 elif is_string and not is_array:
00305                     mdict['v'+mangle(k)] = self._strlist_to_cellarray(v)
00306                 else:
00307                     print("Warning: cannot save field %s to matlab (not a primitive type)" % k)
00308         stuff = {'info_VERSION_MAJOR__': 0, # increment this whenever format changes and breaks backwards compatibility
00309          'info_VERSION_MINOR__': 1, # increment this whenever format changes
00310          'info_topic_msg_counts': self._dict_to_cellarray(self.topic_msg_counts)
00311          }
00312         mdict.update(stuff)
00313         sio.savemat(mat_filename, mdict, long_field_names=True, oned_as='column')
00314         
00315     def load_mat(self, mat_filename):
00316         """
00317         Load data from a MATLAB (version 5) .mat file
00318         """
00319         mdict = sio.loadmat(mat_filename, squeeze_me=True, struct_as_record=False)
00320         assert(mdict['info_VERSION_MAJOR__'] == 0)
00321         for k, v in mdict.iteritems():
00322             #print(k)
00323             if k.startswith('v'):
00324                 name = unmangle(k[1:])
00325                 #print(name)
00326                 if np.issubdtype(v[0].dtype, np.unicode): # detect string list
00327                     self._new_data(name, [str(vv.flat[0]) for vv in v])
00328                 else:
00329                     self._new_data(name, v) # note that ros type is lost..
00330             elif k == 'info_topic_msg_counts':
00331                 self.topic_msg_counts = self._cellarray_to_dict(mdict[k])
00332                 
00333     def _dict_to_cellarray(self, d):
00334         """
00335         Helper function for save_mat
00336         """
00337         obj_arr = np.zeros((len(d),), dtype=np.object)
00338         i = 0
00339         for k, v in d.iteritems():
00340             obj_arr[i] = {'key': k, 'value': v}
00341             i += 1
00342         return obj_arr
00343     
00344     def _cellarray_to_dict(self, c):
00345         """
00346         Helper function for load_mat
00347         """
00348         d = {}
00349         for i in range(len(c)):
00350             cell = c[i]
00351             key = str(cell.key)
00352             value = cell.value.flat[0]
00353             #print key, value
00354             d[key] = value
00355         return d
00356     
00357     def _strlist_to_cellarray(self, sl):
00358         """
00359         Helper function for save_mat
00360         """
00361         ca = np.zeros((len(sl),), dtype=np.object)
00362         for i in range(len(sl)):
00363             ca[i] = sl[i]
00364         return ca
00365     
00366         
00367     def get_topics(self):
00368         temp = self.topic_msg_counts.keys()
00369         temp.sort()
00370         return temp
00371             
00372 if __name__ == "__main__":
00373     start = time.time()
00374     bl = BagLoader(sys.argv[1])
00375     end = time.time()
00376     print("Loading from .bag took %f seconds" % (end - start))
00377 
00378     from pylab import *
00379     try:
00380         t = bl.estimator_output___time
00381         h = -bl.estimator_output_pose_pose_position_z
00382     except AttributeError: # kludge..
00383         t = bl.downlink_estimator_output___time
00384         h = -bl.downlink_estimator_output_pose_pose_position_z
00385     plot(t,h)
00386     title('Altitude vs. Time')
00387     xlabel('Time [s]')
00388     ylabel('Altitude [m]')
00389 
00390     print("Testing save to .mat ...")
00391     start = time.time()
00392     bl.save_mat('test.mat')
00393     end = time.time()
00394     print("Saving to .mat took %f seconds" % (end - start))
00395 
00396     print("Testing load from .mat ...")
00397     start = time.time()
00398     bl2 = BagLoader()
00399     bl2.load_mat('test.mat')
00400     end = time.time()
00401     print(bl.topic_msg_counts)
00402     print(bl2.topic_msg_counts)
00403     print("Loading from .mat took %f seconds" % (end - start))
00404 
00405     print("Testing save to .sbag.gz ...")
00406     start = time.time()
00407     bl._save_pickled('test.sbag.gz')
00408     end = time.time()
00409     print("Saving to .sbag.gz took %f seconds" % (end - start))
00410     
00411     print("Testing load from .sbag.gz ...")
00412     bl3 = BagLoader()
00413     start = time.time()
00414     bl3._load_pickled('test.sbag.gz')
00415     end = time.time()
00416     print("Loading from .sbag.gz took %f seconds" % (end - start))
00417     
00418     show()
00419 


starmac_tools
Author(s): bouffard
autogenerated on Sun Jan 5 2014 11:38:35