3 import tensorflow
as tf
4 import tensorflow.contrib.slim
as slim
9 scope = tf.get_variable_scope().name +
"/bn" 10 return slim.batch_norm(x, scope=scope)
14 incoming, network_builder, scope, nonlinearity=tf.nn.elu,
15 weights_initializer=tf.truncated_normal_initializer(stddev=1e-3),
16 regularizer=
None, is_first=
False, summarize_activations=
True):
21 network = nonlinearity(network)
22 if summarize_activations:
23 tf.summary.histogram(scope+
"/activations", network)
25 pre_block_network = network
26 post_block_network = network_builder(pre_block_network, scope)
28 incoming_dim = pre_block_network.get_shape().as_list()[-1]
29 outgoing_dim = post_block_network.get_shape().as_list()[-1]
30 if incoming_dim != outgoing_dim:
31 assert outgoing_dim == 2 * incoming_dim, \
32 "%d != %d" % (outgoing_dim, 2 * incoming)
33 projection = slim.conv2d(
34 incoming, outgoing_dim, 1, 2, padding=
"SAME", activation_fn=
None,
35 scope=scope+
"/projection", weights_initializer=weights_initializer,
36 biases_initializer=
None, weights_regularizer=regularizer)
37 network = projection + post_block_network
39 network = incoming + post_block_network
44 incoming, scope, nonlinearity=tf.nn.elu,
45 weights_initializer=tf.truncated_normal_initializer(1e-3),
46 bias_initializer=tf.zeros_initializer(), regularizer=
None,
47 increase_dim=
False, summarize_activations=
True):
48 n = incoming.get_shape().as_list()[-1]
54 incoming = slim.conv2d(
55 incoming, n, [3, 3], stride, activation_fn=nonlinearity, padding=
"SAME",
56 normalizer_fn=_batch_norm_fn, weights_initializer=weights_initializer,
57 biases_initializer=bias_initializer, weights_regularizer=regularizer,
59 if summarize_activations:
60 tf.summary.histogram(incoming.name +
"/activations", incoming)
62 incoming = slim.dropout(incoming, keep_prob=0.6)
64 incoming = slim.conv2d(
65 incoming, n, [3, 3], 1, activation_fn=
None, padding=
"SAME",
66 normalizer_fn=
None, weights_initializer=weights_initializer,
67 biases_initializer=bias_initializer, weights_regularizer=regularizer,
73 weights_initializer=tf.truncated_normal_initializer(1e3),
74 bias_initializer=tf.zeros_initializer(), regularizer=
None,
75 increase_dim=
False, is_first=
False,
76 summarize_activations=
True):
78 def network_builder(x, s):
80 x, s, nonlinearity, weights_initializer, bias_initializer,
81 regularizer, increase_dim, summarize_activations)
84 incoming, network_builder, scope, nonlinearity, weights_initializer,
85 regularizer, is_first, summarize_activations)
89 nonlinearity = tf.nn.elu
90 conv_weight_init = tf.truncated_normal_initializer(stddev=1e-3)
91 conv_bias_init = tf.zeros_initializer()
92 conv_regularizer = slim.l2_regularizer(weight_decay)
93 fc_weight_init = tf.truncated_normal_initializer(stddev=1e-3)
94 fc_bias_init = tf.zeros_initializer()
95 fc_regularizer = slim.l2_regularizer(weight_decay)
98 return slim.batch_norm(x, scope=tf.get_variable_scope().name +
"/bn")
101 network = slim.conv2d(
102 network, 32, [3, 3], stride=1, activation_fn=nonlinearity,
103 padding=
"SAME", normalizer_fn=batch_norm_fn, scope=
"conv1_1",
104 weights_initializer=conv_weight_init, biases_initializer=conv_bias_init,
105 weights_regularizer=conv_regularizer)
106 network = slim.conv2d(
107 network, 32, [3, 3], stride=1, activation_fn=nonlinearity,
108 padding=
"SAME", normalizer_fn=batch_norm_fn, scope=
"conv1_2",
109 weights_initializer=conv_weight_init, biases_initializer=conv_bias_init,
110 weights_regularizer=conv_regularizer)
116 network = slim.max_pool2d(network, [3, 3], [2, 2], scope=
"pool1")
119 network,
"conv2_1", nonlinearity, conv_weight_init, conv_bias_init,
120 conv_regularizer, increase_dim=
False, is_first=
True)
122 network,
"conv2_3", nonlinearity, conv_weight_init, conv_bias_init,
123 conv_regularizer, increase_dim=
False)
126 network,
"conv3_1", nonlinearity, conv_weight_init, conv_bias_init,
127 conv_regularizer, increase_dim=
True)
129 network,
"conv3_3", nonlinearity, conv_weight_init, conv_bias_init,
130 conv_regularizer, increase_dim=
False)
133 network,
"conv4_1", nonlinearity, conv_weight_init, conv_bias_init,
134 conv_regularizer, increase_dim=
True)
136 network,
"conv4_3", nonlinearity, conv_weight_init, conv_bias_init,
137 conv_regularizer, increase_dim=
False)
139 feature_dim = network.get_shape().as_list()[-1]
140 network = slim.flatten(network)
142 network = slim.dropout(network, keep_prob=0.6)
143 network = slim.fully_connected(
144 network, feature_dim, activation_fn=nonlinearity,
145 normalizer_fn=batch_norm_fn, weights_regularizer=fc_regularizer,
146 scope=
"fc1", weights_initializer=fc_weight_init,
147 biases_initializer=fc_bias_init)
152 features = slim.batch_norm(features, scope=
"ball", reuse=reuse)
153 feature_norm = tf.sqrt(
154 tf.constant(1e-8, tf.float32) +
155 tf.reduce_sum(tf.square(features), [1], keepdims=
True))
156 features = features / feature_norm
157 return features,
None 162 def factory_fn(image, reuse):
163 with slim.arg_scope([slim.batch_norm, slim.dropout],
165 with slim.arg_scope([slim.conv2d, slim.fully_connected,
166 slim.batch_norm, slim.layer_norm],
169 image, reuse=reuse, weight_decay=weight_decay)
170 return features, logits
176 image = image[:, :, ::-1]
181 """Parse command line arguments. 183 parser = argparse.ArgumentParser(description=
"Freeze old model")
186 default=
"resources/networks/mars-small128.ckpt-68577",
187 help=
"Path to checkpoint file")
190 default=
"resources/networks/mars-small128.pb")
191 return parser.parse_args()
197 with tf.Session(graph=tf.Graph())
as session:
198 input_var = tf.placeholder(
199 tf.uint8, (
None, 128, 64, 3), name=
"images")
200 image_var = tf.map_fn(
201 lambda x:
_preprocess(x), tf.cast(input_var, tf.float32),
205 features, _ = factory_fn(image_var, reuse=
None)
206 features = tf.identity(features, name=
"features")
208 saver = tf.train.Saver(slim.get_variables_to_restore())
209 saver.restore(session, args.checkpoint_in)
211 output_graph_def = tf.graph_util.convert_variables_to_constants(
212 session, tf.get_default_graph().as_graph_def(),
213 [features.name.split(
":")[0]])
214 with tf.gfile.GFile(args.graphdef_out,
"wb")
as file_handle:
215 file_handle.write(output_graph_def.SerializeToString())
218 if __name__ ==
"__main__":
def residual_block(incoming, scope, nonlinearity=tf.nn.elu, weights_initializer=tf.truncated_normal_initializer(1e3), bias_initializer=tf.zeros_initializer(), regularizer=None, increase_dim=False, is_first=False, summarize_activations=True)
def create_inner_block(incoming, scope, nonlinearity=tf.nn.elu, weights_initializer=tf.truncated_normal_initializer(1e-3), bias_initializer=tf.zeros_initializer(), regularizer=None, increase_dim=False, summarize_activations=True)
def _batch_norm_fn(x, scope=None)
def _network_factory(weight_decay=1e-8)
def create_link(incoming, network_builder, scope, nonlinearity=tf.nn.elu, weights_initializer=tf.truncated_normal_initializer(stddev=1e-3), regularizer=None, is_first=False, summarize_activations=True)
def _create_network(incoming, reuse=None, weight_decay=1e-8)