tf.nn.softmax
在最后一个维度做softmax操作, 比如:
import tensorflow as tf
import numpy as np
N = 10
C = 2
a = np.random.randn(N, C, C*4) * 100
aa = tf.constant(a)
bb = tf.nn.softmax(aa)
with tf.Session() as ss:
bbval = ss.run(bb)
print(bbval)
print(np.sum(bbval, axis=2)) # <== will get all ones' result
Restore pre-trained models
In TensorFlow, a tf.train.Saver
is a module to save and restore variables to disk. If we need to specify which variables to restore, we need to pass a variable dict
or list
to tf.train.Saver
. Using list
is simple, now we just talk about how to use dict
.
name_to_vars = { ... } # key: the variable name saved in the checkpoint file, value: the TensorFlow variable we need to restore.
saver_for_restore = tf.train.Saver(name_to_vars)
Note that in name_to_vars
, the key
is a string, i.e. the variable name saved in the checkpoint file, not the name you defined in your graph. The value
is the TensorFlow variable we need to restore, it is the variable you define in your current graph. See the following code:
def save_ckpt():
FLAGS.mode = 'train'
net = CENet3D(FLAGS)
name_to_vars = {}
for var in net.all_vars:
if 'discriminator_im' in var.op.name:
name = var.op.name.replace('discriminator_im', 'discriminator')
name_to_vars[name] = var
if 'generator_vd' in var.op.name or 'discriminator_vd' in var.op.name:
name = var.op.name
name_to_vars[name] = var
saver_only_for_restore = tf.train.Saver(name_to_vars)
saver_for_save = tf.train.Saver(net.all_vars)
with tf.Session(config=tfconfig) as sess:
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init_op)
ckpt = tf.train.get_checkpoint_state(FLAGS.log_path)
if ckpt and ckpt.model_checkpoint_path:
saver_only_for_restore.restore(sess, ckpt.model_checkpoint_path)
print('model restored...')
saver_for_save.save(sess, os.path.join('./', 'initial_weights.ckpt'))
else:
raise ValueError('no model found!')