Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009 import logging
00010 from mxnet.lr_scheduler import LRScheduler
00011
00012 class WarmupMultiFactorScheduler(LRScheduler):
00013 """Reduce learning rate in factor at steps specified in a list
00014
00015 Assume the weight has been updated by n times, then the learning rate will
00016 be
00017
00018 base_lr * factor^(sum((step/n)<=1)) # step is an array
00019
00020 Parameters
00021 ----------
00022 step: list of int
00023 schedule learning rate after n updates
00024 factor: float
00025 the factor for reducing the learning rate
00026 """
00027 def __init__(self, step, factor=1, warmup=False, warmup_lr=0, warmup_step=0):
00028 super(WarmupMultiFactorScheduler, self).__init__()
00029 assert isinstance(step, list) and len(step) >= 1
00030 for i, _step in enumerate(step):
00031 if i != 0 and step[i] <= step[i-1]:
00032 raise ValueError("Schedule step must be an increasing integer list")
00033 if _step < 1:
00034 raise ValueError("Schedule step must be greater or equal than 1 round")
00035 if factor > 1.0:
00036 raise ValueError("Factor must be no more than 1 to make lr reduce")
00037 self.step = step
00038 self.cur_step_ind = 0
00039 self.factor = factor
00040 self.count = 0
00041 self.warmup = warmup
00042 self.warmup_lr = warmup_lr
00043 self.warmup_step = warmup_step
00044
00045 def __call__(self, num_update):
00046 """
00047 Call to schedule current learning rate
00048
00049 Parameters
00050 ----------
00051 num_update: int
00052 the maximal number of updates applied to a weight.
00053 """
00054
00055
00056 if self.warmup and num_update < self.warmup_step:
00057 return self.warmup_lr
00058 while self.cur_step_ind <= len(self.step)-1:
00059 if num_update > self.step[self.cur_step_ind]:
00060 self.count = self.step[self.cur_step_ind]
00061 self.cur_step_ind += 1
00062 self.base_lr *= self.factor
00063 logging.info("Update[%d]: Change learning rate to %0.5e",
00064 num_update, self.base_lr)
00065 else:
00066 return self.base_lr
00067 return self.base_lr