module.py
Go to the documentation of this file.
00001 # --------------------------------------------------------
00002 # Deformable Convolutional Networks
00003 # Copyright (c) 2016 by Contributors
00004 # Copyright (c) 2017 Microsoft
00005 # Licensed under The Apache-2.0 License [see LICENSE for details]
00006 # Modified by Zheng Zhang
00007 # --------------------------------------------------------
00008 
00009 """A `MutableModule` implement the `BaseModule` API, and allows input shape
00010 varying with training iterations. If shapes vary, executors will rebind,
00011 using shared arrays from the initial module binded with maximum shape.
00012 """
00013 
00014 import time
00015 import logging
00016 import warnings
00017 
00018 from mxnet import context as ctx
00019 from mxnet.initializer import Uniform, InitDesc
00020 from mxnet.module.base_module import BaseModule, _check_input_names, _parse_data_desc, _as_list
00021 from mxnet.model import _create_kvstore, _initialize_kvstore, _update_params, _update_params_on_kvstore, load_checkpoint, BatchEndParam
00022 from mxnet import metric
00023 # from mxnet.module.executor_group import DataParallelExecutorGroup
00024 
00025 from .DataParallelExecutorGroup import DataParallelExecutorGroup
00026 from mxnet import ndarray as nd
00027 from mxnet import optimizer as opt
00028 
00029 
00030 class Module(BaseModule):
00031     """Module is a basic module that wrap a `Symbol`. It is functionally the same
00032     as the `FeedForward` model, except under the module API.
00033 
00034     Parameters
00035     ----------
00036     symbol : Symbol
00037     data_names : list of str
00038         Default is `('data')` for a typical model used in image classification.
00039     label_names : list of str
00040         Default is `('softmax_label')` for a typical model used in image
00041         classification.
00042     logger : Logger
00043         Default is `logging`.
00044     context : Context or list of Context
00045         Default is `cpu()`.
00046     work_load_list : list of number
00047         Default `None`, indicating uniform workload.
00048     fixed_param_names: list of str
00049         Default `None`, indicating no network parameters are fixed.
00050     state_names : list of str
00051         states are similar to data and label, but not provided by data iterator.
00052         Instead they are initialized to 0 and can be set by set_states()
00053     """
00054     def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
00055                  logger=logging, context=ctx.cpu(), work_load_list=None,
00056                  fixed_param_names=None, state_names=None):
00057         super(Module, self).__init__(logger=logger)
00058 
00059         if isinstance(context, ctx.Context):
00060             context = [context]
00061         self._context = context
00062         if work_load_list is None:
00063             work_load_list = [1] * len(self._context)
00064         assert len(work_load_list) == len(self._context)
00065         self._work_load_list = work_load_list
00066 
00067         self._symbol = symbol
00068 
00069         data_names = list(data_names) if data_names is not None else []
00070         label_names = list(label_names) if label_names is not None else []
00071         state_names = list(state_names) if state_names is not None else []
00072         fixed_param_names = list(fixed_param_names) if fixed_param_names is not None else []
00073 
00074         _check_input_names(symbol, data_names, "data", True)
00075         _check_input_names(symbol, label_names, "label", False)
00076         _check_input_names(symbol, state_names, "state", True)
00077         _check_input_names(symbol, fixed_param_names, "fixed_param", True)
00078 
00079         arg_names = symbol.list_arguments()
00080         input_names = data_names + label_names + state_names
00081         self._param_names = [x for x in arg_names if x not in input_names]
00082         self._fixed_param_names = fixed_param_names
00083         self._aux_names = symbol.list_auxiliary_states()
00084         self._data_names = data_names
00085         self._label_names = label_names
00086         self._state_names = state_names
00087         self._output_names = symbol.list_outputs()
00088 
00089         self._arg_params = None
00090         self._aux_params = None
00091         self._params_dirty = False
00092 
00093         self._optimizer = None
00094         self._kvstore = None
00095         self._update_on_kvstore = None
00096         self._updater = None
00097         self._preload_opt_states = None
00098         self._grad_req = None
00099 
00100         self._exec_group = None
00101         self._data_shapes = None
00102         self._label_shapes = None
00103 
00104     @staticmethod
00105     def load(prefix, epoch, load_optimizer_states=False, **kwargs):
00106         """Create a model from previously saved checkpoint.
00107 
00108         Parameters
00109         ----------
00110         prefix : str
00111             path prefix of saved model files. You should have
00112             "prefix-symbol.json", "prefix-xxxx.params", and
00113             optionally "prefix-xxxx.states", where xxxx is the
00114             epoch number.
00115         epoch : int
00116             epoch to load.
00117         load_optimizer_states : bool
00118             whether to load optimizer states. Checkpoint needs
00119             to have been made with save_optimizer_states=True.
00120         data_names : list of str
00121             Default is `('data')` for a typical model used in image classification.
00122         label_names : list of str
00123             Default is `('softmax_label')` for a typical model used in image
00124             classification.
00125         logger : Logger
00126             Default is `logging`.
00127         context : Context or list of Context
00128             Default is `cpu()`.
00129         work_load_list : list of number
00130             Default `None`, indicating uniform workload.
00131         fixed_param_names: list of str
00132             Default `None`, indicating no network parameters are fixed.
00133         """
00134         sym, args, auxs = load_checkpoint(prefix, epoch)
00135         mod = Module(symbol=sym, **kwargs)
00136         mod._arg_params = args
00137         mod._aux_params = auxs
00138         mod.params_initialized = True
00139         if load_optimizer_states:
00140             mod._preload_opt_states = '%s-%04d.states'%(prefix, epoch)
00141         return mod
00142 
00143     def save_checkpoint(self, prefix, epoch, save_optimizer_states=False):
00144         """Save current progress to checkpoint.
00145         Use mx.callback.module_checkpoint as epoch_end_callback to save during training.
00146 
00147         Parameters
00148         ----------
00149         prefix : str
00150             The file prefix to checkpoint to
00151         epoch : int
00152             The current epoch number
00153         save_optimizer_states : bool
00154             Whether to save optimizer states for continue training
00155         """
00156         self._symbol.save('%s-symbol.json'%prefix)
00157         param_name = '%s-%04d.params' % (prefix, epoch)
00158         self.save_params(param_name)
00159         logging.info('Saved checkpoint to \"%s\"', param_name)
00160         if save_optimizer_states:
00161             state_name = '%s-%04d.states' % (prefix, epoch)
00162             self.save_optimizer_states(state_name)
00163             logging.info('Saved optimizer state to \"%s\"', state_name)
00164 
00165     def _reset_bind(self):
00166         """Internal function to reset binded state."""
00167         self.binded = False
00168         self._exec_group = None
00169         self._data_shapes = None
00170         self._label_shapes = None
00171 
00172     @property
00173     def data_names(self):
00174         """A list of names for data required by this module."""
00175         return self._data_names
00176 
00177     @property
00178     def label_names(self):
00179         """A list of names for labels required by this module."""
00180         return self._label_names
00181 
00182     @property
00183     def output_names(self):
00184         """A list of names for the outputs of this module."""
00185         return self._output_names
00186 
00187     @property
00188     def data_shapes(self):
00189         """Get data shapes.
00190         Returns
00191         -------
00192         A list of `(name, shape)` pairs.
00193         """
00194         assert self.binded
00195         return self._data_shapes
00196 
00197     @property
00198     def label_shapes(self):
00199         """Get label shapes.
00200         Returns
00201         -------
00202         A list of `(name, shape)` pairs. The return value could be `None` if
00203         the module does not need labels, or if the module is not binded for
00204         training (in this case, label information is not available).
00205         """
00206         assert self.binded
00207         return self._label_shapes
00208 
00209     @property
00210     def output_shapes(self):
00211         """Get output shapes.
00212         Returns
00213         -------
00214         A list of `(name, shape)` pairs.
00215         """
00216         assert self.binded
00217         return self._exec_group.get_output_shapes()
00218 
00219     def get_params(self):
00220         """Get current parameters.
00221         Returns
00222         -------
00223         `(arg_params, aux_params)`, each a dictionary of name to parameters (in
00224         `NDArray`) mapping.
00225         """
00226         assert self.binded and self.params_initialized
00227 
00228         if self._params_dirty:
00229             self._sync_params_from_devices()
00230         return (self._arg_params, self._aux_params)
00231 
00232     def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=None,
00233                     allow_missing=False, force_init=False, allow_extra=False):
00234         """Initialize the parameters and auxiliary states.
00235 
00236         Parameters
00237         ----------
00238         initializer : Initializer
00239             Called to initialize parameters if needed.
00240         arg_params : dict
00241             If not None, should be a dictionary of existing arg_params. Initialization
00242             will be copied from that.
00243         aux_params : dict
00244             If not None, should be a dictionary of existing aux_params. Initialization
00245             will be copied from that.
00246         allow_missing : bool
00247             If true, params could contain missing values, and the initializer will be
00248             called to fill those missing params.
00249         force_init : bool
00250             If true, will force re-initialize even if already initialized.
00251         """
00252         if self.params_initialized and not force_init:
00253             warnings.warn("Parameters already initialized and force_init=False. "
00254                           "init_params call ignored.", stacklevel=2)
00255             return
00256         assert self.binded, 'call bind before initializing the parameters'
00257 
00258         def _impl(name, arr, cache):
00259             """Internal helper for parameter initialization"""
00260             if cache is not None:
00261                 if name in cache:
00262                     cache_arr = cache[name]
00263 
00264                     # just in case the cached array is just the target itself
00265                     if cache_arr is not arr:
00266                         cache_arr.copyto(arr)
00267                 else:
00268                     if not allow_missing:
00269                         raise RuntimeError("%s is not presented" % name)
00270                     if initializer != None:
00271                         initializer(name, arr)
00272             else:
00273                 initializer(name, arr)
00274 
00275         attrs = self._symbol.attr_dict()
00276         for name, arr in self._arg_params.items():
00277             desc = InitDesc(name, attrs.get(name, None))
00278             _impl(desc, arr, arg_params)
00279 
00280         for name, arr in self._aux_params.items():
00281             desc = InitDesc(name, attrs.get(name, None))
00282             _impl(desc, arr, aux_params)
00283 
00284         self.params_initialized = True
00285         self._params_dirty = False
00286 
00287         # copy the initialized parameters to devices
00288         self._exec_group.set_params(self._arg_params, self._aux_params)
00289 
00290     def set_params(self, arg_params, aux_params, allow_missing=False, force_init=True):
00291         """Assign parameter and aux state values.
00292 
00293         Parameters
00294         ----------
00295         arg_params : dict
00296             Dictionary of name to value (`NDArray`) mapping.
00297         aux_params : dict
00298             Dictionary of name to value (`NDArray`) mapping.
00299         allow_missing : bool
00300             If true, params could contain missing values, and the initializer will be
00301             called to fill those missing params.
00302         force_init : bool
00303             If true, will force re-initialize even if already initialized.
00304 
00305         Examples
00306         --------
00307         An example of setting module parameters::
00308             >>> sym, arg_params, aux_params = \
00309             >>>     mx.model.load_checkpoint(model_prefix, n_epoch_load)
00310             >>> mod.set_params(arg_params=arg_params, aux_params=aux_params)
00311         """
00312         if not allow_missing:
00313             self.init_params(initializer=None, arg_params=arg_params, aux_params=aux_params,
00314                              allow_missing=allow_missing, force_init=force_init)
00315             return
00316 
00317         if self.params_initialized and not force_init:
00318             warnings.warn("Parameters already initialized and force_init=False. "
00319                           "set_params call ignored.", stacklevel=2)
00320             return
00321 
00322         self._exec_group.set_params(arg_params, aux_params)
00323 
00324         # because we didn't update self._arg_params, they are dirty now.
00325         self._params_dirty = True
00326         self.params_initialized = True
00327 
00328     def bind(self, data_shapes, label_shapes=None, for_training=True,
00329              inputs_need_grad=False, force_rebind=False, shared_module=None,
00330              grad_req='write'):
00331         """Bind the symbols to construct executors. This is necessary before one
00332         can perform computation with the module.
00333 
00334         Parameters
00335         ----------
00336         data_shapes : list of (str, tuple)
00337             Typically is `data_iter.provide_data`.
00338         label_shapes : list of (str, tuple)
00339             Typically is `data_iter.provide_label`.
00340         for_training : bool
00341             Default is `True`. Whether the executors should be bind for training.
00342         inputs_need_grad : bool
00343             Default is `False`. Whether the gradients to the input data need to be computed.
00344             Typically this is not needed. But this might be needed when implementing composition
00345             of modules.
00346         force_rebind : bool
00347             Default is `False`. This function does nothing if the executors are already
00348             binded. But with this `True`, the executors will be forced to rebind.
00349         shared_module : Module
00350             Default is `None`. This is used in bucketing. When not `None`, the shared module
00351             essentially corresponds to a different bucket -- a module with different symbol
00352             but with the same sets of parameters (e.g. unrolled RNNs with different lengths).
00353         """
00354         # force rebinding is typically used when one want to switch from
00355         # training to prediction phase.
00356         if force_rebind:
00357             self._reset_bind()
00358 
00359         if self.binded:
00360             self.logger.warning('Already binded, ignoring bind()')
00361             return
00362 
00363         self.for_training = for_training
00364         self.inputs_need_grad = inputs_need_grad
00365         self.binded = True
00366         self._grad_req = grad_req
00367 
00368         if not for_training:
00369             assert not inputs_need_grad
00370         else:
00371             pass
00372             # this is not True, as some module might not contains a loss function
00373             # that consumes the labels
00374             # assert label_shapes is not None
00375 
00376         # self._data_shapes, self._label_shapes = _parse_data_desc(
00377         #     self.data_names, self.label_names, data_shapes, label_shapes)
00378         self._data_shapes, self._label_shapes = zip(*[_parse_data_desc(self.data_names, self.label_names, data_shape, label_shape)
00379                                                       for data_shape, label_shape in zip(data_shapes, label_shapes)])
00380         if self._label_shapes.count(None) == len(self._label_shapes):
00381             self._label_shapes = None
00382 
00383         if shared_module is not None:
00384             assert isinstance(shared_module, Module) and \
00385                     shared_module.binded and shared_module.params_initialized
00386             shared_group = shared_module._exec_group
00387         else:
00388             shared_group = None
00389         self._exec_group = DataParallelExecutorGroup(self._symbol, self._context,
00390                                                      self._work_load_list, self._data_shapes,
00391                                                      self._label_shapes, self._param_names,
00392                                                      for_training, inputs_need_grad,
00393                                                      shared_group, logger=self.logger,
00394                                                      fixed_param_names=self._fixed_param_names,
00395                                                      grad_req=grad_req,
00396                                                      state_names=self._state_names)
00397         # self._total_exec_bytes = self._exec_group._total_exec_bytes
00398         if shared_module is not None:
00399             self.params_initialized = True
00400             self._arg_params = shared_module._arg_params
00401             self._aux_params = shared_module._aux_params
00402         elif self.params_initialized:
00403             # if the parameters are already initialized, we are re-binding
00404             # so automatically copy the already initialized params
00405             self._exec_group.set_params(self._arg_params, self._aux_params)
00406         else:
00407             assert self._arg_params is None and self._aux_params is None
00408             param_arrays = [
00409                 nd.zeros(x[0].shape, dtype=x[0].dtype)
00410                 for x in self._exec_group.param_arrays
00411             ]
00412             self._arg_params = {name:arr for name, arr in zip(self._param_names, param_arrays)}
00413 
00414             aux_arrays = [
00415                 nd.zeros(x[0].shape, dtype=x[0].dtype)
00416                 for x in self._exec_group.aux_arrays
00417             ]
00418             self._aux_params = {name:arr for name, arr in zip(self._aux_names, aux_arrays)}
00419 
00420         if shared_module is not None and shared_module.optimizer_initialized:
00421             self.borrow_optimizer(shared_module)
00422 
00423 
00424     def reshape(self, data_shapes, label_shapes=None):
00425         """Reshape the module for new input shapes.
00426 
00427         Parameters
00428         ----------
00429         data_shapes : list of (str, tuple)
00430             Typically is `data_iter.provide_data`.
00431         label_shapes : list of (str, tuple)
00432             Typically is `data_iter.provide_label`.
00433         """
00434         assert self.binded
00435         # self._data_shapes, self._label_shapes = _parse_data_desc(
00436         #     self.data_names, self.label_names, data_shapes, label_shapes)
00437         self._data_shapes, self._label_shapes = zip(*[_parse_data_desc(self.data_names, self.label_names, data_shape, label_shape)
00438                                                       for data_shape, label_shape in zip(data_shapes, label_shapes)])
00439 
00440         self._exec_group.reshape(self._data_shapes, self._label_shapes)
00441 
00442 
00443     def init_optimizer(self, kvstore='local', optimizer='sgd',
00444                        optimizer_params=(('learning_rate', 0.01),), force_init=False):
00445         """Install and initialize optimizers.
00446 
00447         Parameters
00448         ----------
00449         kvstore : str or KVStore
00450             Default `'local'`.
00451         optimizer : str or Optimizer
00452             Default `'sgd'`
00453         optimizer_params : dict
00454             Default `(('learning_rate', 0.01),)`. The default value is not a dictionary,
00455             just to avoid pylint warning of dangerous default values.
00456         force_init : bool
00457             Default `False`, indicating whether we should force re-initializing the
00458             optimizer in the case an optimizer is already installed.
00459         """
00460         assert self.binded and self.params_initialized
00461 
00462         if self.optimizer_initialized and not force_init:
00463             self.logger.warning('optimizer already initialized, ignoring...')
00464             return
00465 
00466         (kvstore, update_on_kvstore) = \
00467                 _create_kvstore(kvstore, len(self._context), self._arg_params)
00468 
00469         batch_size = self._exec_group.batch_size
00470         if kvstore and 'dist' in kvstore.type and '_sync' in kvstore.type:
00471             batch_size *= kvstore.num_workers
00472         rescale_grad = 1.0/batch_size
00473 
00474         if isinstance(optimizer, str):
00475             idx2name = {}
00476             if update_on_kvstore:
00477                 idx2name.update(enumerate(self._exec_group.param_names))
00478             else:
00479                 for k in range(len(self._context)):
00480                     idx2name.update({i*len(self._context)+k: n
00481                                      for i, n in enumerate(self._exec_group.param_names)})
00482             optimizer_params = dict(optimizer_params)
00483             if 'rescale_grad' not in optimizer_params:
00484                 optimizer_params['rescale_grad'] = rescale_grad
00485             optimizer = opt.create(optimizer,
00486                                    sym=self.symbol, param_idx2name=idx2name,
00487                                    **optimizer_params)
00488         else:
00489             assert isinstance(optimizer, opt.Optimizer)
00490             if optimizer.rescale_grad != rescale_grad:
00491                 #pylint: disable=no-member
00492                 warnings.warn(
00493                     "Optimizer created manually outside Module but rescale_grad " +
00494                     "is not normalized to 1.0/batch_size/num_workers (%s vs. %s). "%(
00495                         optimizer.rescale_grad, rescale_grad) +
00496                     "Is this intended?", stacklevel=2)
00497 
00498         self._optimizer = optimizer
00499         self._kvstore = kvstore
00500         self._update_on_kvstore = update_on_kvstore
00501         self._updater = None
00502 
00503         if kvstore:
00504             # copy initialized local parameters to kvstore
00505             _initialize_kvstore(kvstore=kvstore,
00506                                 param_arrays=self._exec_group.param_arrays,
00507                                 arg_params=self._arg_params,
00508                                 param_names=self._param_names,
00509                                 update_on_kvstore=update_on_kvstore)
00510         if update_on_kvstore:
00511             kvstore.set_optimizer(self._optimizer)
00512         else:
00513             self._updater = opt.get_updater(optimizer)
00514 
00515         self.optimizer_initialized = True
00516 
00517         if self._preload_opt_states is not None:
00518             self.load_optimizer_states(self._preload_opt_states)
00519             self._preload_opt_states = None
00520 
00521     def borrow_optimizer(self, shared_module):
00522         """Borrow optimizer from a shared module. Used in bucketing, where exactly the same
00523         optimizer (esp. kvstore) is used.
00524 
00525         Parameters
00526         ----------
00527         shared_module : Module
00528         """
00529         assert shared_module.optimizer_initialized
00530         self._optimizer = shared_module._optimizer
00531         self._kvstore = shared_module._kvstore
00532         self._update_on_kvstore = shared_module._update_on_kvstore
00533         self._updater = shared_module._updater
00534         self.optimizer_initialized = True
00535 
00536     def forward(self, data_batch, is_train=None):
00537         """Forward computation.
00538 
00539         Parameters
00540         ----------
00541         data_batch : DataBatch
00542             Could be anything with similar API implemented.
00543         is_train : bool
00544             Default is `None`, which means `is_train` takes the value of `self.for_training`.
00545         """
00546         assert self.binded and self.params_initialized
00547         self._exec_group.forward(data_batch, is_train)
00548 
00549     def backward(self, out_grads=None):
00550         """Backward computation.
00551 
00552         Parameters
00553         ----------
00554         out_grads : NDArray or list of NDArray, optional
00555             Gradient on the outputs to be propagated back.
00556             This parameter is only needed when bind is called
00557             on outputs that are not a loss function.
00558         """
00559         assert self.binded and self.params_initialized
00560         self._exec_group.backward(out_grads=out_grads)
00561 
00562     def update(self):
00563         """Update parameters according to the installed optimizer and the gradients computed
00564         in the previous forward-backward batch.
00565         """
00566         assert self.binded and self.params_initialized and self.optimizer_initialized
00567 
00568         self._params_dirty = True
00569         if self._update_on_kvstore:
00570             _update_params_on_kvstore(self._exec_group.param_arrays,
00571                                       self._exec_group.grad_arrays,
00572                                       self._kvstore, self._exec_group.param_names)
00573         else:
00574             _update_params(self._exec_group.param_arrays,
00575                            self._exec_group.grad_arrays,
00576                            updater=self._updater,
00577                            num_device=len(self._context),
00578                            kvstore=self._kvstore)
00579 
00580     def get_outputs(self, merge_multi_context=True):
00581         """Get outputs of the previous forward computation.
00582 
00583         Parameters
00584         ----------
00585         merge_multi_context : bool
00586             Default is `True`. In the case when data-parallelism is used, the outputs
00587             will be collected from multiple devices. A `True` value indicate that we
00588             should merge the collected results so that they look like from a single
00589             executor.
00590 
00591         Returns
00592         -------
00593         If `merge_multi_context` is `True`, it is like `[out1, out2]`. Otherwise, it
00594         is like `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`. All the output
00595         elements are `NDArray`.
00596         """
00597         assert self.binded and self.params_initialized
00598         return self._exec_group.get_outputs(merge_multi_context=merge_multi_context)
00599 
00600     def get_input_grads(self, merge_multi_context=True):
00601         """Get the gradients with respect to the inputs of the module.
00602 
00603         Parameters
00604         ----------
00605         merge_multi_context : bool
00606             Default is `True`. In the case when data-parallelism is used, the outputs
00607             will be collected from multiple devices. A `True` value indicate that we
00608             should merge the collected results so that they look like from a single
00609             executor.
00610 
00611         Returns
00612         -------
00613         If `merge_multi_context` is `True`, it is like `[grad1, grad2]`. Otherwise, it
00614         is like `[[grad1_dev1, grad1_dev2], [grad2_dev1, grad2_dev2]]`. All the output
00615         elements are `NDArray`.
00616         """
00617         assert self.binded and self.params_initialized and self.inputs_need_grad
00618         return self._exec_group.get_input_grads(merge_multi_context=merge_multi_context)
00619 
00620     def get_states(self, merge_multi_context=True):
00621         """Get states from all devices
00622 
00623         Parameters
00624         ----------
00625         merge_multi_context : bool
00626             Default is `True`. In the case when data-parallelism is used, the states
00627             will be collected from multiple devices. A `True` value indicate that we
00628             should merge the collected results so that they look like from a single
00629             executor.
00630 
00631         Returns
00632         -------
00633         If `merge_multi_context` is `True`, it is like `[out1, out2]`. Otherwise, it
00634         is like `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`. All the output
00635         elements are `NDArray`.
00636         """
00637         assert self.binded and self.params_initialized
00638         return self._exec_group.get_states(merge_multi_context=merge_multi_context)
00639 
00640     def set_states(self, states=None, value=None):
00641         """Set value for states. Only one of states & value can be specified.
00642 
00643         Parameters
00644         ----------
00645         states : list of list of NDArrays
00646             source states arrays formatted like [[state1_dev1, state1_dev2],
00647             [state2_dev1, state2_dev2]].
00648         value : number
00649             a single scalar value for all state arrays.
00650         """
00651         assert self.binded and self.params_initialized
00652         self._exec_group.set_states(states, value)
00653 
00654     def update_metric(self, eval_metric, labels):
00655         """Evaluate and accumulate evaluation metric on outputs of the last forward computation.
00656 
00657         Parameters
00658         ----------
00659         eval_metric : EvalMetric
00660         labels : list of NDArray
00661             Typically `data_batch.label`.
00662         """
00663         self._exec_group.update_metric(eval_metric, labels)
00664 
00665     def _sync_params_from_devices(self):
00666         """Synchronize parameters from devices to CPU. This function should be called after
00667         calling `update` that updates the parameters on the devices, before one can read the
00668         latest parameters from `self._arg_params` and `self._aux_params`.
00669         """
00670         self._exec_group.get_params(self._arg_params, self._aux_params)
00671         self._params_dirty = False
00672 
00673     def save_optimizer_states(self, fname):
00674         """Save optimizer (updater) state to file
00675 
00676         Parameters
00677         ----------
00678         fname : str
00679             Path to output states file.
00680         """
00681         assert self.optimizer_initialized
00682 
00683         if self._update_on_kvstore:
00684             self._kvstore.save_optimizer_states(fname)
00685         else:
00686             with open(fname, 'wb') as fout:
00687                 fout.write(self._updater.get_states())
00688 
00689     def load_optimizer_states(self, fname):
00690         """Load optimizer (updater) state from file
00691 
00692         Parameters
00693         ----------
00694         fname : str
00695             Path to input states file.
00696         """
00697         assert self.optimizer_initialized
00698 
00699         if self._update_on_kvstore:
00700             self._kvstore.load_optimizer_states(fname)
00701         else:
00702             self._updater.set_states(open(fname, 'rb').read())
00703 
00704     def install_monitor(self, mon):
00705         """ Install monitor on all executors """
00706         assert self.binded
00707         self._exec_group.install_monitor(mon)
00708 
00709 
00710 
00711 
00712 class MutableModule(BaseModule):
00713     """A mutable module is a module that supports variable input data.
00714 
00715     Parameters
00716     ----------
00717     symbol : Symbol
00718     data_names : list of str
00719     label_names : list of str
00720     logger : Logger
00721     context : Context or list of Context
00722     work_load_list : list of number
00723     max_data_shapes : list of (name, shape) tuple, designating inputs whose shape vary
00724     max_label_shapes : list of (name, shape) tuple, designating inputs whose shape vary
00725     fixed_param_prefix : list of str, indicating fixed parameters
00726     """
00727     def __init__(self, symbol, data_names, label_names,
00728                  logger=logging, context=ctx.cpu(), work_load_list=None,
00729                  max_data_shapes=None, max_label_shapes=None, fixed_param_prefix=None):
00730         super(MutableModule, self).__init__(logger=logger)
00731         self._symbol = symbol
00732         self._data_names = data_names
00733         self._label_names = label_names
00734         self._context = context
00735         self._work_load_list = work_load_list
00736 
00737         self._curr_module = None
00738         self._max_data_shapes = max_data_shapes
00739         self._max_label_shapes = max_label_shapes
00740         self._fixed_param_prefix = fixed_param_prefix
00741 
00742         fixed_param_names = list()
00743         if fixed_param_prefix is not None:
00744             for name in self._symbol.list_arguments():
00745                 for prefix in self._fixed_param_prefix:
00746                     if prefix in name:
00747                         fixed_param_names.append(name)
00748         self._fixed_param_names = fixed_param_names
00749         self._preload_opt_states = None
00750 
00751     def _reset_bind(self):
00752         self.binded = False
00753         self._curr_module = None
00754 
00755     @property
00756     def data_names(self):
00757         return self._data_names
00758 
00759     @property
00760     def output_names(self):
00761         return self._symbol.list_outputs()
00762 
00763     @property
00764     def data_shapes(self):
00765         assert self.binded
00766         return self._curr_module.data_shapes
00767 
00768     @property
00769     def label_shapes(self):
00770         assert self.binded
00771         return self._curr_module.label_shapes
00772 
00773     @property
00774     def output_shapes(self):
00775         assert self.binded
00776         return self._curr_module.output_shapes
00777 
00778     def get_params(self):
00779         assert self.binded and self.params_initialized
00780         return self._curr_module.get_params()
00781 
00782     def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=None,
00783                     allow_missing=False, force_init=False, allow_extra=False):
00784         if self.params_initialized and not force_init:
00785             return
00786         assert self.binded, 'call bind before initializing the parameters'
00787         self._curr_module.init_params(initializer=initializer, arg_params=arg_params,
00788                                       aux_params=aux_params, allow_missing=allow_missing,
00789                                       force_init=force_init)
00790         self.params_initialized = True
00791 
00792     def bind(self, data_shapes, label_shapes=None, for_training=True,
00793              inputs_need_grad=False, force_rebind=False, shared_module=None, grad_req='write'):
00794         # in case we already initialized params, keep it
00795         if self.params_initialized:
00796             arg_params, aux_params = self.get_params()
00797 
00798         # force rebinding is typically used when one want to switch from
00799         # training to prediction phase.
00800         if force_rebind:
00801             self._reset_bind()
00802 
00803         if self.binded:
00804             self.logger.warning('Already binded, ignoring bind()')
00805             return
00806 
00807         assert shared_module is None, 'shared_module for MutableModule is not supported'
00808 
00809         self.for_training = for_training
00810         self.inputs_need_grad = inputs_need_grad
00811         self.binded = True
00812 
00813         max_shapes_dict = dict()
00814         if self._max_data_shapes is not None:
00815             max_shapes_dict.update(dict(self._max_data_shapes[0]))
00816         if self._max_label_shapes is not None:
00817             max_shapes_dict.update(dict(self._max_label_shapes[0]))
00818 
00819         max_data_shapes = list()
00820         for name, shape in data_shapes[0]:
00821             if name in max_shapes_dict:
00822                 max_data_shapes.append((name, max_shapes_dict[name]))
00823             else:
00824                 max_data_shapes.append((name, shape))
00825 
00826         max_label_shapes = list()
00827         if not label_shapes.count(None) == len(label_shapes):
00828             for name, shape in label_shapes[0]:
00829                 if name in max_shapes_dict:
00830                     max_label_shapes.append((name, max_shapes_dict[name]))
00831                 else:
00832                     max_label_shapes.append((name, shape))
00833 
00834         if len(max_label_shapes) == 0:
00835             max_label_shapes = None
00836 
00837         module = Module(self._symbol, self._data_names, self._label_names, logger=self.logger,
00838                         context=self._context, work_load_list=self._work_load_list,
00839                         fixed_param_names=self._fixed_param_names)
00840         module.bind([max_data_shapes for _ in range(len(self._context))], [max_label_shapes for _ in range(len(self._context))],
00841                     for_training, inputs_need_grad, force_rebind=False, shared_module=None)
00842         self._curr_module = module
00843 
00844         # copy back saved params, if already initialized
00845         if self.params_initialized:
00846             self.set_params(arg_params, aux_params)
00847 
00848     def save_checkpoint(self, prefix, epoch, save_optimizer_states=False):
00849         """Save current progress to checkpoint.
00850         Use mx.callback.module_checkpoint as epoch_end_callback to save during training.
00851 
00852         Parameters
00853         ----------
00854         prefix : str
00855             The file prefix to checkpoint to
00856         epoch : int
00857             The current epoch number
00858         save_optimizer_states : bool
00859             Whether to save optimizer states for continue training
00860         """
00861         self._curr_module.save_checkpoint(prefix, epoch, save_optimizer_states)
00862 
00863     def init_optimizer(self, kvstore='local', optimizer='sgd',
00864                        optimizer_params=(('learning_rate', 0.01),), force_init=False):
00865         assert self.binded and self.params_initialized
00866         if self.optimizer_initialized and not force_init:
00867             self.logger.warning('optimizer already initialized, ignoring.')
00868             return
00869 
00870         self._curr_module._preload_opt_states = self._preload_opt_states
00871         self._curr_module.init_optimizer(kvstore, optimizer, optimizer_params,
00872                                          force_init=force_init)
00873         self.optimizer_initialized = True
00874 
00875     def fit(self, train_data, eval_data=None, eval_metric='acc',
00876             epoch_end_callback=None, batch_end_callback=None, kvstore='local',
00877             optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
00878             eval_end_callback=None,
00879             eval_batch_end_callback=None, initializer=Uniform(0.01),
00880             arg_params=None, aux_params=None, allow_missing=False,
00881             force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None,
00882             validation_metric=None, monitor=None, prefix=None, state=None):
00883         """Train the module parameters.
00884 
00885         Parameters
00886         ----------
00887         train_data : DataIter
00888         eval_data : DataIter
00889             If not `None`, will be used as validation set and evaluate the performance
00890             after each epoch.
00891         eval_metric : str or EvalMetric
00892             Default `'acc'`. The performance measure used to display during training.
00893         epoch_end_callback : function or list of function
00894             Each callback will be called with the current `epoch`, `symbol`, `arg_params`
00895             and `aux_params`.
00896         batch_end_callback : function or list of function
00897             Each callback will be called with a `BatchEndParam`.
00898         kvstore : str or KVStore
00899             Default `'local'`.
00900         optimizer : str or Optimizer
00901             Default `'sgd'`
00902         optimizer_params : dict
00903             Default `(('learning_rate', 0.01),)`. The parameters for the optimizer constructor.
00904             The default value is not a `dict`, just to avoid pylint warning on dangerous
00905             default values.
00906         eval_end_callback : function or list of function
00907             These will be called at the end of each full evaluation, with the metrics over
00908             the entire evaluation set.
00909         eval_batch_end_callback : function or list of function
00910             These will be called at the end of each minibatch during evaluation
00911         initializer : Initializer
00912             Will be called to initialize the module parameters if not already initialized.
00913         arg_params : dict
00914             Default `None`, if not `None`, should be existing parameters from a trained
00915             model or loaded from a checkpoint (previously saved model). In this case,
00916             the value here will be used to initialize the module parameters, unless they
00917             are already initialized by the user via a call to `init_params` or `fit`.
00918             `arg_params` has higher priority to `initializer`.
00919         aux_params : dict
00920             Default `None`. Similar to `arg_params`, except for auxiliary states.
00921         allow_missing : bool
00922             Default `False`. Indicate whether we allow missing parameters when `arg_params`
00923             and `aux_params` are not `None`. If this is `True`, then the missing parameters
00924             will be initialized via the `initializer`.
00925         force_rebind : bool
00926             Default `False`. Whether to force rebinding the executors if already binded.
00927         force_init : bool
00928             Default `False`. Indicate whether we should force initialization even if the
00929             parameters are already initialized.
00930         begin_epoch : int
00931             Default `0`. Indicate the starting epoch. Usually, if we are resuming from a
00932             checkpoint saved at a previous training phase at epoch N, then we should specify
00933             this value as N+1.
00934         num_epoch : int
00935             Number of epochs to run training.
00936 
00937         Examples
00938         --------
00939         An example of using fit for training::
00940             >>> #Assume training dataIter and validation dataIter are ready
00941             >>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter,
00942                         optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
00943                         num_epoch=10)
00944         """
00945         assert num_epoch is not None, 'please specify number of epochs'
00946 
00947         self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label,
00948                   for_training=True, force_rebind=force_rebind)
00949         if monitor is not None:
00950             self.install_monitor(monitor)
00951         self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params,
00952                          allow_missing=allow_missing, force_init=force_init)
00953         self.init_optimizer(kvstore=kvstore, optimizer=optimizer,
00954                             optimizer_params=optimizer_params)
00955         if state is not None:
00956             self._curr_module.load_optimizer_states(state)
00957 
00958         if validation_metric is None:
00959             validation_metric = eval_metric
00960         if not isinstance(eval_metric, metric.EvalMetric):
00961             eval_metric = metric.create(eval_metric)
00962 
00963         ################################################################################
00964         # training loop
00965         ################################################################################
00966         for epoch in range(begin_epoch, num_epoch):
00967             tic = time.time()
00968             eval_metric.reset()
00969             for nbatch, data_batch in enumerate(train_data):
00970                 if monitor is not None:
00971                     monitor.tic()
00972                 self.forward_backward(data_batch)
00973                 self.update()
00974                 self.update_metric(eval_metric, data_batch.label)
00975 
00976                 if monitor is not None:
00977                     monitor.toc_print()
00978 
00979                 if batch_end_callback is not None:
00980                     batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,
00981                                                      eval_metric=eval_metric,
00982                                                      locals=locals())
00983                     for callback in _as_list(batch_end_callback):
00984                         callback(batch_end_params)
00985 
00986             # one epoch of training is finished
00987             for name, val in eval_metric.get_name_value():
00988                 self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
00989             toc = time.time()
00990             self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc-tic))
00991 
00992             # sync aux params across devices
00993             arg_params, aux_params = self.get_params()
00994             self.set_params(arg_params, aux_params)
00995 
00996             if epoch_end_callback is not None:
00997                 for callback in _as_list(epoch_end_callback):
00998                     callback(epoch, self.symbol, arg_params, aux_params)
00999 
01000             #----------------------------------------
01001             # evaluation on validation set
01002             if eval_data:
01003                 res = self.score(eval_data, validation_metric,
01004                                  score_end_callback=eval_end_callback,
01005                                  batch_end_callback=eval_batch_end_callback, epoch=epoch)
01006                 #TODO: pull this into default
01007                 for name, val in res:
01008                     self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val)
01009 
01010             # end of 1 epoch, reset the data-iter for another epoch
01011             train_data.reset()
01012 
01013 
01014     def forward(self, data_batch, is_train=None):
01015         assert self.binded and self.params_initialized
01016 
01017         # get current_shapes
01018         if self._curr_module.label_shapes is not None:
01019             current_shapes = [dict(self._curr_module.data_shapes[i] + self._curr_module.label_shapes[i]) for i in range(len(self._context))]
01020         else:
01021             current_shapes = [dict(self._curr_module.data_shapes[i]) for i in range(len(self._context))]
01022 
01023         # get input_shapes
01024         if is_train:
01025             input_shapes = [dict(data_batch.provide_data[i] + data_batch.provide_label[i]) for i in range(len(self._context))]
01026         else:
01027             input_shapes = [dict(data_batch.provide_data[i]) for i in range(len(data_batch.provide_data))]
01028 
01029         # decide if shape changed
01030         shape_changed = len(current_shapes) != len(input_shapes)
01031         for pre, cur in zip(current_shapes, input_shapes):
01032             for k, v in pre.items():
01033                 if v != cur[k]:
01034                     shape_changed = True
01035 
01036         if shape_changed:
01037             # self._curr_module.reshape(data_batch.provide_data, data_batch.provide_label)
01038             module = Module(self._symbol, self._data_names, self._label_names,
01039                             logger=self.logger, context=[self._context[i] for i in range(len(data_batch.provide_data))],
01040                             work_load_list=self._work_load_list,
01041                             fixed_param_names=self._fixed_param_names)
01042             module.bind(data_batch.provide_data, data_batch.provide_label, self._curr_module.for_training,
01043                         self._curr_module.inputs_need_grad, force_rebind=False,
01044                         shared_module=self._curr_module)
01045             self._curr_module = module
01046 
01047         self._curr_module.forward(data_batch, is_train=is_train)
01048 
01049     def backward(self, out_grads=None):
01050         assert self.binded and self.params_initialized
01051         self._curr_module.backward(out_grads=out_grads)
01052 
01053     def update(self):
01054         assert self.binded and self.params_initialized and self.optimizer_initialized
01055         self._curr_module.update()
01056 
01057     def get_outputs(self, merge_multi_context=True):
01058         assert self.binded and self.params_initialized
01059         return self._curr_module.get_outputs(merge_multi_context=merge_multi_context)
01060     def get_input_grads(self, merge_multi_context=True):
01061         assert self.binded and self.params_initialized and self.inputs_need_grad
01062         return self._curr_module.get_input_grads(merge_multi_context=merge_multi_context)
01063 
01064     def update_metric(self, eval_metric, labels):
01065         assert self.binded and self.params_initialized
01066         self._curr_module.update_metric(eval_metric, labels)
01067 
01068     def install_monitor(self, mon):
01069         """ Install monitor on all executors """
01070         assert self.binded
01071         self._curr_module.install_monitor(mon)


rail_object_detector
Author(s):
autogenerated on Sat Jun 8 2019 20:26:30