lr_scheduler.py
Go to the documentation of this file.
00001 # --------------------------------------------------------
00002 # Deformable Convolutional Networks
00003 # Copyright (c) 2017 Microsoft
00004 # Licensed under The Apache-2.0 License [see LICENSE for details]
00005 # Written by Yuwen Xiong
00006 # --------------------------------------------------------
00007 
00008 
00009 import logging
00010 from mxnet.lr_scheduler import LRScheduler
00011 
00012 class WarmupMultiFactorScheduler(LRScheduler):
00013     """Reduce learning rate in factor at steps specified in a list
00014 
00015     Assume the weight has been updated by n times, then the learning rate will
00016     be
00017 
00018     base_lr * factor^(sum((step/n)<=1)) # step is an array
00019 
00020     Parameters
00021     ----------
00022     step: list of int
00023         schedule learning rate after n updates
00024     factor: float
00025         the factor for reducing the learning rate
00026     """
00027     def __init__(self, step, factor=1, warmup=False, warmup_lr=0, warmup_step=0):
00028         super(WarmupMultiFactorScheduler, self).__init__()
00029         assert isinstance(step, list) and len(step) >= 1
00030         for i, _step in enumerate(step):
00031             if i != 0 and step[i] <= step[i-1]:
00032                 raise ValueError("Schedule step must be an increasing integer list")
00033             if _step < 1:
00034                 raise ValueError("Schedule step must be greater or equal than 1 round")
00035         if factor > 1.0:
00036             raise ValueError("Factor must be no more than 1 to make lr reduce")
00037         self.step = step
00038         self.cur_step_ind = 0
00039         self.factor = factor
00040         self.count = 0
00041         self.warmup = warmup
00042         self.warmup_lr = warmup_lr
00043         self.warmup_step = warmup_step
00044 
00045     def __call__(self, num_update):
00046         """
00047         Call to schedule current learning rate
00048 
00049         Parameters
00050         ----------
00051         num_update: int
00052             the maximal number of updates applied to a weight.
00053         """
00054 
00055         # NOTE: use while rather than if  (for continuing training via load_epoch)
00056         if self.warmup and num_update < self.warmup_step:
00057             return self.warmup_lr
00058         while self.cur_step_ind <= len(self.step)-1:
00059             if num_update > self.step[self.cur_step_ind]:
00060                 self.count = self.step[self.cur_step_ind]
00061                 self.cur_step_ind += 1
00062                 self.base_lr *= self.factor
00063                 logging.info("Update[%d]: Change learning rate to %0.5e",
00064                              num_update, self.base_lr)
00065             else:
00066                 return self.base_lr
00067         return self.base_lr


rail_object_detector
Author(s):
autogenerated on Sat Jun 8 2019 20:26:30