00001
00002
00003
00004
00005
00006
00007
00008
00009 import cPickle
00010 import os
00011 import time
00012 import mxnet as mx
00013 import numpy as np
00014
00015 from module import MutableModule
00016 from utils import image
00017 from bbox.bbox_transform import bbox_pred, clip_boxes
00018 from nms.nms import py_nms_wrapper, cpu_nms_wrapper, gpu_nms_wrapper
00019 from utils.PrefetchingIter import PrefetchingIter
00020
00021
00022 class Predictor(object):
00023 def __init__(self, symbol, data_names, label_names,
00024 context=mx.cpu(), max_data_shapes=None,
00025 provide_data=None, provide_label=None,
00026 arg_params=None, aux_params=None):
00027 self._mod = MutableModule(symbol, data_names, label_names,
00028 context=context, max_data_shapes=max_data_shapes)
00029 self._mod.bind(provide_data, provide_label, for_training=False)
00030 self._mod.init_params(arg_params=arg_params, aux_params=aux_params)
00031
00032 def predict(self, data_batch):
00033 self._mod.forward(data_batch)
00034
00035 return [dict(zip(self._mod.output_names, _)) for _ in zip(*self._mod.get_outputs(merge_multi_context=False))]
00036
00037
00038 def im_proposal(predictor, data_batch, data_names, scales):
00039 output_all = predictor.predict(data_batch)
00040
00041 data_dict_all = [dict(zip(data_names, data_batch.data[i])) for i in xrange(len(data_batch.data))]
00042 scores_all = []
00043 boxes_all = []
00044
00045 for output, data_dict, scale in zip(output_all, data_dict_all, scales):
00046
00047 boxes = output['rois_output'].asnumpy()[:, 1:]
00048 scores = output['rois_score'].asnumpy()
00049
00050
00051 boxes = boxes / scale
00052 scores_all.append(scores)
00053 boxes_all.append(boxes)
00054
00055 return scores_all, boxes_all, data_dict_all
00056
00057
00058 def generate_proposals(predictor, test_data, imdb, cfg, vis=False, thresh=0.):
00059 """
00060 Generate detections results using RPN.
00061 :param predictor: Predictor
00062 :param test_data: data iterator, must be non-shuffled
00063 :param imdb: image database
00064 :param vis: controls visualization
00065 :param thresh: thresh for valid detections
00066 :return: list of detected boxes
00067 """
00068 assert vis or not test_data.shuffle
00069 data_names = [k[0] for k in test_data.provide_data[0]]
00070
00071 if not isinstance(test_data, PrefetchingIter):
00072 test_data = PrefetchingIter(test_data)
00073
00074 idx = 0
00075 t = time.time()
00076 imdb_boxes = list()
00077 original_boxes = list()
00078 for im_info, data_batch in test_data:
00079 t1 = time.time() - t
00080 t = time.time()
00081
00082 scales = [iim_info[0, 2] for iim_info in im_info]
00083 scores_all, boxes_all, data_dict_all = im_proposal(predictor, data_batch, data_names, scales)
00084 t2 = time.time() - t
00085 t = time.time()
00086 for delta, (scores, boxes, data_dict, scale) in enumerate(zip(scores_all, boxes_all, data_dict_all, scales)):
00087
00088 dets = np.hstack((boxes, scores))
00089 original_boxes.append(dets)
00090
00091
00092 keep = np.where(dets[:, 4:] > thresh)[0]
00093 dets = dets[keep, :]
00094 imdb_boxes.append(dets)
00095
00096 if vis:
00097 vis_all_detection(data_dict['data'].asnumpy(), [dets], ['obj'], scale, cfg)
00098
00099 print 'generating %d/%d' % (idx + 1, imdb.num_images), 'proposal %d' % (dets.shape[0]), \
00100 'data %.4fs net %.4fs' % (t1, t2 / test_data.batch_size)
00101 idx += 1
00102
00103
00104 assert len(imdb_boxes) == imdb.num_images, 'calculations not complete'
00105
00106
00107 rpn_folder = os.path.join(imdb.result_path, 'rpn_data')
00108 if not os.path.exists(rpn_folder):
00109 os.mkdir(rpn_folder)
00110
00111 rpn_file = os.path.join(rpn_folder, imdb.name + '_rpn.pkl')
00112 with open(rpn_file, 'wb') as f:
00113 cPickle.dump(imdb_boxes, f, cPickle.HIGHEST_PROTOCOL)
00114
00115 if thresh > 0:
00116 full_rpn_file = os.path.join(rpn_folder, imdb.name + '_full_rpn.pkl')
00117 with open(full_rpn_file, 'wb') as f:
00118 cPickle.dump(original_boxes, f, cPickle.HIGHEST_PROTOCOL)
00119
00120 print 'wrote rpn proposals to {}'.format(rpn_file)
00121 return imdb_boxes
00122
00123
00124 def im_detect(predictor, data_batch, data_names, scales, cfg):
00125 output_all = predictor.predict(data_batch)
00126
00127 data_dict_all = [dict(zip(data_names, idata)) for idata in data_batch.data]
00128 scores_all = []
00129 pred_boxes_all = []
00130 for output, data_dict, scale in zip(output_all, data_dict_all, scales):
00131 if cfg.TEST.HAS_RPN:
00132 rois = output['rois_output'].asnumpy()[:, 1:]
00133 else:
00134 rois = data_dict['rois'].asnumpy().reshape((-1, 5))[:, 1:]
00135 im_shape = data_dict['data'].shape
00136
00137
00138 scores = output['cls_prob_reshape_output'].asnumpy()[0]
00139 bbox_deltas = output['bbox_pred_reshape_output'].asnumpy()[0]
00140
00141
00142 pred_boxes = bbox_pred(rois, bbox_deltas)
00143 pred_boxes = clip_boxes(pred_boxes, im_shape[-2:])
00144
00145
00146 pred_boxes = pred_boxes / scale
00147
00148 scores_all.append(scores)
00149 pred_boxes_all.append(pred_boxes)
00150 return scores_all, pred_boxes_all, data_dict_all
00151
00152
00153 def pred_eval(predictor, test_data, imdb, cfg, vis=False, thresh=1e-3, logger=None, ignore_cache=True):
00154 """
00155 wrapper for calculating offline validation for faster data analysis
00156 in this example, all threshold are set by hand
00157 :param predictor: Predictor
00158 :param test_data: data iterator, must be non-shuffle
00159 :param imdb: image database
00160 :param vis: controls visualization
00161 :param thresh: valid detection threshold
00162 :return:
00163 """
00164
00165 det_file = os.path.join(imdb.result_path, imdb.name + '_detections.pkl')
00166 if os.path.exists(det_file) and not ignore_cache:
00167 with open(det_file, 'rb') as fid:
00168 all_boxes = cPickle.load(fid)
00169 info_str = imdb.evaluate_detections(all_boxes)
00170 if logger:
00171 logger.info('evaluate detections: \n{}'.format(info_str))
00172 return
00173
00174 assert vis or not test_data.shuffle
00175 data_names = [k[0] for k in test_data.provide_data[0]]
00176
00177 if not isinstance(test_data, PrefetchingIter):
00178 test_data = PrefetchingIter(test_data)
00179
00180 nms = py_nms_wrapper(cfg.TEST.NMS)
00181
00182
00183 max_per_image = cfg.TEST.max_per_image
00184
00185 num_images = imdb.num_images
00186
00187
00188
00189 all_boxes = [[[] for _ in range(num_images)]
00190 for _ in range(imdb.num_classes)]
00191
00192 idx = 0
00193 data_time, net_time, post_time = 0.0, 0.0, 0.0
00194 t = time.time()
00195 for im_info, data_batch in test_data:
00196 t1 = time.time() - t
00197 t = time.time()
00198
00199 scales = [iim_info[0, 2] for iim_info in im_info]
00200 scores_all, boxes_all, data_dict_all = im_detect(predictor, data_batch, data_names, scales, cfg)
00201
00202 t2 = time.time() - t
00203 t = time.time()
00204 for delta, (scores, boxes, data_dict) in enumerate(zip(scores_all, boxes_all, data_dict_all)):
00205 for j in range(1, imdb.num_classes):
00206 indexes = np.where(scores[:, j] > thresh)[0]
00207 cls_scores = scores[indexes, j, np.newaxis]
00208 cls_boxes = boxes[indexes, 4:8] if cfg.CLASS_AGNOSTIC else boxes[indexes, j * 4:(j + 1) * 4]
00209 cls_dets = np.hstack((cls_boxes, cls_scores))
00210 keep = nms(cls_dets)
00211 all_boxes[j][idx+delta] = cls_dets[keep, :]
00212
00213 if max_per_image > 0:
00214 image_scores = np.hstack([all_boxes[j][idx+delta][:, -1]
00215 for j in range(1, imdb.num_classes)])
00216 if len(image_scores) > max_per_image:
00217 image_thresh = np.sort(image_scores)[-max_per_image]
00218 for j in range(1, imdb.num_classes):
00219 keep = np.where(all_boxes[j][idx+delta][:, -1] >= image_thresh)[0]
00220 all_boxes[j][idx+delta] = all_boxes[j][idx+delta][keep, :]
00221
00222 if vis:
00223 boxes_this_image = [[]] + [all_boxes[j][idx+delta] for j in range(1, imdb.num_classes)]
00224 vis_all_detection(data_dict['data'].asnumpy(), boxes_this_image, imdb.classes, scales[delta], cfg)
00225
00226 idx += test_data.batch_size
00227 t3 = time.time() - t
00228 t = time.time()
00229 data_time += t1
00230 net_time += t2
00231 post_time += t3
00232 print 'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(idx, imdb.num_images, data_time / idx * test_data.batch_size, net_time / idx * test_data.batch_size, post_time / idx * test_data.batch_size)
00233 if logger:
00234 logger.info('testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(idx, imdb.num_images, data_time / idx * test_data.batch_size, net_time / idx * test_data.batch_size, post_time / idx * test_data.batch_size))
00235
00236 with open(det_file, 'wb') as f:
00237 cPickle.dump(all_boxes, f, protocol=cPickle.HIGHEST_PROTOCOL)
00238
00239 info_str = imdb.evaluate_detections(all_boxes)
00240 if logger:
00241 logger.info('evaluate detections: \n{}'.format(info_str))
00242
00243
00244 def vis_all_detection(im_array, detections, class_names, scale, cfg, threshold=1e-3):
00245 """
00246 visualize all detections in one image
00247 :param im_array: [b=1 c h w] in rgb
00248 :param detections: [ numpy.ndarray([[x1 y1 x2 y2 score]]) for j in classes ]
00249 :param class_names: list of names in imdb
00250 :param scale: visualize the scaled image
00251 :return:
00252 """
00253 import matplotlib.pyplot as plt
00254 import random
00255 im = image.transform_inverse(im_array, cfg.network.PIXEL_MEANS)
00256 plt.imshow(im)
00257 for j, name in enumerate(class_names):
00258 if name == '__background__':
00259 continue
00260 color = (random.random(), random.random(), random.random())
00261 dets = detections[j]
00262 for det in dets:
00263 bbox = det[:4] * scale
00264 score = det[-1]
00265 if score < threshold:
00266 continue
00267 rect = plt.Rectangle((bbox[0], bbox[1]),
00268 bbox[2] - bbox[0],
00269 bbox[3] - bbox[1], fill=False,
00270 edgecolor=color, linewidth=3.5)
00271 plt.gca().add_patch(rect)
00272 plt.gca().text(bbox[0], bbox[1] - 2,
00273 '{:s} {:.3f}'.format(name, score),
00274 bbox=dict(facecolor=color, alpha=0.5), fontsize=12, color='white')
00275 plt.show()
00276
00277
00278 def draw_all_detection(im_array, detections, class_names, scale, cfg, threshold=1e-1):
00279 """
00280 visualize all detections in one image
00281 :param im_array: [b=1 c h w] in rgb
00282 :param detections: [ numpy.ndarray([[x1 y1 x2 y2 score]]) for j in classes ]
00283 :param class_names: list of names in imdb
00284 :param scale: visualize the scaled image
00285 :return:
00286 """
00287 import cv2
00288 import random
00289 color_white = (255, 255, 255)
00290 im = image.transform_inverse(im_array, cfg.network.PIXEL_MEANS)
00291
00292 im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
00293 for j, name in enumerate(class_names):
00294 if name == '__background__':
00295 continue
00296 color = (random.randint(0, 256), random.randint(0, 256), random.randint(0, 256))
00297 dets = detections[j]
00298 for det in dets:
00299 bbox = det[:4] * scale
00300 score = det[-1]
00301 if score < threshold:
00302 continue
00303 bbox = map(int, bbox)
00304 cv2.rectangle(im, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color=color, thickness=2)
00305 cv2.putText(im, '%s %.3f' % (class_names[j], score), (bbox[0], bbox[1] + 10),
00306 color=color_white, fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=0.5)
00307 return im