freeze_model.py
Go to the documentation of this file.
1 # vim: expandtab:ts=4:sw=4
2 import argparse
3 import tensorflow as tf
4 import tensorflow.contrib.slim as slim
5 
6 
7 def _batch_norm_fn(x, scope=None):
8  if scope is None:
9  scope = tf.get_variable_scope().name + "/bn"
10  return slim.batch_norm(x, scope=scope)
11 
12 
13 def create_link(
14  incoming, network_builder, scope, nonlinearity=tf.nn.elu,
15  weights_initializer=tf.truncated_normal_initializer(stddev=1e-3),
16  regularizer=None, is_first=False, summarize_activations=True):
17  if is_first:
18  network = incoming
19  else:
20  network = _batch_norm_fn(incoming, scope=scope + "/bn")
21  network = nonlinearity(network)
22  if summarize_activations:
23  tf.summary.histogram(scope+"/activations", network)
24 
25  pre_block_network = network
26  post_block_network = network_builder(pre_block_network, scope)
27 
28  incoming_dim = pre_block_network.get_shape().as_list()[-1]
29  outgoing_dim = post_block_network.get_shape().as_list()[-1]
30  if incoming_dim != outgoing_dim:
31  assert outgoing_dim == 2 * incoming_dim, \
32  "%d != %d" % (outgoing_dim, 2 * incoming)
33  projection = slim.conv2d(
34  incoming, outgoing_dim, 1, 2, padding="SAME", activation_fn=None,
35  scope=scope+"/projection", weights_initializer=weights_initializer,
36  biases_initializer=None, weights_regularizer=regularizer)
37  network = projection + post_block_network
38  else:
39  network = incoming + post_block_network
40  return network
41 
42 
44  incoming, scope, nonlinearity=tf.nn.elu,
45  weights_initializer=tf.truncated_normal_initializer(1e-3),
46  bias_initializer=tf.zeros_initializer(), regularizer=None,
47  increase_dim=False, summarize_activations=True):
48  n = incoming.get_shape().as_list()[-1]
49  stride = 1
50  if increase_dim:
51  n *= 2
52  stride = 2
53 
54  incoming = slim.conv2d(
55  incoming, n, [3, 3], stride, activation_fn=nonlinearity, padding="SAME",
56  normalizer_fn=_batch_norm_fn, weights_initializer=weights_initializer,
57  biases_initializer=bias_initializer, weights_regularizer=regularizer,
58  scope=scope + "/1")
59  if summarize_activations:
60  tf.summary.histogram(incoming.name + "/activations", incoming)
61 
62  incoming = slim.dropout(incoming, keep_prob=0.6)
63 
64  incoming = slim.conv2d(
65  incoming, n, [3, 3], 1, activation_fn=None, padding="SAME",
66  normalizer_fn=None, weights_initializer=weights_initializer,
67  biases_initializer=bias_initializer, weights_regularizer=regularizer,
68  scope=scope + "/2")
69  return incoming
70 
71 
72 def residual_block(incoming, scope, nonlinearity=tf.nn.elu,
73  weights_initializer=tf.truncated_normal_initializer(1e3),
74  bias_initializer=tf.zeros_initializer(), regularizer=None,
75  increase_dim=False, is_first=False,
76  summarize_activations=True):
77 
78  def network_builder(x, s):
79  return create_inner_block(
80  x, s, nonlinearity, weights_initializer, bias_initializer,
81  regularizer, increase_dim, summarize_activations)
82 
83  return create_link(
84  incoming, network_builder, scope, nonlinearity, weights_initializer,
85  regularizer, is_first, summarize_activations)
86 
87 
88 def _create_network(incoming, reuse=None, weight_decay=1e-8):
89  nonlinearity = tf.nn.elu
90  conv_weight_init = tf.truncated_normal_initializer(stddev=1e-3)
91  conv_bias_init = tf.zeros_initializer()
92  conv_regularizer = slim.l2_regularizer(weight_decay)
93  fc_weight_init = tf.truncated_normal_initializer(stddev=1e-3)
94  fc_bias_init = tf.zeros_initializer()
95  fc_regularizer = slim.l2_regularizer(weight_decay)
96 
97  def batch_norm_fn(x):
98  return slim.batch_norm(x, scope=tf.get_variable_scope().name + "/bn")
99 
100  network = incoming
101  network = slim.conv2d(
102  network, 32, [3, 3], stride=1, activation_fn=nonlinearity,
103  padding="SAME", normalizer_fn=batch_norm_fn, scope="conv1_1",
104  weights_initializer=conv_weight_init, biases_initializer=conv_bias_init,
105  weights_regularizer=conv_regularizer)
106  network = slim.conv2d(
107  network, 32, [3, 3], stride=1, activation_fn=nonlinearity,
108  padding="SAME", normalizer_fn=batch_norm_fn, scope="conv1_2",
109  weights_initializer=conv_weight_init, biases_initializer=conv_bias_init,
110  weights_regularizer=conv_regularizer)
111 
112  # NOTE(nwojke): This is missing a padding="SAME" to match the CNN
113  # architecture in Table 1 of the paper. Information on how this affects
114  # performance on MOT 16 training sequences can be found in
115  # issue 10 https://github.com/nwojke/deep_sort/issues/10
116  network = slim.max_pool2d(network, [3, 3], [2, 2], scope="pool1")
117 
118  network = residual_block(
119  network, "conv2_1", nonlinearity, conv_weight_init, conv_bias_init,
120  conv_regularizer, increase_dim=False, is_first=True)
121  network = residual_block(
122  network, "conv2_3", nonlinearity, conv_weight_init, conv_bias_init,
123  conv_regularizer, increase_dim=False)
124 
125  network = residual_block(
126  network, "conv3_1", nonlinearity, conv_weight_init, conv_bias_init,
127  conv_regularizer, increase_dim=True)
128  network = residual_block(
129  network, "conv3_3", nonlinearity, conv_weight_init, conv_bias_init,
130  conv_regularizer, increase_dim=False)
131 
132  network = residual_block(
133  network, "conv4_1", nonlinearity, conv_weight_init, conv_bias_init,
134  conv_regularizer, increase_dim=True)
135  network = residual_block(
136  network, "conv4_3", nonlinearity, conv_weight_init, conv_bias_init,
137  conv_regularizer, increase_dim=False)
138 
139  feature_dim = network.get_shape().as_list()[-1]
140  network = slim.flatten(network)
141 
142  network = slim.dropout(network, keep_prob=0.6)
143  network = slim.fully_connected(
144  network, feature_dim, activation_fn=nonlinearity,
145  normalizer_fn=batch_norm_fn, weights_regularizer=fc_regularizer,
146  scope="fc1", weights_initializer=fc_weight_init,
147  biases_initializer=fc_bias_init)
148 
149  features = network
150 
151  # Features in rows, normalize axis 1.
152  features = slim.batch_norm(features, scope="ball", reuse=reuse)
153  feature_norm = tf.sqrt(
154  tf.constant(1e-8, tf.float32) +
155  tf.reduce_sum(tf.square(features), [1], keepdims=True))
156  features = features / feature_norm
157  return features, None
158 
159 
160 def _network_factory(weight_decay=1e-8):
161 
162  def factory_fn(image, reuse):
163  with slim.arg_scope([slim.batch_norm, slim.dropout],
164  is_training=False):
165  with slim.arg_scope([slim.conv2d, slim.fully_connected,
166  slim.batch_norm, slim.layer_norm],
167  reuse=reuse):
168  features, logits = _create_network(
169  image, reuse=reuse, weight_decay=weight_decay)
170  return features, logits
171 
172  return factory_fn
173 
174 
175 def _preprocess(image):
176  image = image[:, :, ::-1] # BGR to RGB
177  return image
178 
179 
181  """Parse command line arguments.
182  """
183  parser = argparse.ArgumentParser(description="Freeze old model")
184  parser.add_argument(
185  "--checkpoint_in",
186  default="resources/networks/mars-small128.ckpt-68577",
187  help="Path to checkpoint file")
188  parser.add_argument(
189  "--graphdef_out",
190  default="resources/networks/mars-small128.pb")
191  return parser.parse_args()
192 
193 
194 def main():
195  args = parse_args()
196 
197  with tf.Session(graph=tf.Graph()) as session:
198  input_var = tf.placeholder(
199  tf.uint8, (None, 128, 64, 3), name="images")
200  image_var = tf.map_fn(
201  lambda x: _preprocess(x), tf.cast(input_var, tf.float32),
202  back_prop=False)
203 
204  factory_fn = _network_factory()
205  features, _ = factory_fn(image_var, reuse=None)
206  features = tf.identity(features, name="features")
207 
208  saver = tf.train.Saver(slim.get_variables_to_restore())
209  saver.restore(session, args.checkpoint_in)
210 
211  output_graph_def = tf.graph_util.convert_variables_to_constants(
212  session, tf.get_default_graph().as_graph_def(),
213  [features.name.split(":")[0]])
214  with tf.gfile.GFile(args.graphdef_out, "wb") as file_handle:
215  file_handle.write(output_graph_def.SerializeToString())
216 
217 
218 if __name__ == "__main__":
219  main()
def residual_block(incoming, scope, nonlinearity=tf.nn.elu, weights_initializer=tf.truncated_normal_initializer(1e3), bias_initializer=tf.zeros_initializer(), regularizer=None, increase_dim=False, is_first=False, summarize_activations=True)
Definition: freeze_model.py:76
def parse_args()
def create_inner_block(incoming, scope, nonlinearity=tf.nn.elu, weights_initializer=tf.truncated_normal_initializer(1e-3), bias_initializer=tf.zeros_initializer(), regularizer=None, increase_dim=False, summarize_activations=True)
Definition: freeze_model.py:47
def _batch_norm_fn(x, scope=None)
Definition: freeze_model.py:7
def _network_factory(weight_decay=1e-8)
def _preprocess(image)
def create_link(incoming, network_builder, scope, nonlinearity=tf.nn.elu, weights_initializer=tf.truncated_normal_initializer(stddev=1e-3), regularizer=None, is_first=False, summarize_activations=True)
Definition: freeze_model.py:16
def _create_network(incoming, reuse=None, weight_decay=1e-8)
Definition: freeze_model.py:88


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Mon May 3 2021 03:03:27