00001
00002
00003
00004
00005
00006
00007
00008
00009 """A `MutableModule` implement the `BaseModule` API, and allows input shape
00010 varying with training iterations. If shapes vary, executors will rebind,
00011 using shared arrays from the initial module binded with maximum shape.
00012 """
00013
00014 import time
00015 import logging
00016 import warnings
00017
00018 from mxnet import context as ctx
00019 from mxnet.initializer import Uniform, InitDesc
00020 from mxnet.module.base_module import BaseModule, _check_input_names, _parse_data_desc, _as_list
00021 from mxnet.model import _create_kvstore, _initialize_kvstore, _update_params, _update_params_on_kvstore, load_checkpoint, BatchEndParam
00022 from mxnet import metric
00023
00024
00025 from .DataParallelExecutorGroup import DataParallelExecutorGroup
00026 from mxnet import ndarray as nd
00027 from mxnet import optimizer as opt
00028
00029
00030 class Module(BaseModule):
00031 """Module is a basic module that wrap a `Symbol`. It is functionally the same
00032 as the `FeedForward` model, except under the module API.
00033
00034 Parameters
00035 ----------
00036 symbol : Symbol
00037 data_names : list of str
00038 Default is `('data')` for a typical model used in image classification.
00039 label_names : list of str
00040 Default is `('softmax_label')` for a typical model used in image
00041 classification.
00042 logger : Logger
00043 Default is `logging`.
00044 context : Context or list of Context
00045 Default is `cpu()`.
00046 work_load_list : list of number
00047 Default `None`, indicating uniform workload.
00048 fixed_param_names: list of str
00049 Default `None`, indicating no network parameters are fixed.
00050 state_names : list of str
00051 states are similar to data and label, but not provided by data iterator.
00052 Instead they are initialized to 0 and can be set by set_states()
00053 """
00054 def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
00055 logger=logging, context=ctx.cpu(), work_load_list=None,
00056 fixed_param_names=None, state_names=None):
00057 super(Module, self).__init__(logger=logger)
00058
00059 if isinstance(context, ctx.Context):
00060 context = [context]
00061 self._context = context
00062 if work_load_list is None:
00063 work_load_list = [1] * len(self._context)
00064 assert len(work_load_list) == len(self._context)
00065 self._work_load_list = work_load_list
00066
00067 self._symbol = symbol
00068
00069 data_names = list(data_names) if data_names is not None else []
00070 label_names = list(label_names) if label_names is not None else []
00071 state_names = list(state_names) if state_names is not None else []
00072 fixed_param_names = list(fixed_param_names) if fixed_param_names is not None else []
00073
00074 _check_input_names(symbol, data_names, "data", True)
00075 _check_input_names(symbol, label_names, "label", False)
00076 _check_input_names(symbol, state_names, "state", True)
00077 _check_input_names(symbol, fixed_param_names, "fixed_param", True)
00078
00079 arg_names = symbol.list_arguments()
00080 input_names = data_names + label_names + state_names
00081 self._param_names = [x for x in arg_names if x not in input_names]
00082 self._fixed_param_names = fixed_param_names
00083 self._aux_names = symbol.list_auxiliary_states()
00084 self._data_names = data_names
00085 self._label_names = label_names
00086 self._state_names = state_names
00087 self._output_names = symbol.list_outputs()
00088
00089 self._arg_params = None
00090 self._aux_params = None
00091 self._params_dirty = False
00092
00093 self._optimizer = None
00094 self._kvstore = None
00095 self._update_on_kvstore = None
00096 self._updater = None
00097 self._preload_opt_states = None
00098 self._grad_req = None
00099
00100 self._exec_group = None
00101 self._data_shapes = None
00102 self._label_shapes = None
00103
00104 @staticmethod
00105 def load(prefix, epoch, load_optimizer_states=False, **kwargs):
00106 """Create a model from previously saved checkpoint.
00107
00108 Parameters
00109 ----------
00110 prefix : str
00111 path prefix of saved model files. You should have
00112 "prefix-symbol.json", "prefix-xxxx.params", and
00113 optionally "prefix-xxxx.states", where xxxx is the
00114 epoch number.
00115 epoch : int
00116 epoch to load.
00117 load_optimizer_states : bool
00118 whether to load optimizer states. Checkpoint needs
00119 to have been made with save_optimizer_states=True.
00120 data_names : list of str
00121 Default is `('data')` for a typical model used in image classification.
00122 label_names : list of str
00123 Default is `('softmax_label')` for a typical model used in image
00124 classification.
00125 logger : Logger
00126 Default is `logging`.
00127 context : Context or list of Context
00128 Default is `cpu()`.
00129 work_load_list : list of number
00130 Default `None`, indicating uniform workload.
00131 fixed_param_names: list of str
00132 Default `None`, indicating no network parameters are fixed.
00133 """
00134 sym, args, auxs = load_checkpoint(prefix, epoch)
00135 mod = Module(symbol=sym, **kwargs)
00136 mod._arg_params = args
00137 mod._aux_params = auxs
00138 mod.params_initialized = True
00139 if load_optimizer_states:
00140 mod._preload_opt_states = '%s-%04d.states'%(prefix, epoch)
00141 return mod
00142
00143 def save_checkpoint(self, prefix, epoch, save_optimizer_states=False):
00144 """Save current progress to checkpoint.
00145 Use mx.callback.module_checkpoint as epoch_end_callback to save during training.
00146
00147 Parameters
00148 ----------
00149 prefix : str
00150 The file prefix to checkpoint to
00151 epoch : int
00152 The current epoch number
00153 save_optimizer_states : bool
00154 Whether to save optimizer states for continue training
00155 """
00156 self._symbol.save('%s-symbol.json'%prefix)
00157 param_name = '%s-%04d.params' % (prefix, epoch)
00158 self.save_params(param_name)
00159 logging.info('Saved checkpoint to \"%s\"', param_name)
00160 if save_optimizer_states:
00161 state_name = '%s-%04d.states' % (prefix, epoch)
00162 self.save_optimizer_states(state_name)
00163 logging.info('Saved optimizer state to \"%s\"', state_name)
00164
00165 def _reset_bind(self):
00166 """Internal function to reset binded state."""
00167 self.binded = False
00168 self._exec_group = None
00169 self._data_shapes = None
00170 self._label_shapes = None
00171
00172 @property
00173 def data_names(self):
00174 """A list of names for data required by this module."""
00175 return self._data_names
00176
00177 @property
00178 def label_names(self):
00179 """A list of names for labels required by this module."""
00180 return self._label_names
00181
00182 @property
00183 def output_names(self):
00184 """A list of names for the outputs of this module."""
00185 return self._output_names
00186
00187 @property
00188 def data_shapes(self):
00189 """Get data shapes.
00190 Returns
00191 -------
00192 A list of `(name, shape)` pairs.
00193 """
00194 assert self.binded
00195 return self._data_shapes
00196
00197 @property
00198 def label_shapes(self):
00199 """Get label shapes.
00200 Returns
00201 -------
00202 A list of `(name, shape)` pairs. The return value could be `None` if
00203 the module does not need labels, or if the module is not binded for
00204 training (in this case, label information is not available).
00205 """
00206 assert self.binded
00207 return self._label_shapes
00208
00209 @property
00210 def output_shapes(self):
00211 """Get output shapes.
00212 Returns
00213 -------
00214 A list of `(name, shape)` pairs.
00215 """
00216 assert self.binded
00217 return self._exec_group.get_output_shapes()
00218
00219 def get_params(self):
00220 """Get current parameters.
00221 Returns
00222 -------
00223 `(arg_params, aux_params)`, each a dictionary of name to parameters (in
00224 `NDArray`) mapping.
00225 """
00226 assert self.binded and self.params_initialized
00227
00228 if self._params_dirty:
00229 self._sync_params_from_devices()
00230 return (self._arg_params, self._aux_params)
00231
00232 def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=None,
00233 allow_missing=False, force_init=False, allow_extra=False):
00234 """Initialize the parameters and auxiliary states.
00235
00236 Parameters
00237 ----------
00238 initializer : Initializer
00239 Called to initialize parameters if needed.
00240 arg_params : dict
00241 If not None, should be a dictionary of existing arg_params. Initialization
00242 will be copied from that.
00243 aux_params : dict
00244 If not None, should be a dictionary of existing aux_params. Initialization
00245 will be copied from that.
00246 allow_missing : bool
00247 If true, params could contain missing values, and the initializer will be
00248 called to fill those missing params.
00249 force_init : bool
00250 If true, will force re-initialize even if already initialized.
00251 """
00252 if self.params_initialized and not force_init:
00253 warnings.warn("Parameters already initialized and force_init=False. "
00254 "init_params call ignored.", stacklevel=2)
00255 return
00256 assert self.binded, 'call bind before initializing the parameters'
00257
00258 def _impl(name, arr, cache):
00259 """Internal helper for parameter initialization"""
00260 if cache is not None:
00261 if name in cache:
00262 cache_arr = cache[name]
00263
00264
00265 if cache_arr is not arr:
00266 cache_arr.copyto(arr)
00267 else:
00268 if not allow_missing:
00269 raise RuntimeError("%s is not presented" % name)
00270 if initializer != None:
00271 initializer(name, arr)
00272 else:
00273 initializer(name, arr)
00274
00275 attrs = self._symbol.attr_dict()
00276 for name, arr in self._arg_params.items():
00277 desc = InitDesc(name, attrs.get(name, None))
00278 _impl(desc, arr, arg_params)
00279
00280 for name, arr in self._aux_params.items():
00281 desc = InitDesc(name, attrs.get(name, None))
00282 _impl(desc, arr, aux_params)
00283
00284 self.params_initialized = True
00285 self._params_dirty = False
00286
00287
00288 self._exec_group.set_params(self._arg_params, self._aux_params)
00289
00290 def set_params(self, arg_params, aux_params, allow_missing=False, force_init=True):
00291 """Assign parameter and aux state values.
00292
00293 Parameters
00294 ----------
00295 arg_params : dict
00296 Dictionary of name to value (`NDArray`) mapping.
00297 aux_params : dict
00298 Dictionary of name to value (`NDArray`) mapping.
00299 allow_missing : bool
00300 If true, params could contain missing values, and the initializer will be
00301 called to fill those missing params.
00302 force_init : bool
00303 If true, will force re-initialize even if already initialized.
00304
00305 Examples
00306 --------
00307 An example of setting module parameters::
00308 >>> sym, arg_params, aux_params = \
00309 >>> mx.model.load_checkpoint(model_prefix, n_epoch_load)
00310 >>> mod.set_params(arg_params=arg_params, aux_params=aux_params)
00311 """
00312 if not allow_missing:
00313 self.init_params(initializer=None, arg_params=arg_params, aux_params=aux_params,
00314 allow_missing=allow_missing, force_init=force_init)
00315 return
00316
00317 if self.params_initialized and not force_init:
00318 warnings.warn("Parameters already initialized and force_init=False. "
00319 "set_params call ignored.", stacklevel=2)
00320 return
00321
00322 self._exec_group.set_params(arg_params, aux_params)
00323
00324
00325 self._params_dirty = True
00326 self.params_initialized = True
00327
00328 def bind(self, data_shapes, label_shapes=None, for_training=True,
00329 inputs_need_grad=False, force_rebind=False, shared_module=None,
00330 grad_req='write'):
00331 """Bind the symbols to construct executors. This is necessary before one
00332 can perform computation with the module.
00333
00334 Parameters
00335 ----------
00336 data_shapes : list of (str, tuple)
00337 Typically is `data_iter.provide_data`.
00338 label_shapes : list of (str, tuple)
00339 Typically is `data_iter.provide_label`.
00340 for_training : bool
00341 Default is `True`. Whether the executors should be bind for training.
00342 inputs_need_grad : bool
00343 Default is `False`. Whether the gradients to the input data need to be computed.
00344 Typically this is not needed. But this might be needed when implementing composition
00345 of modules.
00346 force_rebind : bool
00347 Default is `False`. This function does nothing if the executors are already
00348 binded. But with this `True`, the executors will be forced to rebind.
00349 shared_module : Module
00350 Default is `None`. This is used in bucketing. When not `None`, the shared module
00351 essentially corresponds to a different bucket -- a module with different symbol
00352 but with the same sets of parameters (e.g. unrolled RNNs with different lengths).
00353 """
00354
00355
00356 if force_rebind:
00357 self._reset_bind()
00358
00359 if self.binded:
00360 self.logger.warning('Already binded, ignoring bind()')
00361 return
00362
00363 self.for_training = for_training
00364 self.inputs_need_grad = inputs_need_grad
00365 self.binded = True
00366 self._grad_req = grad_req
00367
00368 if not for_training:
00369 assert not inputs_need_grad
00370 else:
00371 pass
00372
00373
00374
00375
00376
00377
00378 self._data_shapes, self._label_shapes = zip(*[_parse_data_desc(self.data_names, self.label_names, data_shape, label_shape)
00379 for data_shape, label_shape in zip(data_shapes, label_shapes)])
00380 if self._label_shapes.count(None) == len(self._label_shapes):
00381 self._label_shapes = None
00382
00383 if shared_module is not None:
00384 assert isinstance(shared_module, Module) and \
00385 shared_module.binded and shared_module.params_initialized
00386 shared_group = shared_module._exec_group
00387 else:
00388 shared_group = None
00389 self._exec_group = DataParallelExecutorGroup(self._symbol, self._context,
00390 self._work_load_list, self._data_shapes,
00391 self._label_shapes, self._param_names,
00392 for_training, inputs_need_grad,
00393 shared_group, logger=self.logger,
00394 fixed_param_names=self._fixed_param_names,
00395 grad_req=grad_req,
00396 state_names=self._state_names)
00397
00398 if shared_module is not None:
00399 self.params_initialized = True
00400 self._arg_params = shared_module._arg_params
00401 self._aux_params = shared_module._aux_params
00402 elif self.params_initialized:
00403
00404
00405 self._exec_group.set_params(self._arg_params, self._aux_params)
00406 else:
00407 assert self._arg_params is None and self._aux_params is None
00408 param_arrays = [
00409 nd.zeros(x[0].shape, dtype=x[0].dtype)
00410 for x in self._exec_group.param_arrays
00411 ]
00412 self._arg_params = {name:arr for name, arr in zip(self._param_names, param_arrays)}
00413
00414 aux_arrays = [
00415 nd.zeros(x[0].shape, dtype=x[0].dtype)
00416 for x in self._exec_group.aux_arrays
00417 ]
00418 self._aux_params = {name:arr for name, arr in zip(self._aux_names, aux_arrays)}
00419
00420 if shared_module is not None and shared_module.optimizer_initialized:
00421 self.borrow_optimizer(shared_module)
00422
00423
00424 def reshape(self, data_shapes, label_shapes=None):
00425 """Reshape the module for new input shapes.
00426
00427 Parameters
00428 ----------
00429 data_shapes : list of (str, tuple)
00430 Typically is `data_iter.provide_data`.
00431 label_shapes : list of (str, tuple)
00432 Typically is `data_iter.provide_label`.
00433 """
00434 assert self.binded
00435
00436
00437 self._data_shapes, self._label_shapes = zip(*[_parse_data_desc(self.data_names, self.label_names, data_shape, label_shape)
00438 for data_shape, label_shape in zip(data_shapes, label_shapes)])
00439
00440 self._exec_group.reshape(self._data_shapes, self._label_shapes)
00441
00442
00443 def init_optimizer(self, kvstore='local', optimizer='sgd',
00444 optimizer_params=(('learning_rate', 0.01),), force_init=False):
00445 """Install and initialize optimizers.
00446
00447 Parameters
00448 ----------
00449 kvstore : str or KVStore
00450 Default `'local'`.
00451 optimizer : str or Optimizer
00452 Default `'sgd'`
00453 optimizer_params : dict
00454 Default `(('learning_rate', 0.01),)`. The default value is not a dictionary,
00455 just to avoid pylint warning of dangerous default values.
00456 force_init : bool
00457 Default `False`, indicating whether we should force re-initializing the
00458 optimizer in the case an optimizer is already installed.
00459 """
00460 assert self.binded and self.params_initialized
00461
00462 if self.optimizer_initialized and not force_init:
00463 self.logger.warning('optimizer already initialized, ignoring...')
00464 return
00465
00466 (kvstore, update_on_kvstore) = \
00467 _create_kvstore(kvstore, len(self._context), self._arg_params)
00468
00469 batch_size = self._exec_group.batch_size
00470 if kvstore and 'dist' in kvstore.type and '_sync' in kvstore.type:
00471 batch_size *= kvstore.num_workers
00472 rescale_grad = 1.0/batch_size
00473
00474 if isinstance(optimizer, str):
00475 idx2name = {}
00476 if update_on_kvstore:
00477 idx2name.update(enumerate(self._exec_group.param_names))
00478 else:
00479 for k in range(len(self._context)):
00480 idx2name.update({i*len(self._context)+k: n
00481 for i, n in enumerate(self._exec_group.param_names)})
00482 optimizer_params = dict(optimizer_params)
00483 if 'rescale_grad' not in optimizer_params:
00484 optimizer_params['rescale_grad'] = rescale_grad
00485 optimizer = opt.create(optimizer,
00486 sym=self.symbol, param_idx2name=idx2name,
00487 **optimizer_params)
00488 else:
00489 assert isinstance(optimizer, opt.Optimizer)
00490 if optimizer.rescale_grad != rescale_grad:
00491
00492 warnings.warn(
00493 "Optimizer created manually outside Module but rescale_grad " +
00494 "is not normalized to 1.0/batch_size/num_workers (%s vs. %s). "%(
00495 optimizer.rescale_grad, rescale_grad) +
00496 "Is this intended?", stacklevel=2)
00497
00498 self._optimizer = optimizer
00499 self._kvstore = kvstore
00500 self._update_on_kvstore = update_on_kvstore
00501 self._updater = None
00502
00503 if kvstore:
00504
00505 _initialize_kvstore(kvstore=kvstore,
00506 param_arrays=self._exec_group.param_arrays,
00507 arg_params=self._arg_params,
00508 param_names=self._param_names,
00509 update_on_kvstore=update_on_kvstore)
00510 if update_on_kvstore:
00511 kvstore.set_optimizer(self._optimizer)
00512 else:
00513 self._updater = opt.get_updater(optimizer)
00514
00515 self.optimizer_initialized = True
00516
00517 if self._preload_opt_states is not None:
00518 self.load_optimizer_states(self._preload_opt_states)
00519 self._preload_opt_states = None
00520
00521 def borrow_optimizer(self, shared_module):
00522 """Borrow optimizer from a shared module. Used in bucketing, where exactly the same
00523 optimizer (esp. kvstore) is used.
00524
00525 Parameters
00526 ----------
00527 shared_module : Module
00528 """
00529 assert shared_module.optimizer_initialized
00530 self._optimizer = shared_module._optimizer
00531 self._kvstore = shared_module._kvstore
00532 self._update_on_kvstore = shared_module._update_on_kvstore
00533 self._updater = shared_module._updater
00534 self.optimizer_initialized = True
00535
00536 def forward(self, data_batch, is_train=None):
00537 """Forward computation.
00538
00539 Parameters
00540 ----------
00541 data_batch : DataBatch
00542 Could be anything with similar API implemented.
00543 is_train : bool
00544 Default is `None`, which means `is_train` takes the value of `self.for_training`.
00545 """
00546 assert self.binded and self.params_initialized
00547 self._exec_group.forward(data_batch, is_train)
00548
00549 def backward(self, out_grads=None):
00550 """Backward computation.
00551
00552 Parameters
00553 ----------
00554 out_grads : NDArray or list of NDArray, optional
00555 Gradient on the outputs to be propagated back.
00556 This parameter is only needed when bind is called
00557 on outputs that are not a loss function.
00558 """
00559 assert self.binded and self.params_initialized
00560 self._exec_group.backward(out_grads=out_grads)
00561
00562 def update(self):
00563 """Update parameters according to the installed optimizer and the gradients computed
00564 in the previous forward-backward batch.
00565 """
00566 assert self.binded and self.params_initialized and self.optimizer_initialized
00567
00568 self._params_dirty = True
00569 if self._update_on_kvstore:
00570 _update_params_on_kvstore(self._exec_group.param_arrays,
00571 self._exec_group.grad_arrays,
00572 self._kvstore, self._exec_group.param_names)
00573 else:
00574 _update_params(self._exec_group.param_arrays,
00575 self._exec_group.grad_arrays,
00576 updater=self._updater,
00577 num_device=len(self._context),
00578 kvstore=self._kvstore)
00579
00580 def get_outputs(self, merge_multi_context=True):
00581 """Get outputs of the previous forward computation.
00582
00583 Parameters
00584 ----------
00585 merge_multi_context : bool
00586 Default is `True`. In the case when data-parallelism is used, the outputs
00587 will be collected from multiple devices. A `True` value indicate that we
00588 should merge the collected results so that they look like from a single
00589 executor.
00590
00591 Returns
00592 -------
00593 If `merge_multi_context` is `True`, it is like `[out1, out2]`. Otherwise, it
00594 is like `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`. All the output
00595 elements are `NDArray`.
00596 """
00597 assert self.binded and self.params_initialized
00598 return self._exec_group.get_outputs(merge_multi_context=merge_multi_context)
00599
00600 def get_input_grads(self, merge_multi_context=True):
00601 """Get the gradients with respect to the inputs of the module.
00602
00603 Parameters
00604 ----------
00605 merge_multi_context : bool
00606 Default is `True`. In the case when data-parallelism is used, the outputs
00607 will be collected from multiple devices. A `True` value indicate that we
00608 should merge the collected results so that they look like from a single
00609 executor.
00610
00611 Returns
00612 -------
00613 If `merge_multi_context` is `True`, it is like `[grad1, grad2]`. Otherwise, it
00614 is like `[[grad1_dev1, grad1_dev2], [grad2_dev1, grad2_dev2]]`. All the output
00615 elements are `NDArray`.
00616 """
00617 assert self.binded and self.params_initialized and self.inputs_need_grad
00618 return self._exec_group.get_input_grads(merge_multi_context=merge_multi_context)
00619
00620 def get_states(self, merge_multi_context=True):
00621 """Get states from all devices
00622
00623 Parameters
00624 ----------
00625 merge_multi_context : bool
00626 Default is `True`. In the case when data-parallelism is used, the states
00627 will be collected from multiple devices. A `True` value indicate that we
00628 should merge the collected results so that they look like from a single
00629 executor.
00630
00631 Returns
00632 -------
00633 If `merge_multi_context` is `True`, it is like `[out1, out2]`. Otherwise, it
00634 is like `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`. All the output
00635 elements are `NDArray`.
00636 """
00637 assert self.binded and self.params_initialized
00638 return self._exec_group.get_states(merge_multi_context=merge_multi_context)
00639
00640 def set_states(self, states=None, value=None):
00641 """Set value for states. Only one of states & value can be specified.
00642
00643 Parameters
00644 ----------
00645 states : list of list of NDArrays
00646 source states arrays formatted like [[state1_dev1, state1_dev2],
00647 [state2_dev1, state2_dev2]].
00648 value : number
00649 a single scalar value for all state arrays.
00650 """
00651 assert self.binded and self.params_initialized
00652 self._exec_group.set_states(states, value)
00653
00654 def update_metric(self, eval_metric, labels):
00655 """Evaluate and accumulate evaluation metric on outputs of the last forward computation.
00656
00657 Parameters
00658 ----------
00659 eval_metric : EvalMetric
00660 labels : list of NDArray
00661 Typically `data_batch.label`.
00662 """
00663 self._exec_group.update_metric(eval_metric, labels)
00664
00665 def _sync_params_from_devices(self):
00666 """Synchronize parameters from devices to CPU. This function should be called after
00667 calling `update` that updates the parameters on the devices, before one can read the
00668 latest parameters from `self._arg_params` and `self._aux_params`.
00669 """
00670 self._exec_group.get_params(self._arg_params, self._aux_params)
00671 self._params_dirty = False
00672
00673 def save_optimizer_states(self, fname):
00674 """Save optimizer (updater) state to file
00675
00676 Parameters
00677 ----------
00678 fname : str
00679 Path to output states file.
00680 """
00681 assert self.optimizer_initialized
00682
00683 if self._update_on_kvstore:
00684 self._kvstore.save_optimizer_states(fname)
00685 else:
00686 with open(fname, 'wb') as fout:
00687 fout.write(self._updater.get_states())
00688
00689 def load_optimizer_states(self, fname):
00690 """Load optimizer (updater) state from file
00691
00692 Parameters
00693 ----------
00694 fname : str
00695 Path to input states file.
00696 """
00697 assert self.optimizer_initialized
00698
00699 if self._update_on_kvstore:
00700 self._kvstore.load_optimizer_states(fname)
00701 else:
00702 self._updater.set_states(open(fname, 'rb').read())
00703
00704 def install_monitor(self, mon):
00705 """ Install monitor on all executors """
00706 assert self.binded
00707 self._exec_group.install_monitor(mon)
00708
00709
00710
00711
00712 class MutableModule(BaseModule):
00713 """A mutable module is a module that supports variable input data.
00714
00715 Parameters
00716 ----------
00717 symbol : Symbol
00718 data_names : list of str
00719 label_names : list of str
00720 logger : Logger
00721 context : Context or list of Context
00722 work_load_list : list of number
00723 max_data_shapes : list of (name, shape) tuple, designating inputs whose shape vary
00724 max_label_shapes : list of (name, shape) tuple, designating inputs whose shape vary
00725 fixed_param_prefix : list of str, indicating fixed parameters
00726 """
00727 def __init__(self, symbol, data_names, label_names,
00728 logger=logging, context=ctx.cpu(), work_load_list=None,
00729 max_data_shapes=None, max_label_shapes=None, fixed_param_prefix=None):
00730 super(MutableModule, self).__init__(logger=logger)
00731 self._symbol = symbol
00732 self._data_names = data_names
00733 self._label_names = label_names
00734 self._context = context
00735 self._work_load_list = work_load_list
00736
00737 self._curr_module = None
00738 self._max_data_shapes = max_data_shapes
00739 self._max_label_shapes = max_label_shapes
00740 self._fixed_param_prefix = fixed_param_prefix
00741
00742 fixed_param_names = list()
00743 if fixed_param_prefix is not None:
00744 for name in self._symbol.list_arguments():
00745 for prefix in self._fixed_param_prefix:
00746 if prefix in name:
00747 fixed_param_names.append(name)
00748 self._fixed_param_names = fixed_param_names
00749 self._preload_opt_states = None
00750
00751 def _reset_bind(self):
00752 self.binded = False
00753 self._curr_module = None
00754
00755 @property
00756 def data_names(self):
00757 return self._data_names
00758
00759 @property
00760 def output_names(self):
00761 return self._symbol.list_outputs()
00762
00763 @property
00764 def data_shapes(self):
00765 assert self.binded
00766 return self._curr_module.data_shapes
00767
00768 @property
00769 def label_shapes(self):
00770 assert self.binded
00771 return self._curr_module.label_shapes
00772
00773 @property
00774 def output_shapes(self):
00775 assert self.binded
00776 return self._curr_module.output_shapes
00777
00778 def get_params(self):
00779 assert self.binded and self.params_initialized
00780 return self._curr_module.get_params()
00781
00782 def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=None,
00783 allow_missing=False, force_init=False, allow_extra=False):
00784 if self.params_initialized and not force_init:
00785 return
00786 assert self.binded, 'call bind before initializing the parameters'
00787 self._curr_module.init_params(initializer=initializer, arg_params=arg_params,
00788 aux_params=aux_params, allow_missing=allow_missing,
00789 force_init=force_init)
00790 self.params_initialized = True
00791
00792 def bind(self, data_shapes, label_shapes=None, for_training=True,
00793 inputs_need_grad=False, force_rebind=False, shared_module=None, grad_req='write'):
00794
00795 if self.params_initialized:
00796 arg_params, aux_params = self.get_params()
00797
00798
00799
00800 if force_rebind:
00801 self._reset_bind()
00802
00803 if self.binded:
00804 self.logger.warning('Already binded, ignoring bind()')
00805 return
00806
00807 assert shared_module is None, 'shared_module for MutableModule is not supported'
00808
00809 self.for_training = for_training
00810 self.inputs_need_grad = inputs_need_grad
00811 self.binded = True
00812
00813 max_shapes_dict = dict()
00814 if self._max_data_shapes is not None:
00815 max_shapes_dict.update(dict(self._max_data_shapes[0]))
00816 if self._max_label_shapes is not None:
00817 max_shapes_dict.update(dict(self._max_label_shapes[0]))
00818
00819 max_data_shapes = list()
00820 for name, shape in data_shapes[0]:
00821 if name in max_shapes_dict:
00822 max_data_shapes.append((name, max_shapes_dict[name]))
00823 else:
00824 max_data_shapes.append((name, shape))
00825
00826 max_label_shapes = list()
00827 if not label_shapes.count(None) == len(label_shapes):
00828 for name, shape in label_shapes[0]:
00829 if name in max_shapes_dict:
00830 max_label_shapes.append((name, max_shapes_dict[name]))
00831 else:
00832 max_label_shapes.append((name, shape))
00833
00834 if len(max_label_shapes) == 0:
00835 max_label_shapes = None
00836
00837 module = Module(self._symbol, self._data_names, self._label_names, logger=self.logger,
00838 context=self._context, work_load_list=self._work_load_list,
00839 fixed_param_names=self._fixed_param_names)
00840 module.bind([max_data_shapes for _ in range(len(self._context))], [max_label_shapes for _ in range(len(self._context))],
00841 for_training, inputs_need_grad, force_rebind=False, shared_module=None)
00842 self._curr_module = module
00843
00844
00845 if self.params_initialized:
00846 self.set_params(arg_params, aux_params)
00847
00848 def save_checkpoint(self, prefix, epoch, save_optimizer_states=False):
00849 """Save current progress to checkpoint.
00850 Use mx.callback.module_checkpoint as epoch_end_callback to save during training.
00851
00852 Parameters
00853 ----------
00854 prefix : str
00855 The file prefix to checkpoint to
00856 epoch : int
00857 The current epoch number
00858 save_optimizer_states : bool
00859 Whether to save optimizer states for continue training
00860 """
00861 self._curr_module.save_checkpoint(prefix, epoch, save_optimizer_states)
00862
00863 def init_optimizer(self, kvstore='local', optimizer='sgd',
00864 optimizer_params=(('learning_rate', 0.01),), force_init=False):
00865 assert self.binded and self.params_initialized
00866 if self.optimizer_initialized and not force_init:
00867 self.logger.warning('optimizer already initialized, ignoring.')
00868 return
00869
00870 self._curr_module._preload_opt_states = self._preload_opt_states
00871 self._curr_module.init_optimizer(kvstore, optimizer, optimizer_params,
00872 force_init=force_init)
00873 self.optimizer_initialized = True
00874
00875 def fit(self, train_data, eval_data=None, eval_metric='acc',
00876 epoch_end_callback=None, batch_end_callback=None, kvstore='local',
00877 optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
00878 eval_end_callback=None,
00879 eval_batch_end_callback=None, initializer=Uniform(0.01),
00880 arg_params=None, aux_params=None, allow_missing=False,
00881 force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None,
00882 validation_metric=None, monitor=None, prefix=None, state=None):
00883 """Train the module parameters.
00884
00885 Parameters
00886 ----------
00887 train_data : DataIter
00888 eval_data : DataIter
00889 If not `None`, will be used as validation set and evaluate the performance
00890 after each epoch.
00891 eval_metric : str or EvalMetric
00892 Default `'acc'`. The performance measure used to display during training.
00893 epoch_end_callback : function or list of function
00894 Each callback will be called with the current `epoch`, `symbol`, `arg_params`
00895 and `aux_params`.
00896 batch_end_callback : function or list of function
00897 Each callback will be called with a `BatchEndParam`.
00898 kvstore : str or KVStore
00899 Default `'local'`.
00900 optimizer : str or Optimizer
00901 Default `'sgd'`
00902 optimizer_params : dict
00903 Default `(('learning_rate', 0.01),)`. The parameters for the optimizer constructor.
00904 The default value is not a `dict`, just to avoid pylint warning on dangerous
00905 default values.
00906 eval_end_callback : function or list of function
00907 These will be called at the end of each full evaluation, with the metrics over
00908 the entire evaluation set.
00909 eval_batch_end_callback : function or list of function
00910 These will be called at the end of each minibatch during evaluation
00911 initializer : Initializer
00912 Will be called to initialize the module parameters if not already initialized.
00913 arg_params : dict
00914 Default `None`, if not `None`, should be existing parameters from a trained
00915 model or loaded from a checkpoint (previously saved model). In this case,
00916 the value here will be used to initialize the module parameters, unless they
00917 are already initialized by the user via a call to `init_params` or `fit`.
00918 `arg_params` has higher priority to `initializer`.
00919 aux_params : dict
00920 Default `None`. Similar to `arg_params`, except for auxiliary states.
00921 allow_missing : bool
00922 Default `False`. Indicate whether we allow missing parameters when `arg_params`
00923 and `aux_params` are not `None`. If this is `True`, then the missing parameters
00924 will be initialized via the `initializer`.
00925 force_rebind : bool
00926 Default `False`. Whether to force rebinding the executors if already binded.
00927 force_init : bool
00928 Default `False`. Indicate whether we should force initialization even if the
00929 parameters are already initialized.
00930 begin_epoch : int
00931 Default `0`. Indicate the starting epoch. Usually, if we are resuming from a
00932 checkpoint saved at a previous training phase at epoch N, then we should specify
00933 this value as N+1.
00934 num_epoch : int
00935 Number of epochs to run training.
00936
00937 Examples
00938 --------
00939 An example of using fit for training::
00940 >>> #Assume training dataIter and validation dataIter are ready
00941 >>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter,
00942 optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
00943 num_epoch=10)
00944 """
00945 assert num_epoch is not None, 'please specify number of epochs'
00946
00947 self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label,
00948 for_training=True, force_rebind=force_rebind)
00949 if monitor is not None:
00950 self.install_monitor(monitor)
00951 self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params,
00952 allow_missing=allow_missing, force_init=force_init)
00953 self.init_optimizer(kvstore=kvstore, optimizer=optimizer,
00954 optimizer_params=optimizer_params)
00955 if state is not None:
00956 self._curr_module.load_optimizer_states(state)
00957
00958 if validation_metric is None:
00959 validation_metric = eval_metric
00960 if not isinstance(eval_metric, metric.EvalMetric):
00961 eval_metric = metric.create(eval_metric)
00962
00963
00964
00965
00966 for epoch in range(begin_epoch, num_epoch):
00967 tic = time.time()
00968 eval_metric.reset()
00969 for nbatch, data_batch in enumerate(train_data):
00970 if monitor is not None:
00971 monitor.tic()
00972 self.forward_backward(data_batch)
00973 self.update()
00974 self.update_metric(eval_metric, data_batch.label)
00975
00976 if monitor is not None:
00977 monitor.toc_print()
00978
00979 if batch_end_callback is not None:
00980 batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,
00981 eval_metric=eval_metric,
00982 locals=locals())
00983 for callback in _as_list(batch_end_callback):
00984 callback(batch_end_params)
00985
00986
00987 for name, val in eval_metric.get_name_value():
00988 self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
00989 toc = time.time()
00990 self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc-tic))
00991
00992
00993 arg_params, aux_params = self.get_params()
00994 self.set_params(arg_params, aux_params)
00995
00996 if epoch_end_callback is not None:
00997 for callback in _as_list(epoch_end_callback):
00998 callback(epoch, self.symbol, arg_params, aux_params)
00999
01000
01001
01002 if eval_data:
01003 res = self.score(eval_data, validation_metric,
01004 score_end_callback=eval_end_callback,
01005 batch_end_callback=eval_batch_end_callback, epoch=epoch)
01006
01007 for name, val in res:
01008 self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val)
01009
01010
01011 train_data.reset()
01012
01013
01014 def forward(self, data_batch, is_train=None):
01015 assert self.binded and self.params_initialized
01016
01017
01018 if self._curr_module.label_shapes is not None:
01019 current_shapes = [dict(self._curr_module.data_shapes[i] + self._curr_module.label_shapes[i]) for i in range(len(self._context))]
01020 else:
01021 current_shapes = [dict(self._curr_module.data_shapes[i]) for i in range(len(self._context))]
01022
01023
01024 if is_train:
01025 input_shapes = [dict(data_batch.provide_data[i] + data_batch.provide_label[i]) for i in range(len(self._context))]
01026 else:
01027 input_shapes = [dict(data_batch.provide_data[i]) for i in range(len(data_batch.provide_data))]
01028
01029
01030 shape_changed = len(current_shapes) != len(input_shapes)
01031 for pre, cur in zip(current_shapes, input_shapes):
01032 for k, v in pre.items():
01033 if v != cur[k]:
01034 shape_changed = True
01035
01036 if shape_changed:
01037
01038 module = Module(self._symbol, self._data_names, self._label_names,
01039 logger=self.logger, context=[self._context[i] for i in range(len(data_batch.provide_data))],
01040 work_load_list=self._work_load_list,
01041 fixed_param_names=self._fixed_param_names)
01042 module.bind(data_batch.provide_data, data_batch.provide_label, self._curr_module.for_training,
01043 self._curr_module.inputs_need_grad, force_rebind=False,
01044 shared_module=self._curr_module)
01045 self._curr_module = module
01046
01047 self._curr_module.forward(data_batch, is_train=is_train)
01048
01049 def backward(self, out_grads=None):
01050 assert self.binded and self.params_initialized
01051 self._curr_module.backward(out_grads=out_grads)
01052
01053 def update(self):
01054 assert self.binded and self.params_initialized and self.optimizer_initialized
01055 self._curr_module.update()
01056
01057 def get_outputs(self, merge_multi_context=True):
01058 assert self.binded and self.params_initialized
01059 return self._curr_module.get_outputs(merge_multi_context=merge_multi_context)
01060 def get_input_grads(self, merge_multi_context=True):
01061 assert self.binded and self.params_initialized and self.inputs_need_grad
01062 return self._curr_module.get_input_grads(merge_multi_context=merge_multi_context)
01063
01064 def update_metric(self, eval_metric, labels):
01065 assert self.binded and self.params_initialized
01066 self._curr_module.update_metric(eval_metric, labels)
01067
01068 def install_monitor(self, mon):
01069 """ Install monitor on all executors """
01070 assert self.binded
01071 self._curr_module.install_monitor(mon)