00001
00002 import argparse
00003 import tensorflow as tf
00004 import tensorflow.contrib.slim as slim
00005
00006
00007 def _batch_norm_fn(x, scope=None):
00008 if scope is None:
00009 scope = tf.get_variable_scope().name + "/bn"
00010 return slim.batch_norm(x, scope=scope)
00011
00012
00013 def create_link(
00014 incoming, network_builder, scope, nonlinearity=tf.nn.elu,
00015 weights_initializer=tf.truncated_normal_initializer(stddev=1e-3),
00016 regularizer=None, is_first=False, summarize_activations=True):
00017 if is_first:
00018 network = incoming
00019 else:
00020 network = _batch_norm_fn(incoming, scope=scope + "/bn")
00021 network = nonlinearity(network)
00022 if summarize_activations:
00023 tf.summary.histogram(scope+"/activations", network)
00024
00025 pre_block_network = network
00026 post_block_network = network_builder(pre_block_network, scope)
00027
00028 incoming_dim = pre_block_network.get_shape().as_list()[-1]
00029 outgoing_dim = post_block_network.get_shape().as_list()[-1]
00030 if incoming_dim != outgoing_dim:
00031 assert outgoing_dim == 2 * incoming_dim, \
00032 "%d != %d" % (outgoing_dim, 2 * incoming)
00033 projection = slim.conv2d(
00034 incoming, outgoing_dim, 1, 2, padding="SAME", activation_fn=None,
00035 scope=scope+"/projection", weights_initializer=weights_initializer,
00036 biases_initializer=None, weights_regularizer=regularizer)
00037 network = projection + post_block_network
00038 else:
00039 network = incoming + post_block_network
00040 return network
00041
00042
00043 def create_inner_block(
00044 incoming, scope, nonlinearity=tf.nn.elu,
00045 weights_initializer=tf.truncated_normal_initializer(1e-3),
00046 bias_initializer=tf.zeros_initializer(), regularizer=None,
00047 increase_dim=False, summarize_activations=True):
00048 n = incoming.get_shape().as_list()[-1]
00049 stride = 1
00050 if increase_dim:
00051 n *= 2
00052 stride = 2
00053
00054 incoming = slim.conv2d(
00055 incoming, n, [3, 3], stride, activation_fn=nonlinearity, padding="SAME",
00056 normalizer_fn=_batch_norm_fn, weights_initializer=weights_initializer,
00057 biases_initializer=bias_initializer, weights_regularizer=regularizer,
00058 scope=scope + "/1")
00059 if summarize_activations:
00060 tf.summary.histogram(incoming.name + "/activations", incoming)
00061
00062 incoming = slim.dropout(incoming, keep_prob=0.6)
00063
00064 incoming = slim.conv2d(
00065 incoming, n, [3, 3], 1, activation_fn=None, padding="SAME",
00066 normalizer_fn=None, weights_initializer=weights_initializer,
00067 biases_initializer=bias_initializer, weights_regularizer=regularizer,
00068 scope=scope + "/2")
00069 return incoming
00070
00071
00072 def residual_block(incoming, scope, nonlinearity=tf.nn.elu,
00073 weights_initializer=tf.truncated_normal_initializer(1e3),
00074 bias_initializer=tf.zeros_initializer(), regularizer=None,
00075 increase_dim=False, is_first=False,
00076 summarize_activations=True):
00077
00078 def network_builder(x, s):
00079 return create_inner_block(
00080 x, s, nonlinearity, weights_initializer, bias_initializer,
00081 regularizer, increase_dim, summarize_activations)
00082
00083 return create_link(
00084 incoming, network_builder, scope, nonlinearity, weights_initializer,
00085 regularizer, is_first, summarize_activations)
00086
00087
00088 def _create_network(incoming, reuse=None, weight_decay=1e-8):
00089 nonlinearity = tf.nn.elu
00090 conv_weight_init = tf.truncated_normal_initializer(stddev=1e-3)
00091 conv_bias_init = tf.zeros_initializer()
00092 conv_regularizer = slim.l2_regularizer(weight_decay)
00093 fc_weight_init = tf.truncated_normal_initializer(stddev=1e-3)
00094 fc_bias_init = tf.zeros_initializer()
00095 fc_regularizer = slim.l2_regularizer(weight_decay)
00096
00097 def batch_norm_fn(x):
00098 return slim.batch_norm(x, scope=tf.get_variable_scope().name + "/bn")
00099
00100 network = incoming
00101 network = slim.conv2d(
00102 network, 32, [3, 3], stride=1, activation_fn=nonlinearity,
00103 padding="SAME", normalizer_fn=batch_norm_fn, scope="conv1_1",
00104 weights_initializer=conv_weight_init, biases_initializer=conv_bias_init,
00105 weights_regularizer=conv_regularizer)
00106 network = slim.conv2d(
00107 network, 32, [3, 3], stride=1, activation_fn=nonlinearity,
00108 padding="SAME", normalizer_fn=batch_norm_fn, scope="conv1_2",
00109 weights_initializer=conv_weight_init, biases_initializer=conv_bias_init,
00110 weights_regularizer=conv_regularizer)
00111
00112
00113
00114
00115
00116 network = slim.max_pool2d(network, [3, 3], [2, 2], scope="pool1")
00117
00118 network = residual_block(
00119 network, "conv2_1", nonlinearity, conv_weight_init, conv_bias_init,
00120 conv_regularizer, increase_dim=False, is_first=True)
00121 network = residual_block(
00122 network, "conv2_3", nonlinearity, conv_weight_init, conv_bias_init,
00123 conv_regularizer, increase_dim=False)
00124
00125 network = residual_block(
00126 network, "conv3_1", nonlinearity, conv_weight_init, conv_bias_init,
00127 conv_regularizer, increase_dim=True)
00128 network = residual_block(
00129 network, "conv3_3", nonlinearity, conv_weight_init, conv_bias_init,
00130 conv_regularizer, increase_dim=False)
00131
00132 network = residual_block(
00133 network, "conv4_1", nonlinearity, conv_weight_init, conv_bias_init,
00134 conv_regularizer, increase_dim=True)
00135 network = residual_block(
00136 network, "conv4_3", nonlinearity, conv_weight_init, conv_bias_init,
00137 conv_regularizer, increase_dim=False)
00138
00139 feature_dim = network.get_shape().as_list()[-1]
00140 network = slim.flatten(network)
00141
00142 network = slim.dropout(network, keep_prob=0.6)
00143 network = slim.fully_connected(
00144 network, feature_dim, activation_fn=nonlinearity,
00145 normalizer_fn=batch_norm_fn, weights_regularizer=fc_regularizer,
00146 scope="fc1", weights_initializer=fc_weight_init,
00147 biases_initializer=fc_bias_init)
00148
00149 features = network
00150
00151
00152 features = slim.batch_norm(features, scope="ball", reuse=reuse)
00153 feature_norm = tf.sqrt(
00154 tf.constant(1e-8, tf.float32) +
00155 tf.reduce_sum(tf.square(features), [1], keepdims=True))
00156 features = features / feature_norm
00157 return features, None
00158
00159
00160 def _network_factory(weight_decay=1e-8):
00161
00162 def factory_fn(image, reuse):
00163 with slim.arg_scope([slim.batch_norm, slim.dropout],
00164 is_training=False):
00165 with slim.arg_scope([slim.conv2d, slim.fully_connected,
00166 slim.batch_norm, slim.layer_norm],
00167 reuse=reuse):
00168 features, logits = _create_network(
00169 image, reuse=reuse, weight_decay=weight_decay)
00170 return features, logits
00171
00172 return factory_fn
00173
00174
00175 def _preprocess(image):
00176 image = image[:, :, ::-1]
00177 return image
00178
00179
00180 def parse_args():
00181 """Parse command line arguments.
00182 """
00183 parser = argparse.ArgumentParser(description="Freeze old model")
00184 parser.add_argument(
00185 "--checkpoint_in",
00186 default="resources/networks/mars-small128.ckpt-68577",
00187 help="Path to checkpoint file")
00188 parser.add_argument(
00189 "--graphdef_out",
00190 default="resources/networks/mars-small128.pb")
00191 return parser.parse_args()
00192
00193
00194 def main():
00195 args = parse_args()
00196
00197 with tf.Session(graph=tf.Graph()) as session:
00198 input_var = tf.placeholder(
00199 tf.uint8, (None, 128, 64, 3), name="images")
00200 image_var = tf.map_fn(
00201 lambda x: _preprocess(x), tf.cast(input_var, tf.float32),
00202 back_prop=False)
00203
00204 factory_fn = _network_factory()
00205 features, _ = factory_fn(image_var, reuse=None)
00206 features = tf.identity(features, name="features")
00207
00208 saver = tf.train.Saver(slim.get_variables_to_restore())
00209 saver.restore(session, args.checkpoint_in)
00210
00211 output_graph_def = tf.graph_util.convert_variables_to_constants(
00212 session, tf.get_default_graph().as_graph_def(),
00213 [features.name.split(":")[0]])
00214 with tf.gfile.GFile(args.graphdef_out, "wb") as file_handle:
00215 file_handle.write(output_graph_def.SerializeToString())
00216
00217
00218 if __name__ == "__main__":
00219 main()