00001
00002
00003
00004
00005
00006
00007
00008
00009 """
00010 Proposal Target Operator selects foreground and background roi and assigns label, bbox_transform to them.
00011 """
00012
00013 import mxnet as mx
00014 import numpy as np
00015 from distutils.util import strtobool
00016 from easydict import EasyDict as edict
00017 import cPickle
00018
00019
00020 from core.rcnn import sample_rois
00021
00022 DEBUG = False
00023
00024
00025 class ProposalTargetOperator(mx.operator.CustomOp):
00026 def __init__(self, num_classes, batch_images, batch_rois, cfg, fg_fraction):
00027 super(ProposalTargetOperator, self).__init__()
00028 self._num_classes = num_classes
00029 self._batch_images = batch_images
00030 self._batch_rois = batch_rois
00031 self._cfg = cfg
00032 self._fg_fraction = fg_fraction
00033
00034 if DEBUG:
00035 self._count = 0
00036 self._fg_num = 0
00037 self._bg_num = 0
00038
00039 def forward(self, is_train, req, in_data, out_data, aux):
00040 assert self._batch_rois == -1 or self._batch_rois % self._batch_images == 0, \
00041 'batchimages {} must devide batch_rois {}'.format(self._batch_images, self._batch_rois)
00042 all_rois = in_data[0].asnumpy()
00043 gt_boxes = in_data[1].asnumpy()
00044
00045 if self._batch_rois == -1:
00046 rois_per_image = all_rois.shape[0] + gt_boxes.shape[0]
00047 fg_rois_per_image = rois_per_image
00048 else:
00049 rois_per_image = self._batch_rois / self._batch_images
00050 fg_rois_per_image = np.round(self._fg_fraction * rois_per_image).astype(int)
00051
00052
00053
00054 zeros = np.zeros((gt_boxes.shape[0], 1), dtype=gt_boxes.dtype)
00055 all_rois = np.vstack((all_rois, np.hstack((zeros, gt_boxes[:, :-1]))))
00056
00057 assert np.all(all_rois[:, 0] == 0), 'Only single item batches are supported'
00058
00059 rois, labels, bbox_targets, bbox_weights = \
00060 sample_rois(all_rois, fg_rois_per_image, rois_per_image, self._num_classes, self._cfg, gt_boxes=gt_boxes)
00061
00062 if DEBUG:
00063 print "labels=", labels
00064 print 'num fg: {}'.format((labels > 0).sum())
00065 print 'num bg: {}'.format((labels == 0).sum())
00066 self._count += 1
00067 self._fg_num += (labels > 0).sum()
00068 self._bg_num += (labels == 0).sum()
00069 print "self._count=", self._count
00070 print 'num fg avg: {}'.format(self._fg_num / self._count)
00071 print 'num bg avg: {}'.format(self._bg_num / self._count)
00072 print 'ratio: {:.3f}'.format(float(self._fg_num) / float(self._bg_num))
00073
00074 for ind, val in enumerate([rois, labels, bbox_targets, bbox_weights]):
00075 self.assign(out_data[ind], req[ind], val)
00076
00077 def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
00078 self.assign(in_grad[0], req[0], 0)
00079 self.assign(in_grad[1], req[1], 0)
00080
00081
00082 @mx.operator.register('proposal_target')
00083 class ProposalTargetProp(mx.operator.CustomOpProp):
00084 def __init__(self, num_classes, batch_images, batch_rois, cfg, fg_fraction='0.25'):
00085 super(ProposalTargetProp, self).__init__(need_top_grad=False)
00086 self._num_classes = int(num_classes)
00087 self._batch_images = int(batch_images)
00088 self._batch_rois = int(batch_rois)
00089 self._cfg = cPickle.loads(cfg)
00090 self._fg_fraction = float(fg_fraction)
00091
00092 def list_arguments(self):
00093 return ['rois', 'gt_boxes']
00094
00095 def list_outputs(self):
00096 return ['rois_output', 'label', 'bbox_target', 'bbox_weight']
00097
00098 def infer_shape(self, in_shape):
00099 rpn_rois_shape = in_shape[0]
00100 gt_boxes_shape = in_shape[1]
00101
00102 rois = rpn_rois_shape[0] + gt_boxes_shape[0] if self._batch_rois == -1 else self._batch_rois
00103
00104 output_rois_shape = (rois, 5)
00105 label_shape = (rois, )
00106 bbox_target_shape = (rois, self._num_classes * 4)
00107 bbox_weight_shape = (rois, self._num_classes * 4)
00108
00109 return [rpn_rois_shape, gt_boxes_shape], \
00110 [output_rois_shape, label_shape, bbox_target_shape, bbox_weight_shape]
00111
00112 def create_operator(self, ctx, shapes, dtypes):
00113 return ProposalTargetOperator(self._num_classes, self._batch_images, self._batch_rois, self._cfg, self._fg_fraction)
00114
00115 def declare_backward_dependency(self, out_grad, in_data, out_data):
00116 return []