Go to the documentation of this file.00001 import mxnet as mx
00002
00003
00004 def load_checkpoint(prefix, epoch):
00005 """
00006 Load model checkpoint from file.
00007 :param prefix: Prefix of model name.
00008 :param epoch: Epoch number of model we would like to load.
00009 :return: (arg_params, aux_params)
00010 arg_params : dict of str to NDArray
00011 Model parameter, dict of name to NDArray of net's weights.
00012 aux_params : dict of str to NDArray
00013 Model parameter, dict of name to NDArray of net's auxiliary states.
00014 """
00015 save_dict = mx.nd.load('%s-%04d.params' % (prefix, epoch))
00016 arg_params = {}
00017 aux_params = {}
00018 for k, v in save_dict.items():
00019 tp, name = k.split(':', 1)
00020 if tp == 'arg':
00021 arg_params[name] = v
00022 if tp == 'aux':
00023 aux_params[name] = v
00024 return arg_params, aux_params
00025
00026
00027 def convert_context(params, ctx):
00028 """
00029 :param params: dict of str to NDArray
00030 :param ctx: the context to convert to
00031 :return: dict of str of NDArray with context ctx
00032 """
00033 new_params = dict()
00034 for k, v in params.items():
00035 new_params[k] = v.as_in_context(ctx)
00036 return new_params
00037
00038
00039 def load_param(prefix, epoch, convert=False, ctx=None, process=False):
00040 """
00041 wrapper for load checkpoint
00042 :param prefix: Prefix of model name.
00043 :param epoch: Epoch number of model we would like to load.
00044 :param convert: reference model should be converted to GPU NDArray first
00045 :param ctx: if convert then ctx must be designated.
00046 :param process: model should drop any test
00047 :return: (arg_params, aux_params)
00048 """
00049 arg_params, aux_params = load_checkpoint(prefix, epoch)
00050 if convert:
00051 if ctx is None:
00052 ctx = mx.cpu()
00053 arg_params = convert_context(arg_params, ctx)
00054 aux_params = convert_context(aux_params, ctx)
00055 if process:
00056 tests = [k for k in arg_params.keys() if '_test' in k]
00057 for test in tests:
00058 arg_params[test.replace('_test', '')] = arg_params.pop(test)
00059 return arg_params, aux_params