TensorFlow: Alternates between datasets with different output shapes
I’m trying to convert tf. Dataset
is used in 3D image CNNs, where the shapes of the 3D images input from the training and validation sets are different (training: (64, 64, 64), validation: (176, 176, 160)). I didn’t even know it was possible, but I’m recreating this network based on a paper and using the classic feed_dict
approach the network really works. For performance reasons (and just to learn), I’m trying to switch the network to use tf. Dataset
。
I have two datasets and iterators as shown below:
def _data_parser(dataset, shape):
features = {"input": tf. FixedLenFeature((), tf.string),
"label": tf. FixedLenFeature((), tf.string)}
parsed_features = tf.parse_single_example(dataset, features)
image = tf.decode_raw(parsed_features["input"], tf.float32)
image = tf.reshape(image, shape + (1,))
label = tf.decode_raw(parsed_features["label"], tf.float32)
label = tf.reshape(label, shape + (1,))
return image, label
train_datasets = ["train.tfrecord"]
train_dataset = tf.data.TFRecordDataset(train_datasets)
train_dataset = train_dataset.map(lambda x: _data_parser(x, (64, 64, 64)))
train_dataset = train_dataset.batch(batch_size) # batch_size = 16
train_iterator = train_dataset.make_initializable_iterator()
val_datasets = ["validation.tfrecord"]
val_dataset = tf.data.TFRecordDataset(val_datasets)
val_dataset = val_dataset.map(lambda x: _data_parser(x, (176, 176, 160)))
val_dataset = val_dataset.batch(1)
val_iterator = val_dataset.make_initializable_iterator()
TensorFlow documentation has a description of the use of reinitializable_ Iterator
or feedable_iterator
examples of switching between datasets, but they both switch between iterators of the same output shape, which is not the case here.
In my case, how should I use tf. Dataset
and tf.data.Iterator
switch between training and validation sets?
Solution
Simply provide unspecified (none
) values for shapes on axes whose dimensions do not match. For example
import numpy as np
import tensorflow as tf
training_dataset = tf.data.Dataset.from_tensors(np.zeros((64, 64, 64), np.float32)).repeat().batch(4)
validation_dataset = tf.data.Dataset.from_tensors(np.zeros((176, 176, 160), np.float32)).repeat().batch(1)
iterator = tf.data.Iterator.from_structure(
training_dataset.output_types,
<b>tf. TensorShape([None, None, None, None])</b>)
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)
sess = tf. InteractiveSession()
sess.run(training_init_op)
print(sess.run(next_element).shape)
sess.run(validation_init_op)
print(sess.run(next_element).shape)