00001 """
00002 RPN:
00003 data =
00004 {'data': [num_images, c, h, w],
00005 'im_info': [num_images, 4] (optional)}
00006 label =
00007 {'gt_boxes': [num_boxes, 5] (optional),
00008 'label': [batch_size, 1] <- [batch_size, num_anchors, feat_height, feat_width],
00009 'bbox_target': [batch_size, num_anchors, feat_height, feat_width],
00010 'bbox_weight': [batch_size, num_anchors, feat_height, feat_width]}
00011 """
00012
00013 import numpy as np
00014 import numpy.random as npr
00015
00016 from utils.image import get_image, tensor_vstack
00017 from generate_anchor import generate_anchors
00018 from bbox.bbox_transform import bbox_overlaps, bbox_transform
00019
00020
00021 def get_rpn_testbatch(roidb, cfg):
00022 """
00023 return a dict of testbatch
00024 :param roidb: ['image', 'flipped']
00025 :return: data, label, im_info
00026 """
00027
00028 imgs, roidb = get_image(roidb, cfg)
00029 im_array = imgs
00030 im_info = [np.array([roidb[i]['im_info']], dtype=np.float32) for i in range(len(roidb))]
00031
00032 data = [{'data': im_array[i],
00033 'im_info': im_info[i]} for i in range(len(roidb))]
00034 label = {}
00035
00036 return data, label, im_info
00037
00038
00039 def get_rpn_batch(roidb, cfg):
00040 """
00041 prototype for rpn batch: data, im_info, gt_boxes
00042 :param roidb: ['image', 'flipped'] + ['gt_boxes', 'boxes', 'gt_classes']
00043 :return: data, label
00044 """
00045 assert len(roidb) == 1, 'Single batch only'
00046 imgs, roidb = get_image(roidb, cfg)
00047 im_array = imgs[0]
00048 im_info = np.array([roidb[0]['im_info']], dtype=np.float32)
00049
00050
00051 if roidb[0]['gt_classes'].size > 0:
00052 gt_inds = np.where(roidb[0]['gt_classes'] != 0)[0]
00053 gt_boxes = np.empty((roidb[0]['boxes'].shape[0], 5), dtype=np.float32)
00054 gt_boxes[:, 0:4] = roidb[0]['boxes'][gt_inds, :]
00055 gt_boxes[:, 4] = roidb[0]['gt_classes'][gt_inds]
00056 else:
00057 gt_boxes = np.empty((0, 5), dtype=np.float32)
00058
00059 data = {'data': im_array,
00060 'im_info': im_info}
00061 label = {'gt_boxes': gt_boxes}
00062
00063 return data, label
00064
00065
00066 def assign_anchor(feat_shape, gt_boxes, im_info, cfg, feat_stride=16,
00067 scales=(8, 16, 32), ratios=(0.5, 1, 2), allowed_border=0):
00068 """
00069 assign ground truth boxes to anchor positions
00070 :param feat_shape: infer output shape
00071 :param gt_boxes: assign ground truth
00072 :param im_info: filter out anchors overlapped with edges
00073 :param feat_stride: anchor position step
00074 :param scales: used to generate anchors, affects num_anchors (per location)
00075 :param ratios: aspect ratios of generated anchors
00076 :param allowed_border: filter out anchors with edge overlap > allowed_border
00077 :return: dict of label
00078 'label': of shape (batch_size, 1) <- (batch_size, num_anchors, feat_height, feat_width)
00079 'bbox_target': of shape (batch_size, num_anchors * 4, feat_height, feat_width)
00080 'bbox_inside_weight': *todo* mark the assigned anchors
00081 'bbox_outside_weight': used to normalize the bbox_loss, all weights sums to RPN_POSITIVE_WEIGHT
00082 """
00083 def _unmap(data, count, inds, fill=0):
00084 """" unmap a subset inds of data into original data of size count """
00085 if len(data.shape) == 1:
00086 ret = np.empty((count,), dtype=np.float32)
00087 ret.fill(fill)
00088 ret[inds] = data
00089 else:
00090 ret = np.empty((count,) + data.shape[1:], dtype=np.float32)
00091 ret.fill(fill)
00092 ret[inds, :] = data
00093 return ret
00094
00095 DEBUG = False
00096 im_info = im_info[0]
00097 scales = np.array(scales, dtype=np.float32)
00098 base_anchors = generate_anchors(base_size=feat_stride, ratios=list(ratios), scales=scales)
00099 num_anchors = base_anchors.shape[0]
00100 feat_height, feat_width = feat_shape[-2:]
00101
00102 if DEBUG:
00103 print 'anchors:'
00104 print base_anchors
00105 print 'anchor shapes:'
00106 print np.hstack((base_anchors[:, 2::4] - base_anchors[:, 0::4],
00107 base_anchors[:, 3::4] - base_anchors[:, 1::4]))
00108 print 'im_info', im_info
00109 print 'height', feat_height, 'width', feat_width
00110 print 'gt_boxes shape', gt_boxes.shape
00111 print 'gt_boxes', gt_boxes
00112
00113
00114 shift_x = np.arange(0, feat_width) * feat_stride
00115 shift_y = np.arange(0, feat_height) * feat_stride
00116 shift_x, shift_y = np.meshgrid(shift_x, shift_y)
00117 shifts = np.vstack((shift_x.ravel(), shift_y.ravel(), shift_x.ravel(), shift_y.ravel())).transpose()
00118
00119
00120
00121
00122 A = num_anchors
00123 K = shifts.shape[0]
00124 all_anchors = base_anchors.reshape((1, A, 4)) + shifts.reshape((1, K, 4)).transpose((1, 0, 2))
00125 all_anchors = all_anchors.reshape((K * A, 4))
00126 total_anchors = int(K * A)
00127
00128
00129 inds_inside = np.where((all_anchors[:, 0] >= -allowed_border) &
00130 (all_anchors[:, 1] >= -allowed_border) &
00131 (all_anchors[:, 2] < im_info[1] + allowed_border) &
00132 (all_anchors[:, 3] < im_info[0] + allowed_border))[0]
00133 if DEBUG:
00134 print 'total_anchors', total_anchors
00135 print 'inds_inside', len(inds_inside)
00136
00137
00138 anchors = all_anchors[inds_inside, :]
00139 if DEBUG:
00140 print 'anchors shape', anchors.shape
00141
00142
00143 labels = np.empty((len(inds_inside),), dtype=np.float32)
00144 labels.fill(-1)
00145
00146 if gt_boxes.size > 0:
00147
00148
00149 overlaps = bbox_overlaps(anchors.astype(np.float), gt_boxes.astype(np.float))
00150 argmax_overlaps = overlaps.argmax(axis=1)
00151 max_overlaps = overlaps[np.arange(len(inds_inside)), argmax_overlaps]
00152 gt_argmax_overlaps = overlaps.argmax(axis=0)
00153 gt_max_overlaps = overlaps[gt_argmax_overlaps, np.arange(overlaps.shape[1])]
00154 gt_argmax_overlaps = np.where(overlaps == gt_max_overlaps)[0]
00155
00156 if not cfg.TRAIN.RPN_CLOBBER_POSITIVES:
00157
00158 labels[max_overlaps < cfg.TRAIN.RPN_NEGATIVE_OVERLAP] = 0
00159
00160
00161 labels[gt_argmax_overlaps] = 1
00162
00163
00164 labels[max_overlaps >= cfg.TRAIN.RPN_POSITIVE_OVERLAP] = 1
00165
00166 if cfg.TRAIN.RPN_CLOBBER_POSITIVES:
00167
00168 labels[max_overlaps < cfg.TRAIN.RPN_NEGATIVE_OVERLAP] = 0
00169 else:
00170 labels[:] = 0
00171
00172
00173 num_fg = int(cfg.TRAIN.RPN_FG_FRACTION * cfg.TRAIN.RPN_BATCH_SIZE)
00174 fg_inds = np.where(labels == 1)[0]
00175 if len(fg_inds) > num_fg:
00176 disable_inds = npr.choice(fg_inds, size=(len(fg_inds) - num_fg), replace=False)
00177 if DEBUG:
00178 disable_inds = fg_inds[:(len(fg_inds) - num_fg)]
00179 labels[disable_inds] = -1
00180
00181
00182 num_bg = cfg.TRAIN.RPN_BATCH_SIZE - np.sum(labels == 1)
00183 bg_inds = np.where(labels == 0)[0]
00184 if len(bg_inds) > num_bg:
00185 disable_inds = npr.choice(bg_inds, size=(len(bg_inds) - num_bg), replace=False)
00186 if DEBUG:
00187 disable_inds = bg_inds[:(len(bg_inds) - num_bg)]
00188 labels[disable_inds] = -1
00189
00190 bbox_targets = np.zeros((len(inds_inside), 4), dtype=np.float32)
00191 if gt_boxes.size > 0:
00192 bbox_targets[:] = bbox_transform(anchors, gt_boxes[argmax_overlaps, :4])
00193
00194 bbox_weights = np.zeros((len(inds_inside), 4), dtype=np.float32)
00195 bbox_weights[labels == 1, :] = np.array(cfg.TRAIN.RPN_BBOX_WEIGHTS)
00196
00197 if DEBUG:
00198 _sums = bbox_targets[labels == 1, :].sum(axis=0)
00199 _squared_sums = (bbox_targets[labels == 1, :] ** 2).sum(axis=0)
00200 _counts = np.sum(labels == 1)
00201 means = _sums / (_counts + 1e-14)
00202 stds = np.sqrt(_squared_sums / _counts - means ** 2)
00203 print 'means', means
00204 print 'stdevs', stds
00205
00206
00207 labels = _unmap(labels, total_anchors, inds_inside, fill=-1)
00208 bbox_targets = _unmap(bbox_targets, total_anchors, inds_inside, fill=0)
00209 bbox_weights = _unmap(bbox_weights, total_anchors, inds_inside, fill=0)
00210
00211 if DEBUG:
00212 print 'rpn: max max_overlaps', np.max(max_overlaps)
00213 print 'rpn: num_positives', np.sum(labels == 1)
00214 print 'rpn: num_negatives', np.sum(labels == 0)
00215 _fg_sum = np.sum(labels == 1)
00216 _bg_sum = np.sum(labels == 0)
00217 _count = 1
00218 print 'rpn: num_positive avg', _fg_sum / _count
00219 print 'rpn: num_negative avg', _bg_sum / _count
00220
00221 labels = labels.reshape((1, feat_height, feat_width, A)).transpose(0, 3, 1, 2)
00222 labels = labels.reshape((1, A * feat_height * feat_width))
00223 bbox_targets = bbox_targets.reshape((1, feat_height, feat_width, A * 4)).transpose(0, 3, 1, 2)
00224 bbox_weights = bbox_weights.reshape((1, feat_height, feat_width, A * 4)).transpose((0, 3, 1, 2))
00225
00226 label = {'label': labels,
00227 'bbox_target': bbox_targets,
00228 'bbox_weight': bbox_weights}
00229 return label