symbol.py
Go to the documentation of this file.
00001 # --------------------------------------------------------
00002 # Deformable Convolutional Networks
00003 # Copyright (c) 2017 Microsoft
00004 # Licensed under The Apache-2.0 License [see LICENSE for details]
00005 # Written by Yuwen Xiong
00006 # --------------------------------------------------------
00007 
00008 import numpy as np
00009 class Symbol:
00010     def __init__(self):
00011         self.arg_shape_dict = None
00012         self.out_shape_dict = None
00013         self.aux_shape_dict = None
00014         self.sym = None
00015 
00016     @property
00017     def symbol(self):
00018         return self.sym
00019 
00020     def get_symbol(self, cfg, is_train=True):
00021         """
00022         return a generated symbol, it also need to be assigned to self.sym
00023         """
00024         raise NotImplementedError()
00025 
00026     def init_weights(self, cfg, arg_params, aux_params):
00027         raise NotImplementedError()
00028 
00029     def get_msra_std(self, shape):
00030         fan_in = float(shape[1])
00031         if len(shape) > 2:
00032             fan_in *= np.prod(shape[2:])
00033         print(np.sqrt(2 / fan_in))
00034         return np.sqrt(2 / fan_in)
00035 
00036     def infer_shape(self, data_shape_dict):
00037         # infer shape
00038         arg_shape, out_shape, aux_shape = self.sym.infer_shape(**data_shape_dict)
00039         self.arg_shape_dict = dict(zip(self.sym.list_arguments(), arg_shape))
00040         self.out_shape_dict = dict(zip(self.sym.list_outputs(), out_shape))
00041         self.aux_shape_dict = dict(zip(self.sym.list_auxiliary_states(), aux_shape))
00042 
00043     def check_parameter_shapes(self, arg_params, aux_params, data_shape_dict, is_train=True):
00044         for k in self.sym.list_arguments():
00045             if k in data_shape_dict or (False if is_train else 'label' in k):
00046                 continue
00047             assert k in arg_params, k + ' not initialized'
00048             assert arg_params[k].shape == self.arg_shape_dict[k], \
00049                 'shape inconsistent for ' + k + ' inferred ' + str(self.arg_shape_dict[k]) + ' provided ' + str(
00050                     arg_params[k].shape)
00051         for k in self.sym.list_auxiliary_states():
00052             assert k in aux_params, k + ' not initialized'
00053             assert aux_params[k].shape == self.aux_shape_dict[k], \
00054                 'shape inconsistent for ' + k + ' inferred ' + str(self.aux_shape_dict[k]) + ' provided ' + str(
00055                     aux_params[k].shape)


rail_object_detector
Author(s):
autogenerated on Sat Jun 8 2019 20:26:31