callback.py
Go to the documentation of this file.
00001 # --------------------------------------------------------
00002 # Deformable Convolutional Networks
00003 # Copyright (c) 2016 by Contributors
00004 # Copyright (c) 2017 Microsoft
00005 # Licensed under The Apache-2.0 License [see LICENSE for details]
00006 # Modified by Yuwen Xiong
00007 # --------------------------------------------------------
00008 
00009 import time
00010 import logging
00011 import mxnet as mx
00012 
00013 
00014 class Speedometer(object):
00015     def __init__(self, batch_size, frequent=50):
00016         self.batch_size = batch_size
00017         self.frequent = frequent
00018         self.init = False
00019         self.tic = 0
00020         self.last_count = 0
00021 
00022     def __call__(self, param):
00023         """Callback to Show speed."""
00024         count = param.nbatch
00025         if self.last_count > count:
00026             self.init = False
00027         self.last_count = count
00028 
00029         if self.init:
00030             if count % self.frequent == 0:
00031                 speed = self.frequent * self.batch_size / (time.time() - self.tic)
00032                 s = ''
00033                 if param.eval_metric is not None:
00034                     name, value = param.eval_metric.get()
00035                     s = "Epoch[%d] Batch [%d]\tSpeed: %.2f samples/sec\tTrain-" % (param.epoch, count, speed)
00036                     for n, v in zip(name, value):
00037                         s += "%s=%f,\t" % (n, v)
00038                 else:
00039                     s = "Iter[%d] Batch [%d]\tSpeed: %.2f samples/sec" % (param.epoch, count, speed)
00040 
00041                 logging.info(s)
00042                 print(s)
00043                 self.tic = time.time()
00044         else:
00045             self.init = True
00046             self.tic = time.time()
00047 
00048 
00049 def do_checkpoint(prefix, means, stds):
00050     def _callback(iter_no, sym, arg, aux):
00051         weight = arg['rfcn_bbox_weight']
00052         bias = arg['rfcn_bbox_bias']
00053         repeat = bias.shape[0] / means.shape[0]
00054 
00055         arg['rfcn_bbox_weight_test'] = weight * mx.nd.repeat(mx.nd.array(stds), repeats=repeat).reshape((bias.shape[0], 1, 1, 1))
00056         arg['rfcn_bbox_bias_test'] = arg['rfcn_bbox_bias'] * mx.nd.repeat(mx.nd.array(stds), repeats=repeat) + mx.nd.repeat(mx.nd.array(means), repeats=repeat)
00057         mx.model.save_checkpoint(prefix, iter_no + 1, sym, arg, aux)
00058         arg.pop('rfcn_bbox_weight_test')
00059         arg.pop('rfcn_bbox_bias_test')
00060     return _callback


rail_object_detector
Author(s):
autogenerated on Sat Jun 8 2019 20:26:29