Use Tensorflow’s top_k and scatter_nd
I’m trying to write an operation in tensorflow that only propagates the first k values of each feature graph.
Example:
k=1, the input size is [batch_size, x, y, channels] assuming
it is [1,2,2,3].
The output should be the same size, and if k=1 then each x,y plane will have only one nonzero value.
Example in numpy:
input = [[[[6.4 1.4 1.3] [2.1 6.5 4.8]][[2.3 9.2 2.8][7.9 5.1 0.6]]]]]
The output should be:
[[[[6.4 0. 0.] [0. 6.5 0.]] [[0. 9.2 0.] [7.9 0. 0.]]]]
In order to do this in TensorFlow, I want to use nn.top_k and then scatter_nd.
The problem is that top_k returns the index of the request element very differently than scatter_nd needs it.
top_k to (1,2,2, 1).
scatter_nd needs it as a list of all coordinates for each value, as shown below
[[0, 0, 0, 0], [0, 0, 1, 1], [0, 1, 0, 1], [0, 1, 1, 0]]
Does anyone know the way to convert between them? Maybe even take a different approach to this operation entirely?
Solution
tf.nn.top_k()
returns only the first k values in the last dimension. So you have to add back all the other dimensions. The simplest tf.where()
.code (test):
import tensorflow as tf
inp = tf.constant( [ [ [ [6.4, 1.4, 1.3], [2.1, 6.5, 4.8] ], [ [2.3, 9.2, 2.8], [7.9, 5.1, 10.6] ] ] ] )
t, idx = tf.nn.top_k( inp, k = 2 )
idx_one_hot = tf.one_hot( idx, depth = 3 )
idx_red = tf.reduce_sum( idx_one_hot, axis = -2 )
idx2 = tf.where( tf.not_equal( idx_red, 0 ) )
with tf. Session() as sess:
print( sess.run( idx2 ) )
Output (note that I’ve changed the last number in your example to index also 2, only 0 and 1 look a bit misleading, as if it’s a bool tensor):
[[0 0 0 0]
[0 0 0 1]
[0 0 1 1]
[0 0 1 2]
[0 1 0 1]
[0 1 0 2]
[0 1 1 0]
[0 1 1 2]]
Note that this loses the index order in the last dimension reported by the top_k, which it changes to the increasing order of the index itself.