Python – TensorFlow: Alternates between datasets with different output shapes

TensorFlow: Alternates between datasets with different output shapes… here is a solution to the problem.

TensorFlow: Alternates between datasets with different output shapes

I’m trying to convert tf. Dataset is used in 3D image CNNs, where the shapes of the 3D images input from the training and validation sets are different (training: (64, 64, 64), validation: (176, 176, 160)). I didn’t even know it was possible, but I’m recreating this network based on a paper and using the classic feed_dict approach the network really works. For performance reasons (and just to learn), I’m trying to switch the network to use tf. Dataset

I have two datasets and iterators as shown below:

def _data_parser(dataset, shape):
        features = {"input": tf. FixedLenFeature((), tf.string),
                    "label": tf. FixedLenFeature((), tf.string)}
        parsed_features = tf.parse_single_example(dataset, features)

image = tf.decode_raw(parsed_features["input"], tf.float32)
        image = tf.reshape(image, shape + (1,))

label = tf.decode_raw(parsed_features["label"], tf.float32)
        label = tf.reshape(label, shape + (1,))
        return image, label

train_datasets = ["train.tfrecord"]
train_dataset =
train_dataset = x: _data_parser(x, (64, 64, 64)))
train_dataset = train_dataset.batch(batch_size) # batch_size = 16
train_iterator = train_dataset.make_initializable_iterator()

val_datasets = ["validation.tfrecord"]
val_dataset =
val_dataset = x: _data_parser(x, (176, 176, 160)))
val_dataset = val_dataset.batch(1)
val_iterator = val_dataset.make_initializable_iterator()

TensorFlow documentation has a description of the use of reinitializable_ Iterator or feedable_iterator examples of switching between datasets, but they both switch between iterators of the same output shape, which is not the case here.

In my case, how should I use tf. Dataset and switch between training and validation sets?


Simply provide unspecified (none) values for shapes on axes whose dimensions do not match. For example

import numpy as np
import tensorflow as tf

training_dataset =, 64, 64), np.float32)).repeat().batch(4)
validation_dataset =, 176, 160), np.float32)).repeat().batch(1)

iterator =
    <b>tf. TensorShape([None, None, None, None])</b>)
next_element = iterator.get_next()

training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)

sess = tf. InteractiveSession()

Related Problems and Solutions