Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008 import numpy as np
00009 class Symbol:
00010 def __init__(self):
00011 self.arg_shape_dict = None
00012 self.out_shape_dict = None
00013 self.aux_shape_dict = None
00014 self.sym = None
00015
00016 @property
00017 def symbol(self):
00018 return self.sym
00019
00020 def get_symbol(self, cfg, is_train=True):
00021 """
00022 return a generated symbol, it also need to be assigned to self.sym
00023 """
00024 raise NotImplementedError()
00025
00026 def init_weights(self, cfg, arg_params, aux_params):
00027 raise NotImplementedError()
00028
00029 def get_msra_std(self, shape):
00030 fan_in = float(shape[1])
00031 if len(shape) > 2:
00032 fan_in *= np.prod(shape[2:])
00033 print(np.sqrt(2 / fan_in))
00034 return np.sqrt(2 / fan_in)
00035
00036 def infer_shape(self, data_shape_dict):
00037
00038 arg_shape, out_shape, aux_shape = self.sym.infer_shape(**data_shape_dict)
00039 self.arg_shape_dict = dict(zip(self.sym.list_arguments(), arg_shape))
00040 self.out_shape_dict = dict(zip(self.sym.list_outputs(), out_shape))
00041 self.aux_shape_dict = dict(zip(self.sym.list_auxiliary_states(), aux_shape))
00042
00043 def check_parameter_shapes(self, arg_params, aux_params, data_shape_dict, is_train=True):
00044 for k in self.sym.list_arguments():
00045 if k in data_shape_dict or (False if is_train else 'label' in k):
00046 continue
00047 assert k in arg_params, k + ' not initialized'
00048 assert arg_params[k].shape == self.arg_shape_dict[k], \
00049 'shape inconsistent for ' + k + ' inferred ' + str(self.arg_shape_dict[k]) + ' provided ' + str(
00050 arg_params[k].shape)
00051 for k in self.sym.list_auxiliary_states():
00052 assert k in aux_params, k + ' not initialized'
00053 assert aux_params[k].shape == self.aux_shape_dict[k], \
00054 'shape inconsistent for ' + k + ' inferred ' + str(self.aux_shape_dict[k]) + ' provided ' + str(
00055 aux_params[k].shape)