TensorFlow 2.0 : Function with @tf. Function decorators do not take numpy functions… here is a solution to the problem.
TensorFlow 2.0 : Function with @tf. Function decorators do not take numpy functions
I’m writing a function to implement a model in TensorFlow 2.0. It takes image_batch
(a batch of image data in numpy RGB format) and performs some specific data augmentation tasks I need. The line that is causing my problem is:
@tf.function
def augment_data(image_batch, labels):
import numpy as np
from tensorflow.image import flip_left_right
image_batch = np.append(image_batch, flip_left_right(image_batch), axis=0)
[ ... ]
numpy's
.append()
function no longer works when I put the @tf.function
decorator on it. It returns:
ValueError: zero-dimensional arrays cannot be concatenated
When I use the np.append()
command outside of the function, or when there is no @tf.function
at the top, the code runs without problems.
Is this normal? Am I forced to remove the decorator to make it work? Or is this a bug because TensorFlow 2.0 is still in beta? In this case, how do I fix it?
Solution
Simply wrap numpy ops into tf.py_function
def append(image_batch, tf_func):
return np.append(image_batch, tf_func, axis=0)
@tf.function
def augment_data(image_batch):
image = tf.py_function(append, inp=[image_batch, tf.image.flip_left_right(image_batch)], Tout=[tf.float32])
return image