tester.py
Go to the documentation of this file.
00001 # --------------------------------------------------------
00002 # Deformable Convolutional Networks
00003 # Copyright (c) 2016 by Contributors
00004 # Copyright (c) 2017 Microsoft
00005 # Licensed under The Apache-2.0 License [see LICENSE for details]
00006 # Modified by Yuwen Xiong
00007 # --------------------------------------------------------
00008 
00009 import cPickle
00010 import os
00011 import time
00012 import mxnet as mx
00013 import numpy as np
00014 
00015 from module import MutableModule
00016 from utils import image
00017 from bbox.bbox_transform import bbox_pred, clip_boxes
00018 from nms.nms import py_nms_wrapper, cpu_nms_wrapper, gpu_nms_wrapper
00019 from utils.PrefetchingIter import PrefetchingIter
00020 
00021 
00022 class Predictor(object):
00023     def __init__(self, symbol, data_names, label_names,
00024                  context=mx.cpu(), max_data_shapes=None,
00025                  provide_data=None, provide_label=None,
00026                  arg_params=None, aux_params=None):
00027         self._mod = MutableModule(symbol, data_names, label_names,
00028                                   context=context, max_data_shapes=max_data_shapes)
00029         self._mod.bind(provide_data, provide_label, for_training=False)
00030         self._mod.init_params(arg_params=arg_params, aux_params=aux_params)
00031 
00032     def predict(self, data_batch):
00033         self._mod.forward(data_batch)
00034         # [dict(zip(self._mod.output_names, _)) for _ in zip(*self._mod.get_outputs(merge_multi_context=False))]
00035         return [dict(zip(self._mod.output_names, _)) for _ in zip(*self._mod.get_outputs(merge_multi_context=False))]
00036 
00037 
00038 def im_proposal(predictor, data_batch, data_names, scales):
00039     output_all = predictor.predict(data_batch)
00040 
00041     data_dict_all = [dict(zip(data_names, data_batch.data[i])) for i in xrange(len(data_batch.data))]
00042     scores_all = []
00043     boxes_all = []
00044 
00045     for output, data_dict, scale in zip(output_all, data_dict_all, scales):
00046         # drop the batch index
00047         boxes = output['rois_output'].asnumpy()[:, 1:]
00048         scores = output['rois_score'].asnumpy()
00049 
00050         # transform to original scale
00051         boxes = boxes / scale
00052         scores_all.append(scores)
00053         boxes_all.append(boxes)
00054 
00055     return scores_all, boxes_all, data_dict_all
00056 
00057 
00058 def generate_proposals(predictor, test_data, imdb, cfg, vis=False, thresh=0.):
00059     """
00060     Generate detections results using RPN.
00061     :param predictor: Predictor
00062     :param test_data: data iterator, must be non-shuffled
00063     :param imdb: image database
00064     :param vis: controls visualization
00065     :param thresh: thresh for valid detections
00066     :return: list of detected boxes
00067     """
00068     assert vis or not test_data.shuffle
00069     data_names = [k[0] for k in test_data.provide_data[0]]
00070 
00071     if not isinstance(test_data, PrefetchingIter):
00072         test_data = PrefetchingIter(test_data)
00073 
00074     idx = 0
00075     t = time.time()
00076     imdb_boxes = list()
00077     original_boxes = list()
00078     for im_info, data_batch in test_data:
00079         t1 = time.time() - t
00080         t = time.time()
00081 
00082         scales = [iim_info[0, 2] for iim_info in im_info]
00083         scores_all, boxes_all, data_dict_all = im_proposal(predictor, data_batch, data_names, scales)
00084         t2 = time.time() - t
00085         t = time.time()
00086         for delta, (scores, boxes, data_dict, scale) in enumerate(zip(scores_all, boxes_all, data_dict_all, scales)):
00087             # assemble proposals
00088             dets = np.hstack((boxes, scores))
00089             original_boxes.append(dets)
00090 
00091             # filter proposals
00092             keep = np.where(dets[:, 4:] > thresh)[0]
00093             dets = dets[keep, :]
00094             imdb_boxes.append(dets)
00095 
00096             if vis:
00097                 vis_all_detection(data_dict['data'].asnumpy(), [dets], ['obj'], scale, cfg)
00098 
00099             print 'generating %d/%d' % (idx + 1, imdb.num_images), 'proposal %d' % (dets.shape[0]), \
00100                 'data %.4fs net %.4fs' % (t1, t2 / test_data.batch_size)
00101             idx += 1
00102 
00103 
00104     assert len(imdb_boxes) == imdb.num_images, 'calculations not complete'
00105 
00106     # save results
00107     rpn_folder = os.path.join(imdb.result_path, 'rpn_data')
00108     if not os.path.exists(rpn_folder):
00109         os.mkdir(rpn_folder)
00110 
00111     rpn_file = os.path.join(rpn_folder, imdb.name + '_rpn.pkl')
00112     with open(rpn_file, 'wb') as f:
00113         cPickle.dump(imdb_boxes, f, cPickle.HIGHEST_PROTOCOL)
00114 
00115     if thresh > 0:
00116         full_rpn_file = os.path.join(rpn_folder, imdb.name + '_full_rpn.pkl')
00117         with open(full_rpn_file, 'wb') as f:
00118             cPickle.dump(original_boxes, f, cPickle.HIGHEST_PROTOCOL)
00119 
00120     print 'wrote rpn proposals to {}'.format(rpn_file)
00121     return imdb_boxes
00122 
00123 
00124 def im_detect(predictor, data_batch, data_names, scales, cfg):
00125     output_all = predictor.predict(data_batch)
00126 
00127     data_dict_all = [dict(zip(data_names, idata)) for idata in data_batch.data]
00128     scores_all = []
00129     pred_boxes_all = []
00130     for output, data_dict, scale in zip(output_all, data_dict_all, scales):
00131         if cfg.TEST.HAS_RPN:
00132             rois = output['rois_output'].asnumpy()[:, 1:]
00133         else:
00134             rois = data_dict['rois'].asnumpy().reshape((-1, 5))[:, 1:]
00135         im_shape = data_dict['data'].shape
00136 
00137         # save output
00138         scores = output['cls_prob_reshape_output'].asnumpy()[0]
00139         bbox_deltas = output['bbox_pred_reshape_output'].asnumpy()[0]
00140 
00141         # post processing
00142         pred_boxes = bbox_pred(rois, bbox_deltas)
00143         pred_boxes = clip_boxes(pred_boxes, im_shape[-2:])
00144 
00145         # we used scaled image & roi to train, so it is necessary to transform them back
00146         pred_boxes = pred_boxes / scale
00147 
00148         scores_all.append(scores)
00149         pred_boxes_all.append(pred_boxes)
00150     return scores_all, pred_boxes_all, data_dict_all
00151 
00152 
00153 def pred_eval(predictor, test_data, imdb, cfg, vis=False, thresh=1e-3, logger=None, ignore_cache=True):
00154     """
00155     wrapper for calculating offline validation for faster data analysis
00156     in this example, all threshold are set by hand
00157     :param predictor: Predictor
00158     :param test_data: data iterator, must be non-shuffle
00159     :param imdb: image database
00160     :param vis: controls visualization
00161     :param thresh: valid detection threshold
00162     :return:
00163     """
00164 
00165     det_file = os.path.join(imdb.result_path, imdb.name + '_detections.pkl')
00166     if os.path.exists(det_file) and not ignore_cache:
00167         with open(det_file, 'rb') as fid:
00168             all_boxes = cPickle.load(fid)
00169         info_str = imdb.evaluate_detections(all_boxes)
00170         if logger:
00171             logger.info('evaluate detections: \n{}'.format(info_str))
00172         return
00173 
00174     assert vis or not test_data.shuffle
00175     data_names = [k[0] for k in test_data.provide_data[0]]
00176 
00177     if not isinstance(test_data, PrefetchingIter):
00178         test_data = PrefetchingIter(test_data)
00179 
00180     nms = py_nms_wrapper(cfg.TEST.NMS)
00181 
00182     # limit detections to max_per_image over all classes
00183     max_per_image = cfg.TEST.max_per_image
00184 
00185     num_images = imdb.num_images
00186     # all detections are collected into:
00187     #    all_boxes[cls][image] = N x 5 array of detections in
00188     #    (x1, y1, x2, y2, score)
00189     all_boxes = [[[] for _ in range(num_images)]
00190                  for _ in range(imdb.num_classes)]
00191 
00192     idx = 0
00193     data_time, net_time, post_time = 0.0, 0.0, 0.0
00194     t = time.time()
00195     for im_info, data_batch in test_data:
00196         t1 = time.time() - t
00197         t = time.time()
00198 
00199         scales = [iim_info[0, 2] for iim_info in im_info]
00200         scores_all, boxes_all, data_dict_all = im_detect(predictor, data_batch, data_names, scales, cfg)
00201 
00202         t2 = time.time() - t
00203         t = time.time()
00204         for delta, (scores, boxes, data_dict) in enumerate(zip(scores_all, boxes_all, data_dict_all)):
00205             for j in range(1, imdb.num_classes):
00206                 indexes = np.where(scores[:, j] > thresh)[0]
00207                 cls_scores = scores[indexes, j, np.newaxis]
00208                 cls_boxes = boxes[indexes, 4:8] if cfg.CLASS_AGNOSTIC else boxes[indexes, j * 4:(j + 1) * 4]
00209                 cls_dets = np.hstack((cls_boxes, cls_scores))
00210                 keep = nms(cls_dets)
00211                 all_boxes[j][idx+delta] = cls_dets[keep, :]
00212 
00213             if max_per_image > 0:
00214                 image_scores = np.hstack([all_boxes[j][idx+delta][:, -1]
00215                                           for j in range(1, imdb.num_classes)])
00216                 if len(image_scores) > max_per_image:
00217                     image_thresh = np.sort(image_scores)[-max_per_image]
00218                     for j in range(1, imdb.num_classes):
00219                         keep = np.where(all_boxes[j][idx+delta][:, -1] >= image_thresh)[0]
00220                         all_boxes[j][idx+delta] = all_boxes[j][idx+delta][keep, :]
00221 
00222             if vis:
00223                 boxes_this_image = [[]] + [all_boxes[j][idx+delta] for j in range(1, imdb.num_classes)]
00224                 vis_all_detection(data_dict['data'].asnumpy(), boxes_this_image, imdb.classes, scales[delta], cfg)
00225 
00226         idx += test_data.batch_size
00227         t3 = time.time() - t
00228         t = time.time()
00229         data_time += t1
00230         net_time += t2
00231         post_time += t3
00232         print 'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(idx, imdb.num_images, data_time / idx * test_data.batch_size, net_time / idx * test_data.batch_size, post_time / idx * test_data.batch_size)
00233         if logger:
00234             logger.info('testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(idx, imdb.num_images, data_time / idx * test_data.batch_size, net_time / idx * test_data.batch_size, post_time / idx * test_data.batch_size))
00235 
00236     with open(det_file, 'wb') as f:
00237         cPickle.dump(all_boxes, f, protocol=cPickle.HIGHEST_PROTOCOL)
00238 
00239     info_str = imdb.evaluate_detections(all_boxes)
00240     if logger:
00241         logger.info('evaluate detections: \n{}'.format(info_str))
00242 
00243 
00244 def vis_all_detection(im_array, detections, class_names, scale, cfg, threshold=1e-3):
00245     """
00246     visualize all detections in one image
00247     :param im_array: [b=1 c h w] in rgb
00248     :param detections: [ numpy.ndarray([[x1 y1 x2 y2 score]]) for j in classes ]
00249     :param class_names: list of names in imdb
00250     :param scale: visualize the scaled image
00251     :return:
00252     """
00253     import matplotlib.pyplot as plt
00254     import random
00255     im = image.transform_inverse(im_array, cfg.network.PIXEL_MEANS)
00256     plt.imshow(im)
00257     for j, name in enumerate(class_names):
00258         if name == '__background__':
00259             continue
00260         color = (random.random(), random.random(), random.random())  # generate a random color
00261         dets = detections[j]
00262         for det in dets:
00263             bbox = det[:4] * scale
00264             score = det[-1]
00265             if score < threshold:
00266                 continue
00267             rect = plt.Rectangle((bbox[0], bbox[1]),
00268                                  bbox[2] - bbox[0],
00269                                  bbox[3] - bbox[1], fill=False,
00270                                  edgecolor=color, linewidth=3.5)
00271             plt.gca().add_patch(rect)
00272             plt.gca().text(bbox[0], bbox[1] - 2,
00273                            '{:s} {:.3f}'.format(name, score),
00274                            bbox=dict(facecolor=color, alpha=0.5), fontsize=12, color='white')
00275     plt.show()
00276 
00277 
00278 def draw_all_detection(im_array, detections, class_names, scale, cfg, threshold=1e-1):
00279     """
00280     visualize all detections in one image
00281     :param im_array: [b=1 c h w] in rgb
00282     :param detections: [ numpy.ndarray([[x1 y1 x2 y2 score]]) for j in classes ]
00283     :param class_names: list of names in imdb
00284     :param scale: visualize the scaled image
00285     :return:
00286     """
00287     import cv2
00288     import random
00289     color_white = (255, 255, 255)
00290     im = image.transform_inverse(im_array, cfg.network.PIXEL_MEANS)
00291     # change to bgr
00292     im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
00293     for j, name in enumerate(class_names):
00294         if name == '__background__':
00295             continue
00296         color = (random.randint(0, 256), random.randint(0, 256), random.randint(0, 256))  # generate a random color
00297         dets = detections[j]
00298         for det in dets:
00299             bbox = det[:4] * scale
00300             score = det[-1]
00301             if score < threshold:
00302                 continue
00303             bbox = map(int, bbox)
00304             cv2.rectangle(im, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color=color, thickness=2)
00305             cv2.putText(im, '%s %.3f' % (class_names[j], score), (bbox[0], bbox[1] + 10),
00306                         color=color_white, fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=0.5)
00307     return im


rail_object_detector
Author(s):
autogenerated on Sat Jun 8 2019 20:26:31