freeze_model.py
Go to the documentation of this file.
00001 # vim: expandtab:ts=4:sw=4
00002 import argparse
00003 import tensorflow as tf
00004 import tensorflow.contrib.slim as slim
00005 
00006 
00007 def _batch_norm_fn(x, scope=None):
00008     if scope is None:
00009         scope = tf.get_variable_scope().name + "/bn"
00010     return slim.batch_norm(x, scope=scope)
00011 
00012 
00013 def create_link(
00014         incoming, network_builder, scope, nonlinearity=tf.nn.elu,
00015         weights_initializer=tf.truncated_normal_initializer(stddev=1e-3),
00016         regularizer=None, is_first=False, summarize_activations=True):
00017     if is_first:
00018         network = incoming
00019     else:
00020         network = _batch_norm_fn(incoming, scope=scope + "/bn")
00021         network = nonlinearity(network)
00022         if summarize_activations:
00023             tf.summary.histogram(scope+"/activations", network)
00024 
00025     pre_block_network = network
00026     post_block_network = network_builder(pre_block_network, scope)
00027 
00028     incoming_dim = pre_block_network.get_shape().as_list()[-1]
00029     outgoing_dim = post_block_network.get_shape().as_list()[-1]
00030     if incoming_dim != outgoing_dim:
00031         assert outgoing_dim == 2 * incoming_dim, \
00032             "%d != %d" % (outgoing_dim, 2 * incoming)
00033         projection = slim.conv2d(
00034             incoming, outgoing_dim, 1, 2, padding="SAME", activation_fn=None,
00035             scope=scope+"/projection", weights_initializer=weights_initializer,
00036             biases_initializer=None, weights_regularizer=regularizer)
00037         network = projection + post_block_network
00038     else:
00039         network = incoming + post_block_network
00040     return network
00041 
00042 
00043 def create_inner_block(
00044         incoming, scope, nonlinearity=tf.nn.elu,
00045         weights_initializer=tf.truncated_normal_initializer(1e-3),
00046         bias_initializer=tf.zeros_initializer(), regularizer=None,
00047         increase_dim=False, summarize_activations=True):
00048     n = incoming.get_shape().as_list()[-1]
00049     stride = 1
00050     if increase_dim:
00051         n *= 2
00052         stride = 2
00053 
00054     incoming = slim.conv2d(
00055         incoming, n, [3, 3], stride, activation_fn=nonlinearity, padding="SAME",
00056         normalizer_fn=_batch_norm_fn, weights_initializer=weights_initializer,
00057         biases_initializer=bias_initializer, weights_regularizer=regularizer,
00058         scope=scope + "/1")
00059     if summarize_activations:
00060         tf.summary.histogram(incoming.name + "/activations", incoming)
00061 
00062     incoming = slim.dropout(incoming, keep_prob=0.6)
00063 
00064     incoming = slim.conv2d(
00065         incoming, n, [3, 3], 1, activation_fn=None, padding="SAME",
00066         normalizer_fn=None, weights_initializer=weights_initializer,
00067         biases_initializer=bias_initializer, weights_regularizer=regularizer,
00068         scope=scope + "/2")
00069     return incoming
00070 
00071 
00072 def residual_block(incoming, scope, nonlinearity=tf.nn.elu,
00073                    weights_initializer=tf.truncated_normal_initializer(1e3),
00074                    bias_initializer=tf.zeros_initializer(), regularizer=None,
00075                    increase_dim=False, is_first=False,
00076                    summarize_activations=True):
00077 
00078     def network_builder(x, s):
00079         return create_inner_block(
00080             x, s, nonlinearity, weights_initializer, bias_initializer,
00081             regularizer, increase_dim, summarize_activations)
00082 
00083     return create_link(
00084         incoming, network_builder, scope, nonlinearity, weights_initializer,
00085         regularizer, is_first, summarize_activations)
00086 
00087 
00088 def _create_network(incoming, reuse=None, weight_decay=1e-8):
00089     nonlinearity = tf.nn.elu
00090     conv_weight_init = tf.truncated_normal_initializer(stddev=1e-3)
00091     conv_bias_init = tf.zeros_initializer()
00092     conv_regularizer = slim.l2_regularizer(weight_decay)
00093     fc_weight_init = tf.truncated_normal_initializer(stddev=1e-3)
00094     fc_bias_init = tf.zeros_initializer()
00095     fc_regularizer = slim.l2_regularizer(weight_decay)
00096 
00097     def batch_norm_fn(x):
00098         return slim.batch_norm(x, scope=tf.get_variable_scope().name + "/bn")
00099 
00100     network = incoming
00101     network = slim.conv2d(
00102         network, 32, [3, 3], stride=1, activation_fn=nonlinearity,
00103         padding="SAME", normalizer_fn=batch_norm_fn, scope="conv1_1",
00104         weights_initializer=conv_weight_init, biases_initializer=conv_bias_init,
00105         weights_regularizer=conv_regularizer)
00106     network = slim.conv2d(
00107         network, 32, [3, 3], stride=1, activation_fn=nonlinearity,
00108         padding="SAME", normalizer_fn=batch_norm_fn, scope="conv1_2",
00109         weights_initializer=conv_weight_init, biases_initializer=conv_bias_init,
00110         weights_regularizer=conv_regularizer)
00111 
00112     # NOTE(nwojke): This is missing a padding="SAME" to match the CNN
00113     # architecture in Table 1 of the paper. Information on how this affects
00114     # performance on MOT 16 training sequences can be found in
00115     # issue 10 https://github.com/nwojke/deep_sort/issues/10
00116     network = slim.max_pool2d(network, [3, 3], [2, 2], scope="pool1")
00117 
00118     network = residual_block(
00119         network, "conv2_1", nonlinearity, conv_weight_init, conv_bias_init,
00120         conv_regularizer, increase_dim=False, is_first=True)
00121     network = residual_block(
00122         network, "conv2_3", nonlinearity, conv_weight_init, conv_bias_init,
00123         conv_regularizer, increase_dim=False)
00124 
00125     network = residual_block(
00126         network, "conv3_1", nonlinearity, conv_weight_init, conv_bias_init,
00127         conv_regularizer, increase_dim=True)
00128     network = residual_block(
00129         network, "conv3_3", nonlinearity, conv_weight_init, conv_bias_init,
00130         conv_regularizer, increase_dim=False)
00131 
00132     network = residual_block(
00133         network, "conv4_1", nonlinearity, conv_weight_init, conv_bias_init,
00134         conv_regularizer, increase_dim=True)
00135     network = residual_block(
00136         network, "conv4_3", nonlinearity, conv_weight_init, conv_bias_init,
00137         conv_regularizer, increase_dim=False)
00138 
00139     feature_dim = network.get_shape().as_list()[-1]
00140     network = slim.flatten(network)
00141 
00142     network = slim.dropout(network, keep_prob=0.6)
00143     network = slim.fully_connected(
00144         network, feature_dim, activation_fn=nonlinearity,
00145         normalizer_fn=batch_norm_fn, weights_regularizer=fc_regularizer,
00146         scope="fc1", weights_initializer=fc_weight_init,
00147         biases_initializer=fc_bias_init)
00148 
00149     features = network
00150 
00151     # Features in rows, normalize axis 1.
00152     features = slim.batch_norm(features, scope="ball", reuse=reuse)
00153     feature_norm = tf.sqrt(
00154         tf.constant(1e-8, tf.float32) +
00155         tf.reduce_sum(tf.square(features), [1], keepdims=True))
00156     features = features / feature_norm
00157     return features, None
00158 
00159 
00160 def _network_factory(weight_decay=1e-8):
00161 
00162     def factory_fn(image, reuse):
00163             with slim.arg_scope([slim.batch_norm, slim.dropout],
00164                                 is_training=False):
00165                 with slim.arg_scope([slim.conv2d, slim.fully_connected,
00166                                      slim.batch_norm, slim.layer_norm],
00167                                     reuse=reuse):
00168                     features, logits = _create_network(
00169                         image, reuse=reuse, weight_decay=weight_decay)
00170                     return features, logits
00171 
00172     return factory_fn
00173 
00174 
00175 def _preprocess(image):
00176     image = image[:, :, ::-1]  # BGR to RGB
00177     return image
00178 
00179 
00180 def parse_args():
00181     """Parse command line arguments.
00182     """
00183     parser = argparse.ArgumentParser(description="Freeze old model")
00184     parser.add_argument(
00185         "--checkpoint_in",
00186         default="resources/networks/mars-small128.ckpt-68577",
00187         help="Path to checkpoint file")
00188     parser.add_argument(
00189         "--graphdef_out",
00190         default="resources/networks/mars-small128.pb")
00191     return parser.parse_args()
00192 
00193 
00194 def main():
00195     args = parse_args()
00196 
00197     with tf.Session(graph=tf.Graph()) as session:
00198         input_var = tf.placeholder(
00199             tf.uint8, (None, 128, 64, 3), name="images")
00200         image_var = tf.map_fn(
00201             lambda x: _preprocess(x), tf.cast(input_var, tf.float32),
00202             back_prop=False)
00203 
00204         factory_fn = _network_factory()
00205         features, _ = factory_fn(image_var, reuse=None)
00206         features = tf.identity(features, name="features")
00207 
00208         saver = tf.train.Saver(slim.get_variables_to_restore())
00209         saver.restore(session, args.checkpoint_in)
00210 
00211         output_graph_def = tf.graph_util.convert_variables_to_constants(
00212             session, tf.get_default_graph().as_graph_def(),
00213             [features.name.split(":")[0]])
00214         with tf.gfile.GFile(args.graphdef_out, "wb") as file_handle:
00215             file_handle.write(output_graph_def.SerializeToString())
00216 
00217 
00218 if __name__ == "__main__":
00219     main()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Tue Jul 2 2019 19:41:07