How to get global steps when using tf.train.MonitoredTrainingSession
When we specify global_step in Saver.save
, it stores the global_step as a checkpoint suffix.
# save the checkpoint
saver = tf.train.Saver()
saver.save(session, checkpoints_path, global_step)
We can resume the checkpoint like this and get the last global step stored in the checkpoint:
# restore the checkpoint and obtain the global step
saver.restore(session, ckpt.model_checkpoint_path)
...
_, gstep = session.run([optimizer, global_step], feed_dict=feed_dict_train)
If we use tf.train.MonitoredTrainingSession
, what is the equivalent of saving a global step to a checkpoint and getting gstep
?
Edit 1
Following Maxim’s suggestion, I created global_step
variables before tf.train.Monitored TrainingSession
and added a CheckpointSaverHook
as follows:
global_step = tf.train.get_or_create_global_step()
save_checkpoint_hook = tf.train.CheckpointSaverHook(checkpoint_dir=checkpoints_abs_path,
save_steps=5,
checkpoint_basename=(checkpoints_prefix + ".ckpt"))
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=is_chief,
hooks=[sync_replicas_hook, save_checkpoint_hook],
config=config) as session:
_, gstep = session.run([optimizer, global_step], feed_dict=feed_dict_train)
print("current global step=" + str(gstep))
I can see that the checkpoint file it generates is similar to what Saver.saver does.
However, it cannot retrieve global steps from the checkpoint. Please tell me how to solve this problem?
Solution
It can be tf.train.get_global_step()
by Get the current global step size or > via tf.train.get_or_create_ global_step() function. The latter should be called before training starts.
For monitored sessions, add > tf.train.CheckpointSaverHook to hooks
, which internally save the model after every N steps using a defined global step tensor.