Python – TensorFlow: Stacks tensors in a while loop

TensorFlow: Stacks tensors in a while loop… here is a solution to the problem.

TensorFlow: Stacks tensors in a while loop

I’m trying to implement a loop that iterates through the rows of tensors, retrieves the indexes in each row, uses them to collect vectors from another tensor, and finally combines those vectors into a new tensor.
The problem is that each row may contain a different number of indexes (e.g. [[-1,-1,1,4,-1], [3,-1,-1,-1,-1]) First row index: [1, 4]; Second row index [3]).
The problem arises when I use tf.while_loop or tf.scan. For the first one, I don’t understand how all the collected tensors can be stacked on top of each other. Conversely, the second wants all outputs to have the same shape (it seems that I can’t tell that the general shape of all outputs is [None, 10]).

Has anyone tried something like that?

I attached the code for while_loop:

i = tf.constant(0)
def body(i, merging):
    i += 1
    print('i', i)
    i_row = tf.gather(dense, [i])
    i_indices = tf.where(i_row > 0)[:, 1]
    i_vecs = tf.gather(embeddings_ph, i_indices)
    return i, i_vecs

tf.while_loop(lambda i, merging : tf.less(i, 2), body, 
              loop_vars=[i,merging], 
              shape_invariants=[i.get_shape(), 
                                tf. TensorShape((None, 3))],
              name='vecs_gathering')

What’s missing here is stacking all the while_loop outputs (i_vec of each i) in a new tensor.

Solution

Okay, take inspiration from the RNN implementation. I modified my code as follows and now it works perfectly :

def body(i, outputs):
    i_row = tf.gather(dense, [i])
    i_indices = tf.where(i_row > 0)[:, 1]
    i_vecs = tf.gather(embeddings_ph, i_indices)
    outputs = outputs.write(i, i_vecs)
    i += 1
return i, outputs

outputs = tf. TensorArray(dtype=tf.float32, infer_shape=False, size=1, 
                     dynamic_size=True) 
_, outputs = tf.while_loop(lambda i, *_: tf.less(i, 3), body,[0,outputs])

outputs = outputs.concat()

I also want to highlight the fact that you have to reassign the value of the TensorArray when performing the write (otherwise tf will prompt you for the fact that you are not using the array you declared).

Related Problems and Solutions