rpn.py
Go to the documentation of this file.
00001 """
00002 RPN:
00003 data =
00004     {'data': [num_images, c, h, w],
00005      'im_info': [num_images, 4] (optional)}
00006 label =
00007     {'gt_boxes': [num_boxes, 5] (optional),
00008      'label': [batch_size, 1] <- [batch_size, num_anchors, feat_height, feat_width],
00009      'bbox_target': [batch_size, num_anchors, feat_height, feat_width],
00010      'bbox_weight': [batch_size, num_anchors, feat_height, feat_width]}
00011 """
00012 
00013 import numpy as np
00014 import numpy.random as npr
00015 
00016 from utils.image import get_image, tensor_vstack
00017 from generate_anchor import generate_anchors
00018 from bbox.bbox_transform import bbox_overlaps, bbox_transform
00019 
00020 
00021 def get_rpn_testbatch(roidb, cfg):
00022     """
00023     return a dict of testbatch
00024     :param roidb: ['image', 'flipped']
00025     :return: data, label, im_info
00026     """
00027     # assert len(roidb) == 1, 'Single batch only'
00028     imgs, roidb = get_image(roidb, cfg)
00029     im_array = imgs
00030     im_info = [np.array([roidb[i]['im_info']], dtype=np.float32) for i in range(len(roidb))]
00031 
00032     data = [{'data': im_array[i],
00033             'im_info': im_info[i]} for i in range(len(roidb))]
00034     label = {}
00035 
00036     return data, label, im_info
00037 
00038 
00039 def get_rpn_batch(roidb, cfg):
00040     """
00041     prototype for rpn batch: data, im_info, gt_boxes
00042     :param roidb: ['image', 'flipped'] + ['gt_boxes', 'boxes', 'gt_classes']
00043     :return: data, label
00044     """
00045     assert len(roidb) == 1, 'Single batch only'
00046     imgs, roidb = get_image(roidb, cfg)
00047     im_array = imgs[0]
00048     im_info = np.array([roidb[0]['im_info']], dtype=np.float32)
00049 
00050     # gt boxes: (x1, y1, x2, y2, cls)
00051     if roidb[0]['gt_classes'].size > 0:
00052         gt_inds = np.where(roidb[0]['gt_classes'] != 0)[0]
00053         gt_boxes = np.empty((roidb[0]['boxes'].shape[0], 5), dtype=np.float32)
00054         gt_boxes[:, 0:4] = roidb[0]['boxes'][gt_inds, :]
00055         gt_boxes[:, 4] = roidb[0]['gt_classes'][gt_inds]
00056     else:
00057         gt_boxes = np.empty((0, 5), dtype=np.float32)
00058 
00059     data = {'data': im_array,
00060             'im_info': im_info}
00061     label = {'gt_boxes': gt_boxes}
00062 
00063     return data, label
00064 
00065 
00066 def assign_anchor(feat_shape, gt_boxes, im_info, cfg, feat_stride=16,
00067                   scales=(8, 16, 32), ratios=(0.5, 1, 2), allowed_border=0):
00068     """
00069     assign ground truth boxes to anchor positions
00070     :param feat_shape: infer output shape
00071     :param gt_boxes: assign ground truth
00072     :param im_info: filter out anchors overlapped with edges
00073     :param feat_stride: anchor position step
00074     :param scales: used to generate anchors, affects num_anchors (per location)
00075     :param ratios: aspect ratios of generated anchors
00076     :param allowed_border: filter out anchors with edge overlap > allowed_border
00077     :return: dict of label
00078     'label': of shape (batch_size, 1) <- (batch_size, num_anchors, feat_height, feat_width)
00079     'bbox_target': of shape (batch_size, num_anchors * 4, feat_height, feat_width)
00080     'bbox_inside_weight': *todo* mark the assigned anchors
00081     'bbox_outside_weight': used to normalize the bbox_loss, all weights sums to RPN_POSITIVE_WEIGHT
00082     """
00083     def _unmap(data, count, inds, fill=0):
00084         """" unmap a subset inds of data into original data of size count """
00085         if len(data.shape) == 1:
00086             ret = np.empty((count,), dtype=np.float32)
00087             ret.fill(fill)
00088             ret[inds] = data
00089         else:
00090             ret = np.empty((count,) + data.shape[1:], dtype=np.float32)
00091             ret.fill(fill)
00092             ret[inds, :] = data
00093         return ret
00094 
00095     DEBUG = False
00096     im_info = im_info[0]
00097     scales = np.array(scales, dtype=np.float32)
00098     base_anchors = generate_anchors(base_size=feat_stride, ratios=list(ratios), scales=scales)
00099     num_anchors = base_anchors.shape[0]
00100     feat_height, feat_width = feat_shape[-2:]
00101 
00102     if DEBUG:
00103         print 'anchors:'
00104         print base_anchors
00105         print 'anchor shapes:'
00106         print np.hstack((base_anchors[:, 2::4] - base_anchors[:, 0::4],
00107                          base_anchors[:, 3::4] - base_anchors[:, 1::4]))
00108         print 'im_info', im_info
00109         print 'height', feat_height, 'width', feat_width
00110         print 'gt_boxes shape', gt_boxes.shape
00111         print 'gt_boxes', gt_boxes
00112 
00113     # 1. generate proposals from bbox deltas and shifted anchors
00114     shift_x = np.arange(0, feat_width) * feat_stride
00115     shift_y = np.arange(0, feat_height) * feat_stride
00116     shift_x, shift_y = np.meshgrid(shift_x, shift_y)
00117     shifts = np.vstack((shift_x.ravel(), shift_y.ravel(), shift_x.ravel(), shift_y.ravel())).transpose()
00118     # add A anchors (1, A, 4) to
00119     # cell K shifts (K, 1, 4) to get
00120     # shift anchors (K, A, 4)
00121     # reshape to (K*A, 4) shifted anchors
00122     A = num_anchors
00123     K = shifts.shape[0]
00124     all_anchors = base_anchors.reshape((1, A, 4)) + shifts.reshape((1, K, 4)).transpose((1, 0, 2))
00125     all_anchors = all_anchors.reshape((K * A, 4))
00126     total_anchors = int(K * A)
00127 
00128     # only keep anchors inside the image
00129     inds_inside = np.where((all_anchors[:, 0] >= -allowed_border) &
00130                            (all_anchors[:, 1] >= -allowed_border) &
00131                            (all_anchors[:, 2] < im_info[1] + allowed_border) &
00132                            (all_anchors[:, 3] < im_info[0] + allowed_border))[0]
00133     if DEBUG:
00134         print 'total_anchors', total_anchors
00135         print 'inds_inside', len(inds_inside)
00136 
00137     # keep only inside anchors
00138     anchors = all_anchors[inds_inside, :]
00139     if DEBUG:
00140         print 'anchors shape', anchors.shape
00141 
00142     # label: 1 is positive, 0 is negative, -1 is dont care
00143     labels = np.empty((len(inds_inside),), dtype=np.float32)
00144     labels.fill(-1)
00145 
00146     if gt_boxes.size > 0:
00147         # overlap between the anchors and the gt boxes
00148         # overlaps (ex, gt)
00149         overlaps = bbox_overlaps(anchors.astype(np.float), gt_boxes.astype(np.float))
00150         argmax_overlaps = overlaps.argmax(axis=1)
00151         max_overlaps = overlaps[np.arange(len(inds_inside)), argmax_overlaps]
00152         gt_argmax_overlaps = overlaps.argmax(axis=0)
00153         gt_max_overlaps = overlaps[gt_argmax_overlaps, np.arange(overlaps.shape[1])]
00154         gt_argmax_overlaps = np.where(overlaps == gt_max_overlaps)[0]
00155 
00156         if not cfg.TRAIN.RPN_CLOBBER_POSITIVES:
00157             # assign bg labels first so that positive labels can clobber them
00158             labels[max_overlaps < cfg.TRAIN.RPN_NEGATIVE_OVERLAP] = 0
00159 
00160         # fg label: for each gt, anchor with highest overlap
00161         labels[gt_argmax_overlaps] = 1
00162 
00163         # fg label: above threshold IoU
00164         labels[max_overlaps >= cfg.TRAIN.RPN_POSITIVE_OVERLAP] = 1
00165 
00166         if cfg.TRAIN.RPN_CLOBBER_POSITIVES:
00167             # assign bg labels last so that negative labels can clobber positives
00168             labels[max_overlaps < cfg.TRAIN.RPN_NEGATIVE_OVERLAP] = 0
00169     else:
00170         labels[:] = 0
00171 
00172     # subsample positive labels if we have too many
00173     num_fg = int(cfg.TRAIN.RPN_FG_FRACTION * cfg.TRAIN.RPN_BATCH_SIZE)
00174     fg_inds = np.where(labels == 1)[0]
00175     if len(fg_inds) > num_fg:
00176         disable_inds = npr.choice(fg_inds, size=(len(fg_inds) - num_fg), replace=False)
00177         if DEBUG:
00178             disable_inds = fg_inds[:(len(fg_inds) - num_fg)]
00179         labels[disable_inds] = -1
00180 
00181     # subsample negative labels if we have too many
00182     num_bg = cfg.TRAIN.RPN_BATCH_SIZE - np.sum(labels == 1)
00183     bg_inds = np.where(labels == 0)[0]
00184     if len(bg_inds) > num_bg:
00185         disable_inds = npr.choice(bg_inds, size=(len(bg_inds) - num_bg), replace=False)
00186         if DEBUG:
00187             disable_inds = bg_inds[:(len(bg_inds) - num_bg)]
00188         labels[disable_inds] = -1
00189 
00190     bbox_targets = np.zeros((len(inds_inside), 4), dtype=np.float32)
00191     if gt_boxes.size > 0:
00192         bbox_targets[:] = bbox_transform(anchors, gt_boxes[argmax_overlaps, :4])
00193 
00194     bbox_weights = np.zeros((len(inds_inside), 4), dtype=np.float32)
00195     bbox_weights[labels == 1, :] = np.array(cfg.TRAIN.RPN_BBOX_WEIGHTS)
00196 
00197     if DEBUG:
00198         _sums = bbox_targets[labels == 1, :].sum(axis=0)
00199         _squared_sums = (bbox_targets[labels == 1, :] ** 2).sum(axis=0)
00200         _counts = np.sum(labels == 1)
00201         means = _sums / (_counts + 1e-14)
00202         stds = np.sqrt(_squared_sums / _counts - means ** 2)
00203         print 'means', means
00204         print 'stdevs', stds
00205 
00206     # map up to original set of anchors
00207     labels = _unmap(labels, total_anchors, inds_inside, fill=-1)
00208     bbox_targets = _unmap(bbox_targets, total_anchors, inds_inside, fill=0)
00209     bbox_weights = _unmap(bbox_weights, total_anchors, inds_inside, fill=0)
00210 
00211     if DEBUG:
00212         print 'rpn: max max_overlaps', np.max(max_overlaps)
00213         print 'rpn: num_positives', np.sum(labels == 1)
00214         print 'rpn: num_negatives', np.sum(labels == 0)
00215         _fg_sum = np.sum(labels == 1)
00216         _bg_sum = np.sum(labels == 0)
00217         _count = 1
00218         print 'rpn: num_positive avg', _fg_sum / _count
00219         print 'rpn: num_negative avg', _bg_sum / _count
00220 
00221     labels = labels.reshape((1, feat_height, feat_width, A)).transpose(0, 3, 1, 2)
00222     labels = labels.reshape((1, A * feat_height * feat_width))
00223     bbox_targets = bbox_targets.reshape((1, feat_height, feat_width, A * 4)).transpose(0, 3, 1, 2)
00224     bbox_weights = bbox_weights.reshape((1, feat_height, feat_width, A * 4)).transpose((0, 3, 1, 2))
00225 
00226     label = {'label': labels,
00227              'bbox_target': bbox_targets,
00228              'bbox_weight': bbox_weights}
00229     return label


rail_object_detector
Author(s):
autogenerated on Sat Jun 8 2019 20:26:30