00001
00002
00003
00004
00005
00006
00007
00008
00009 import logging
00010 import numpy as np
00011
00012 from mxnet import context as ctx
00013 from mxnet import ndarray as nd
00014 from mxnet.io import DataDesc
00015 from mxnet.executor_manager import _split_input_slice
00016
00017
00018
00019 def _load_general(data, targets, major_axis):
00020 """Load a list of arrays into a list of arrays specified by slices"""
00021 for d_src, d_targets in zip(data, targets):
00022 if isinstance(d_targets, nd.NDArray):
00023 d_src.copyto(d_targets)
00024 elif isinstance(d_src, (list, tuple)):
00025 for src, dst in zip(d_src, d_targets):
00026 src.copyto(dst)
00027 else:
00028 raise NotImplementedError
00029
00030
00031 def _load_data(batch, targets, major_axis):
00032 """Load data into sliced arrays"""
00033 _load_general(batch.data, targets, major_axis)
00034
00035
00036 def _load_label(batch, targets, major_axis):
00037 """Load label into sliced arrays"""
00038 _load_general(batch.label, targets, major_axis)
00039
00040
00041 def _merge_multi_context(outputs, major_axis):
00042 """Merge outputs that lives on multiple context into one, so that they look
00043 like living on one context.
00044 """
00045 rets = []
00046 for tensors, axis in zip(outputs, major_axis):
00047 if axis >= 0:
00048 rets.append(nd.concatenate(tensors, axis=axis, always_copy=False))
00049 else:
00050
00051
00052
00053 rets.append(tensors[0])
00054 return rets
00055
00056
00057
00058 class DataParallelExecutorGroup(object):
00059 """DataParallelExecutorGroup is a group of executors that lives on a group of devices.
00060 This is a helper class used to implement data parallelization. Each mini-batch will
00061 be split and run on the devices.
00062
00063 Parameters
00064 ----------
00065 symbol : Symbol
00066 The common symbolic computation graph for all executors.
00067 contexts : list
00068 A list of contexts.
00069 workload : list
00070 If not `None`, could be a list of numbers that specify the workload to be assigned
00071 to different context. Larger number indicate heavier workload.
00072 data_shapes : list
00073 Should be a list of (name, shape) tuples, for the shapes of data. Note the order is
00074 important and should be the same as the order that the `DataIter` provide the data.
00075 label_shapes : list
00076 Should be a list of (name, shape) tuples, for the shapes of label. Note the order is
00077 important and should be the same as the order that the `DataIter` provide the label.
00078 param_names : list
00079 A list of strings, indicating the names of parameters (e.g. weights, filters, etc.)
00080 in the computation graph.
00081 for_training : bool
00082 Indicate whether the executors should be bind for training. When not doing training,
00083 the memory for gradients will not be allocated.
00084 inputs_need_grad : bool
00085 Indicate whether the gradients for the input data should be computed. This is currently
00086 not used. It will be useful for implementing composition of modules.
00087 shared_group : DataParallelExecutorGroup
00088 Default is `None`. This is used in bucketing. When not `None`, it should be a executor
00089 group corresponding to a different bucket. In other words, it will correspond to a different
00090 symbol but with the same set of parameters (e.g. unrolled RNNs with different lengths).
00091 In this case, many memory will be shared.
00092 logger : Logger
00093 Default is `logging`.
00094 fixed_param_names: list of str
00095 Indicate parameters to be fixed during training. Parameters in this list will not allocate
00096 space for gradient, nor do gradient calculation.
00097 grad_req : str, list of str, dict of str to str
00098 Requirement for gradient accumulation. Can be 'write', 'add', or 'null'
00099 (default to 'write').
00100 Can be specified globally (str) or for each argument (list, dict).
00101 """
00102 def __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_names,
00103 for_training, inputs_need_grad, shared_group=None, logger=logging,
00104 fixed_param_names=None, grad_req='write', state_names=None):
00105 self.param_names = param_names
00106 self.arg_names = symbol.list_arguments()
00107 self.aux_names = symbol.list_auxiliary_states()
00108
00109 self.symbol = symbol
00110 self.contexts = contexts
00111 self.workload = workload
00112
00113 self.for_training = for_training
00114 self.inputs_need_grad = inputs_need_grad
00115
00116 self.logger = logger
00117
00118
00119 self.fixed_param_names = fixed_param_names
00120 if self.fixed_param_names is None:
00121 self.fixed_param_names = []
00122
00123 self.state_names = state_names
00124 if self.state_names is None:
00125 self.state_names = []
00126
00127 if not for_training:
00128 grad_req = 'null'
00129
00130
00131
00132
00133
00134 data_names = [x.name for x in data_shapes[0]]
00135
00136 if isinstance(grad_req, str):
00137 self.grad_req = {}
00138 for k in self.arg_names:
00139 if k in self.param_names:
00140 self.grad_req[k] = 'null' if k in self.fixed_param_names else grad_req
00141 elif k in data_names:
00142 self.grad_req[k] = grad_req if self.inputs_need_grad else 'null'
00143 else:
00144 self.grad_req[k] = 'null'
00145 elif isinstance(grad_req, (list, tuple)):
00146 assert len(grad_req) == len(self.arg_names)
00147 self.grad_req = dict(zip(self.arg_names, grad_req))
00148 elif isinstance(grad_req, dict):
00149 self.grad_req = {}
00150 for k in self.arg_names:
00151 if k in self.param_names:
00152 self.grad_req[k] = 'null' if k in self.fixed_param_names else 'write'
00153 elif k in data_names:
00154 self.grad_req[k] = 'write' if self.inputs_need_grad else 'null'
00155 else:
00156 self.grad_req[k] = 'null'
00157 self.grad_req.update(grad_req)
00158 else:
00159 raise ValueError("grad_req must be one of str, list, tuple, or dict.")
00160
00161 if shared_group is not None:
00162 self.shared_data_arrays = shared_group.shared_data_arrays
00163 else:
00164 self.shared_data_arrays = [{} for _ in contexts]
00165
00166
00167 self.batch_size = len(data_shapes)
00168 self.slices = None
00169 self.execs = []
00170 self._default_execs = None
00171 self.data_arrays = None
00172 self.label_arrays = None
00173 self.param_arrays = None
00174 self.state_arrays = None
00175 self.grad_arrays = None
00176 self.aux_arrays = None
00177 self.input_grad_arrays = None
00178
00179 self.data_shapes = None
00180 self.label_shapes = None
00181 self.data_layouts = None
00182 self.label_layouts = None
00183 self.output_layouts = [DataDesc.get_batch_axis(self.symbol[name].attr('__layout__'))
00184 for name in self.symbol.list_outputs()]
00185 self.bind_exec(data_shapes, label_shapes, shared_group)
00186
00187 def decide_slices(self, data_shapes):
00188 """Decide the slices for each context according to the workload.
00189
00190 Parameters
00191 ----------
00192 data_shapes : list
00193 list of (name, shape) specifying the shapes for the input data or label.
00194 """
00195 assert len(data_shapes) > 0
00196 major_axis = [DataDesc.get_batch_axis(x.layout) for x in data_shapes]
00197
00198 for (name, shape), axis in zip(data_shapes, major_axis):
00199 if axis == -1:
00200 continue
00201
00202 batch_size = shape[axis]
00203 if self.batch_size is not None:
00204 assert batch_size == self.batch_size, ("all data must have the same batch size: "
00205 + ("batch_size = %d, but " % self.batch_size)
00206 + ("%s has shape %s" % (name, shape)))
00207 else:
00208 self.batch_size = batch_size
00209 self.slices = _split_input_slice(self.batch_size, self.workload)
00210
00211 return major_axis
00212
00213 def _collect_arrays(self):
00214 """Collect internal arrays from executors."""
00215
00216 self.data_arrays = [[e.arg_dict[name] for name, _ in self.data_shapes[0]] for e in self.execs]
00217
00218 self.state_arrays = [[e.arg_dict[name] for e in self.execs]
00219 for name in self.state_names]
00220
00221 if self.label_shapes is not None:
00222 self.label_arrays = [[e.arg_dict[name] for name, _ in self.label_shapes[0]] for e in self.execs]
00223 else:
00224 self.label_arrays = None
00225
00226 self.param_arrays = [[exec_.arg_arrays[i] for exec_ in self.execs]
00227 for i, name in enumerate(self.arg_names)
00228 if name in self.param_names]
00229 if self.for_training:
00230 self.grad_arrays = [[exec_.grad_arrays[i] for exec_ in self.execs]
00231 for i, name in enumerate(self.arg_names)
00232 if name in self.param_names]
00233 else:
00234 self.grad_arrays = None
00235
00236 data_names = [x[0] for x in self.data_shapes]
00237 if self.inputs_need_grad:
00238 self.input_grad_arrays = [[exec_.grad_arrays[i] for exec_ in self.execs]
00239 for i, name in enumerate(self.arg_names)
00240 if name in data_names]
00241 else:
00242 self.input_grad_arrays = None
00243
00244 self.aux_arrays = [[exec_.aux_arrays[i] for exec_ in self.execs]
00245 for i in range(len(self.aux_names))]
00246
00247 def bind_exec(self, data_shapes, label_shapes, shared_group=None, reshape=False):
00248 """Bind executors on their respective devices.
00249
00250 Parameters
00251 ----------
00252 data_shapes : list
00253 label_shapes : list
00254 shared_group : DataParallelExecutorGroup
00255 reshape : bool
00256 """
00257 assert reshape or not self.execs
00258
00259 for i in range(len(self.contexts)):
00260 data_shapes_i = data_shapes[i]
00261 if label_shapes is not None:
00262 label_shapes_i = label_shapes[i]
00263 else:
00264 label_shapes_i = []
00265
00266 if reshape:
00267 self.execs[i] = self._default_execs[i].reshape(
00268 allow_up_sizing=True, **dict(data_shapes_i + label_shapes_i))
00269 else:
00270 self.execs.append(self._bind_ith_exec(i, data_shapes_i, label_shapes_i,
00271 shared_group))
00272
00273 self.data_shapes = data_shapes
00274 self.label_shapes = label_shapes
00275 self._collect_arrays()
00276
00277 def reshape(self, data_shapes, label_shapes):
00278 """Reshape executors.
00279
00280 Parameters
00281 ----------
00282 data_shapes : list
00283 label_shapes : list
00284 """
00285 if self._default_execs is None:
00286 self._default_execs = [i for i in self.execs]
00287 for i in range(len(self.contexts)):
00288 self.execs[i] = self._default_execs[i].reshape(
00289 allow_up_sizing=True, **dict(data_shapes[i] + (label_shapes[i] if label_shapes is not None else []))
00290 )
00291 self.data_shapes = data_shapes
00292 self.label_shapes = label_shapes
00293 self._collect_arrays()
00294
00295
00296 def set_params(self, arg_params, aux_params):
00297 """Assign, i.e. copy parameters to all the executors.
00298
00299 Parameters
00300 ----------
00301 arg_params : dict
00302 A dictionary of name to `NDArray` parameter mapping.
00303 aux_params : dict
00304 A dictionary of name to `NDArray` auxiliary variable mapping.
00305 """
00306 for exec_ in self.execs:
00307 exec_.copy_params_from(arg_params, aux_params)
00308
00309 def get_params(self, arg_params, aux_params):
00310 """ Copy data from each executor to `arg_params` and `aux_params`.
00311
00312 Parameters
00313 ----------
00314 arg_params : list of NDArray
00315 target parameter arrays
00316 aux_params : list of NDArray
00317 target aux arrays
00318
00319 Notes
00320 -----
00321 - This function will inplace update the NDArrays in arg_params and aux_params.
00322 """
00323 for name, block in zip(self.param_names, self.param_arrays):
00324 weight = sum(w.copyto(ctx.cpu()) for w in block) / len(block)
00325 weight.astype(arg_params[name].dtype).copyto(arg_params[name])
00326 for name, block in zip(self.aux_names, self.aux_arrays):
00327 weight = sum(w.copyto(ctx.cpu()) for w in block) / len(block)
00328 weight.astype(aux_params[name].dtype).copyto(aux_params[name])
00329
00330 def forward(self, data_batch, is_train=None):
00331 """Split `data_batch` according to workload and run forward on each devices.
00332
00333 Parameters
00334 ----------
00335 data_batch : DataBatch
00336 Or could be any object implementing similar interface.
00337 is_train : bool
00338 The hint for the backend, indicating whether we are during training phase.
00339 Default is `None`, then the value `self.for_training` will be used.
00340 Returns
00341 -------
00342
00343 """
00344 _load_data(data_batch, self.data_arrays, self.data_layouts)
00345 if is_train is None:
00346 is_train = self.for_training
00347
00348 if self.label_arrays is not None:
00349 assert not is_train or data_batch.label
00350 if data_batch.label:
00351 _load_label(data_batch, self.label_arrays, self.label_layouts)
00352
00353 for exec_ in self.execs:
00354 exec_.forward(is_train=is_train)
00355
00356
00357 def get_outputs(self, merge_multi_context=True):
00358 """Get outputs of the previous forward computation.
00359
00360 Parameters
00361 ----------
00362 merge_multi_context : bool
00363 Default is `True`. In the case when data-parallelism is used, the outputs
00364 will be collected from multiple devices. A `True` value indicate that we
00365 should merge the collected results so that they look like from a single
00366 executor.
00367
00368 Returns
00369 -------
00370 If `merge_multi_context` is `True`, it is like `[out1, out2]`. Otherwise, it
00371 is like `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`. All the output
00372 elements are `NDArray`.
00373 """
00374 outputs = [[exec_.outputs[i] for exec_ in self.execs]
00375 for i in range(len(self.execs[0].outputs))]
00376 if merge_multi_context:
00377 outputs = _merge_multi_context(outputs, self.output_layouts)
00378 return outputs
00379
00380 def get_states(self, merge_multi_context=True):
00381 """Get states from all devices
00382
00383 Parameters
00384 ----------
00385 merge_multi_context : bool
00386 Default is `True`. In the case when data-parallelism is used, the states
00387 will be collected from multiple devices. A `True` value indicate that we
00388 should merge the collected results so that they look like from a single
00389 executor.
00390
00391 Returns
00392 -------
00393 If `merge_multi_context` is `True`, it is like `[out1, out2]`. Otherwise, it
00394 is like `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`. All the output
00395 elements are `NDArray`.
00396 """
00397 assert not merge_multi_context, \
00398 "merge_multi_context=True is not supported for get_states yet."
00399 return self.state_arrays
00400
00401 def set_states(self, states=None, value=None):
00402 """Set value for states. Only one of states & value can be specified.
00403
00404 Parameters
00405 ----------
00406 states : list of list of NDArrays
00407 source states arrays formatted like [[state1_dev1, state1_dev2],
00408 [state2_dev1, state2_dev2]].
00409 value : number
00410 a single scalar value for all state arrays.
00411 """
00412 if states is not None:
00413 assert value is None, "Only one of states & value can be specified."
00414 _load_general(states, self.state_arrays, (0,)*len(states))
00415 else:
00416 assert value is not None, "At least one of states & value must be specified."
00417 assert states is None, "Only one of states & value can be specified."
00418 for d_dst in self.state_arrays:
00419 for dst in d_dst:
00420 dst[:] = value
00421
00422 def get_input_grads(self, merge_multi_context=True):
00423 """Get the gradients with respect to the inputs of the module.
00424
00425 Parameters
00426 ----------
00427 merge_multi_context : bool
00428 Default is `True`. In the case when data-parallelism is used, the outputs
00429 will be collected from multiple devices. A `True` value indicate that we
00430 should merge the collected results so that they look like from a single
00431 executor.
00432
00433 Returns
00434 -------
00435 If `merge_multi_context` is `True`, it is like `[grad1, grad2]`. Otherwise, it
00436 is like `[[grad1_dev1, grad1_dev2], [grad2_dev1, grad2_dev2]]`. All the output
00437 elements are `NDArray`.
00438 """
00439 assert self.inputs_need_grad
00440 if merge_multi_context:
00441 return _merge_multi_context(self.input_grad_arrays, self.data_layouts)
00442 return self.input_grad_arrays
00443
00444 def backward(self, out_grads=None):
00445 """Run backward on all devices. A backward should be called after
00446 a call to the forward function. Backward cannot be called unless
00447 `self.for_training` is `True`.
00448
00449 Parameters
00450 ----------
00451 out_grads : NDArray or list of NDArray, optional
00452 Gradient on the outputs to be propagated back.
00453 This parameter is only needed when bind is called
00454 on outputs that are not a loss function.
00455 """
00456 assert self.for_training, 're-bind with for_training=True to run backward'
00457 if out_grads is None:
00458 out_grads = []
00459
00460 for i, exec_ in enumerate(self.execs):
00461 out_grads_slice = []
00462 exec_.backward(out_grads=out_grads_slice)
00463
00464 def update_metric(self, eval_metric, labels):
00465 """Accumulate the performance according to `eval_metric` on all devices.
00466
00467 Parameters
00468 ----------
00469 eval_metric : EvalMetric
00470 The metric used for evaluation.
00471 labels : list of NDArray
00472 Typically comes from `label` of a `DataBatch`.
00473 """
00474 for texec, labels in zip(self.execs, labels):
00475 eval_metric.update(labels, texec.outputs)
00476
00477 def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group):
00478 """Internal utility function to bind the i-th executor.
00479 """
00480 shared_exec = None if shared_group is None else shared_group.execs[i]
00481 context = self.contexts[i]
00482 shared_data_arrays = self.shared_data_arrays[i]
00483
00484 input_shapes = dict(data_shapes)
00485 if label_shapes is not None:
00486 input_shapes.update(dict(label_shapes))
00487
00488 arg_shapes, _, aux_shapes = self.symbol.infer_shape(**input_shapes)
00489 assert arg_shapes is not None, "shape inference failed"
00490
00491 input_types = {x.name: x.dtype for x in data_shapes}
00492 if label_shapes is not None:
00493 input_types.update({x.name: x.dtype for x in label_shapes})
00494 arg_types, _, aux_types = self.symbol.infer_type(**input_types)
00495 assert arg_types is not None, "type inference failed"
00496
00497 arg_arrays = []
00498 grad_arrays = {} if self.for_training else None
00499
00500 def _get_or_reshape(name, shared_data_arrays, arg_shape, arg_type, context, logger):
00501 """Internal helper to get a memory block or re-use by re-shaping"""
00502 if name in shared_data_arrays:
00503 arg_arr = shared_data_arrays[name]
00504
00505 if np.prod(arg_arr.shape) >= np.prod(arg_shape):
00506
00507 assert arg_arr.dtype == arg_type
00508 arg_arr = arg_arr.reshape(arg_shape)
00509 else:
00510 logger.warning(('bucketing: data "%s" has a shape %s' % (name, arg_shape)) +
00511 (', which is larger than already allocated ') +
00512 ('shape %s' % (arg_arr.shape,)) +
00513 ('. Need to re-allocate. Consider putting ') +
00514 ('default_bucket_key to') +
00515 (' be the bucket taking the largest input for better ') +
00516 ('memory sharing.'))
00517 arg_arr = nd.zeros(arg_shape, context, dtype=arg_type)
00518
00519
00520 shared_data_arrays[name] = arg_arr
00521 else:
00522 arg_arr = nd.zeros(arg_shape, context, dtype=arg_type)
00523 shared_data_arrays[name] = arg_arr
00524
00525 return arg_arr
00526
00527
00528 for j in range(len(self.arg_names)):
00529 name = self.arg_names[j]
00530 if name in self.param_names:
00531 if shared_exec is None:
00532 arg_arr = nd.zeros(arg_shapes[j], context, dtype=arg_types[j])
00533 if self.grad_req[name] != 'null':
00534 grad_arr = nd.zeros(arg_shapes[j], context, dtype=arg_types[j])
00535 grad_arrays[name] = grad_arr
00536 else:
00537 arg_arr = shared_exec.arg_dict[name]
00538 assert arg_arr.shape == arg_shapes[j]
00539 assert arg_arr.dtype == arg_types[j]
00540 if self.grad_req[name] != 'null':
00541 grad_arrays[name] = shared_exec.grad_dict[name]
00542 else:
00543 arg_arr = _get_or_reshape(name, shared_data_arrays, arg_shapes[j], arg_types[j],
00544 context, self.logger)
00545
00546
00547 if self.grad_req[name] != 'null':
00548 grad_arrays[name] = _get_or_reshape('grad of ' + name, shared_data_arrays,
00549 arg_shapes[j], arg_types[j], context,
00550 self.logger)
00551
00552 arg_arrays.append(arg_arr)
00553
00554
00555 if shared_exec is None:
00556 aux_arrays = [nd.zeros(s, context, dtype=t) for s, t in zip(aux_shapes, aux_types)]
00557 else:
00558 for j, arr in enumerate(shared_exec.aux_arrays):
00559 assert aux_shapes[j] == arr.shape
00560 assert aux_types[j] == arr.dtype
00561 aux_arrays = shared_exec.aux_arrays[:]
00562
00563 executor = self.symbol.bind(ctx=context, args=arg_arrays,
00564 args_grad=grad_arrays, aux_states=aux_arrays,
00565 grad_req=self.grad_req, shared_exec=shared_exec)
00566
00567 return executor
00568
00569 def _sliced_shape(self, shapes, i, major_axis):
00570 """Get the sliced shapes for the i-th executor.
00571
00572 Parameters
00573 ----------
00574 shapes : list of (str, tuple)
00575 The original (name, shape) pairs.
00576 i : int
00577 Which executor we are dealing with.
00578 """
00579 sliced_shapes = []
00580 for desc, axis in zip(shapes, major_axis):
00581 shape = list(desc.shape)
00582 if axis >= 0:
00583 shape[axis] = self.slices[i].stop - self.slices[i].start
00584 sliced_shapes.append(DataDesc(desc.name, tuple(shape), desc.dtype, desc.layout))
00585 return sliced_shapes
00586
00587 def install_monitor(self, mon):
00588 """Install monitor on all executors"""
00589 for exe in self.execs:
00590 mon.install(exe)