Go to the documentation of this file.00001
00002 import time
00003 import sys, os, copy
00004 import numpy as np, math
00005 import scipy.ndimage as ni
00006
00007 class occupancy_grid_3d():
00008
00009
00010
00011
00012 def __init__(self, center, size, resolution, data,
00013 occupancy_threshold, to_binary = True):
00014 self.grid_shape = size/resolution
00015 tlb = center + size/2
00016 brf = center + size/2
00017
00018 self.size = size
00019 self.grid = np.reshape(data, self.grid_shape)
00020 self.grid_shape = np.matrix(self.grid.shape).T
00021 self.resolution = resolution
00022 self.center = center
00023
00024 if to_binary:
00025 self.to_binary(occupancy_threshold)
00026
00027
00028
00029 def to_binary(self, occupancy_threshold):
00030 filled = (self.grid >= occupancy_threshold)
00031 self.grid[np.where(filled==True)] = 1
00032 self.grid[np.where(filled==False)] = 0
00033
00034
00035
00036
00037 def grid_to_points(self, array=None):
00038 if array == None:
00039 array = self.grid
00040
00041 idxs = np.where(array == 1)
00042 x_idx = idxs[0]
00043 y_idx = idxs[1]
00044 z_idx = idxs[2]
00045
00046 x = x_idx * self.resolution[0,0] + self.center[0,0] - self.size[0,0]/2
00047 y = y_idx * self.resolution[1,0] + self.center[1,0] - self.size[1,0]/2
00048 z = z_idx * self.resolution[2,0] + self.center[2,0] - self.size[2,0]/2
00049
00050 return np.matrix(np.row_stack([x,y,z]))
00051
00052
00053
00054 def connected_comonents(self, threshold):
00055 connect_structure = np.ones((3,3,3), dtype='int')
00056 grid = self.grid
00057 labeled_arr, n_labels = ni.label(grid, connect_structure)
00058
00059 if n_labels == 0:
00060 return labeled_arr, n_labels
00061
00062 labels_list = range(1,n_labels+1)
00063 count_objects = ni.sum(grid, labeled_arr, labels_list)
00064 if n_labels == 1:
00065 count_objects = [count_objects]
00066
00067 t0 = time.time()
00068 new_labels_list = []
00069
00070 for c,l in zip(count_objects, labels_list):
00071 if c > threshold:
00072 new_labels_list.append(l)
00073 else:
00074 labeled_arr[np.where(labeled_arr == l)] = 0
00075
00076
00077 for nl,l in enumerate(new_labels_list):
00078 labeled_arr[np.where(labeled_arr == l)] = nl+1
00079 n_labels = len(new_labels_list)
00080 t1 = time.time()
00081 print 'time:', t1-t0
00082 return labeled_arr, n_labels
00083
00084
00085
00086
00087 if __name__ == '__main__':
00088 print 'Hello World'
00089