load_model.py
Go to the documentation of this file.
00001 import mxnet as mx
00002 
00003 
00004 def load_checkpoint(prefix, epoch):
00005     """
00006     Load model checkpoint from file.
00007     :param prefix: Prefix of model name.
00008     :param epoch: Epoch number of model we would like to load.
00009     :return: (arg_params, aux_params)
00010     arg_params : dict of str to NDArray
00011         Model parameter, dict of name to NDArray of net's weights.
00012     aux_params : dict of str to NDArray
00013         Model parameter, dict of name to NDArray of net's auxiliary states.
00014     """
00015     save_dict = mx.nd.load('%s-%04d.params' % (prefix, epoch))
00016     arg_params = {}
00017     aux_params = {}
00018     for k, v in save_dict.items():
00019         tp, name = k.split(':', 1)
00020         if tp == 'arg':
00021             arg_params[name] = v
00022         if tp == 'aux':
00023             aux_params[name] = v
00024     return arg_params, aux_params
00025 
00026 
00027 def convert_context(params, ctx):
00028     """
00029     :param params: dict of str to NDArray
00030     :param ctx: the context to convert to
00031     :return: dict of str of NDArray with context ctx
00032     """
00033     new_params = dict()
00034     for k, v in params.items():
00035         new_params[k] = v.as_in_context(ctx)
00036     return new_params
00037 
00038 
00039 def load_param(prefix, epoch, convert=False, ctx=None, process=False):
00040     """
00041     wrapper for load checkpoint
00042     :param prefix: Prefix of model name.
00043     :param epoch: Epoch number of model we would like to load.
00044     :param convert: reference model should be converted to GPU NDArray first
00045     :param ctx: if convert then ctx must be designated.
00046     :param process: model should drop any test
00047     :return: (arg_params, aux_params)
00048     """
00049     arg_params, aux_params = load_checkpoint(prefix, epoch)
00050     if convert:
00051         if ctx is None:
00052             ctx = mx.cpu()
00053         arg_params = convert_context(arg_params, ctx)
00054         aux_params = convert_context(aux_params, ctx)
00055     if process:
00056         tests = [k for k in arg_params.keys() if '_test' in k]
00057         for test in tests:
00058             arg_params[test.replace('_test', '')] = arg_params.pop(test)
00059     return arg_params, aux_params


rail_object_detector
Author(s):
autogenerated on Sat Jun 8 2019 20:26:30