Python – How to get global steps when using tf.train.MonitoredTrainingSession

How to get global steps when using tf.train.MonitoredTrainingSession… here is a solution to the problem.

How to get global steps when using tf.train.MonitoredTrainingSession

When we specify global_step in Saver.save, it stores the global_step as a checkpoint suffix.

# save the checkpoint
saver = tf.train.Saver()
saver.save(session, checkpoints_path, global_step)

We can resume the checkpoint like this and get the last global step stored in the checkpoint:

# restore the checkpoint and obtain the global step
saver.restore(session, ckpt.model_checkpoint_path)
...
_, gstep = session.run([optimizer, global_step], feed_dict=feed_dict_train)

If we use tf.train.MonitoredTrainingSession, what is the equivalent of saving a global step to a checkpoint and getting gstep?

Edit 1

Following Maxim’s suggestion, I created global_step variables before tf.train.Monitored TrainingSession and added a CheckpointSaverHook as follows:

global_step = tf.train.get_or_create_global_step()
save_checkpoint_hook = tf.train.CheckpointSaverHook(checkpoint_dir=checkpoints_abs_path,
                                                    save_steps=5,
                                                    checkpoint_basename=(checkpoints_prefix + ".ckpt"))

with tf.train.MonitoredTrainingSession(master=server.target,
                                       is_chief=is_chief,                     
                                       hooks=[sync_replicas_hook, save_checkpoint_hook],
                                       config=config) as session:

_, gstep = session.run([optimizer, global_step], feed_dict=feed_dict_train)
    print("current global step=" + str(gstep))

I can see that the checkpoint file it generates is similar to what Saver.saver does. However, it cannot retrieve global steps from the checkpoint. Please tell me how to solve this problem?

Solution

It can be tf.train.get_global_step() by Get the current global step size or > via tf.train.get_or_create_ global_step() function. The latter should be called before training starts.

For monitored sessions, add > tf.train.CheckpointSaverHook to hooks, which internally save the model after every N steps using a defined global step tensor.

Related Problems and Solutions