Java – Permanent update of variables in tensorflow-java (during inference)

Permanent update of variables in tensorflow-java (during inference)… here is a solution to the problem.

Permanent update of variables in tensorflow-java (during inference)

I

have trained a model using python-tensorflow and I want to do inference in java-tensorflow. I have loaded the trained model/graph into Java. After this, I want to permanently update a variable in the graph. I know that the tf.variable.load(value,session) function in Python can be used to update the value of a variable. I wonder if there is a similar method in Java.

So far, I have tried the following methods.

// g and s are loaded graphs and sessions respectively
s.runner().feed(variableName,updatedTensorValue)

But during fetch calls executed on the same line, the above line uses updatedTensorValue only for variableName.

g.opBuilder("Assign",variableName).setAttr("value",updatedTensorValue).build();

Instead of updating the value, the above line tries to add the same variable to the plot, so it throws an exception.

Another alternative to permanently updating variables in the graph is that I will always call the feed(variableName, updatedTensorValue) method during all fetch calls. I’ll be running inference code on multiple instances, so I’m wondering if this extra feed call will take extra time.

Thanks

Solution

The

way to do most things in TensorFlow is to perform an action. You are on the right track when you try to run the Assign operation, but the call is incorrect because the value to be assigned is not a "property" of > Assign operation instead of an input tensor. (See the original definition of the operation.) , but admittedly, the definition may not be easy to understand unless you are familiar with the internals of TensorFlow).

However, you don’t need to add an operation to the graph in Java to do this. Instead, you can follow exactly tf. Variable.load does execution in Python – execution tf. Variable.initializer operation, enter the input value.

For example, consider the following graph built in Python:

import tensorflow as tf

var = tf. Variable(1.0, name='myvar')
init = tf.global_variables_initializer()

# Save the graph and write out the names of the operations of interest
tf.train.write_graph(tf.get_default_graph(), '/tmp', 'graph.pb', as_text=False)
print('Init all variables:         ', init.name)
print('myvar.initializer:          ', var.initializer.name)
print('myvar.initializer.inputs[1]:', var.initializer.inputs[1].name)

Now, let’s replicate the behavior of Python var.load() in Java and assign the value 3.0 to the variable as follows:

try (Tensor<Float> newValue = Tensors.create(3.0f)) {
  s.runner()
    .feed("myvar/initial_value", newVal) // myvar.initializer.inputs[1].name
    .addTarget("myvar/Assign")           // myvar.initializer.name
    .run();
}

Hope this helps.

Related Problems and Solutions