Illustration

Blog

A detailed example of how to use data generators with Keras

Fork   Star
python keras 2 fit_generator large dataset multiprocessing

By Afshine Amidi and Shervine Amidi

Motivation

Have you ever had to load a dataset that was so memory consuming that you wished a magic trick could seamlessly take care of that? Large datasets are increasingly becoming part of our lives, as we are able to harness an ever-growing quantity of data.

We have to keep in mind that in some cases, even the most state-of-the-art configuration won't have enough memory space to process the data the way we used to do it. That is the reason why we need to find other ways to do that task efficiently. In this blog post, we are going to show you how to generate your dataset on multiple cores in real time and feed it right away to your deep learning model.

The framework used in this tutorial is the one provided by Python's high-level package Keras, which can be used on top of a GPU installation of either TensorFlow or Theano.

Tutorial

Previous situation

Before reading this article, your Keras script probably looked like this:

import numpy as np
from keras.models import Sequential

# Load entire dataset
X, y = np.load('some_training_set_with_labels.npy')

# Design model
model = Sequential()
[...] # Your architecture
model.compile()

# Train model on your dataset
model.fit(x=X, y=y)

This article is all about changing the line loading the entire dataset at once. Indeed, this task may cause issues as all of the training samples may not be able to fit in memory at the same time.

In order to do so, let's dive into a step by step recipe that builds a data generator suited for this situation. By the way, the following code is a good skeleton to use for your own project; you can copy/paste the following pieces of code and fill the blanks accordingly.

Notations

Before getting started, let's go through a few organizational tips that are particularly useful when dealing with large datasets.

Let ID be the Python string that identifies a given sample of the dataset. A good way to keep track of samples and their labels is to adopt the following framework:

  1. Create a dictionary called partition where you gather:

    • in partition['train'] a list of training IDs
    • in partition['validation'] a list of validation IDs
  2. Create a dictionary called labels where for each ID of the dataset, the associated label is given by labels[ID]

For example, let's say that our training set contains id-1, id-2 and id-3 with respective labels 0, 1 and 2, with a validation set containing id-4 with label 1. In that case, the Python variables partition and labels look like

>>> partition
{'train': ['id-1', 'id-2', 'id-3'], 'validation': ['id-4']}

and

>>> labels
{'id-1': 0, 'id-2': 1, 'id-3': 2, 'id-4': 1}

Also, for the sake of modularity, we will write Keras code and customized classes in separate files, so that your folder looks like

folder/
├── my_classes.py
├── keras_script.py
└── data/

where data/ is assumed to be the folder containing your dataset.

Finally, it is good to note that the code in this tutorial is aimed at being general and minimal, so that you can easily adapt it for your own dataset.

Data generator

Now, let's go through the details of how to set the Python class DataGenerator, which will be used for real-time data feeding to your Keras model.

First, let's write the initialization function of the class. We make the latter inherit the properties of keras.utils.Sequence so that we can leverage nice functionalities such as multiprocessing.

def __init__(self, list_IDs, labels, batch_size=32, dim=(32,32,32), n_channels=1,
             n_classes=10, shuffle=True):
    'Initialization'
    self.dim = dim
    self.batch_size = batch_size
    self.labels = labels
    self.list_IDs = list_IDs
    self.n_channels = n_channels
    self.n_classes = n_classes
    self.shuffle = shuffle
    self.on_epoch_end()

We put as arguments relevant information about the data, such as dimension sizes (e.g. a volume of length 32 will have dim=(32,32,32)), number of channels, number of classes, batch size, or decide whether we want to shuffle our data at generation. We also store important information such as labels and the list of IDs that we wish to generate at each pass.

Here, the method on_epoch_end is triggered once at the very beginning as well as at the end of each epoch. If the shuffle parameter is set to True, we will get a new order of exploration at each pass (or just keep a linear exploration scheme otherwise).

def on_epoch_end(self):
  'Updates indexes after each epoch'
  self.indexes = np.arange(len(self.list_IDs))
  if self.shuffle == True:
      np.random.shuffle(self.indexes)

Shuffling the order in which examples are fed to the classifier is helpful so that batches between epochs do not look alike. Doing so will eventually make our model more robust.

Another method that is core to the generation process is the one that achieves the most crucial job: producing batches of data. The private method in charge of this task is called __data_generation and takes as argument the list of IDs of the target batch.

def __data_generation(self, list_IDs_temp):
  'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
  # Initialization
  X = np.empty((self.batch_size, *self.dim, self.n_channels))
  y = np.empty((self.batch_size), dtype=int)

  # Generate data
  for i, ID in enumerate(list_IDs_temp):
      # Store sample
      X[i,] = np.load('data/' + ID + '.npy')

      # Store class
      y[i] = self.labels[ID]

  return X, keras.utils.to_categorical(y, num_classes=self.n_classes)

During data generation, this code reads the NumPy array of each example from its corresponding file ID.npy. Since our code is multicore-friendly, note that you can do more complex operations instead (e.g. computations from source files) without worrying that data generation becomes a bottleneck in the training process.

Also, please note that we used Keras' keras.utils.to_categorical function to convert our numerical labels stored in y to a binary form (e.g. in a 6-class problem, the third label corresponds to [0 0 1 0 0 0]) suited for classification.

Now comes the part where we build up all these components together. Each call requests a batch index between 0 and the total number of batches, where the latter is specified in the __len__ method.

def __len__(self):
  'Denotes the number of batches per epoch'
  return int(np.floor(len(self.list_IDs) / self.batch_size))

A common practice is to set this value to $$\biggl\lfloor\frac{\#\textrm{ samples}}{\textrm{batch size}}\biggr\rfloor$$ so that the model sees the training samples at most once per epoch.

Now, when the batch corresponding to a given index is called, the generator executes the __getitem__ method to generate it.

def __getitem__(self, index):
  'Generate one batch of data'
  # Generate indexes of the batch
  indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

  # Find list of IDs
  list_IDs_temp = [self.list_IDs[k] for k in indexes]

  # Generate data
  X, y = self.__data_generation(list_IDs_temp)

  return X, y

The complete code corresponding to the steps that we described in this section is shown below.

import numpy as np
import keras

class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, list_IDs, labels, batch_size=32, dim=(32,32,32), n_channels=1,
                 n_classes=10, shuffle=True):
        'Initialization'
        self.dim = dim
        self.batch_size = batch_size
        self.labels = labels
        self.list_IDs = list_IDs
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        X, y = self.__data_generation(list_IDs_temp)

        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.empty((self.batch_size, *self.dim, self.n_channels))
        y = np.empty((self.batch_size), dtype=int)

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            # Store sample
            X[i,] = np.load('data/' + ID + '.npy')

            # Store class
            y[i] = self.labels[ID]

        return X, keras.utils.to_categorical(y, num_classes=self.n_classes)

Keras script

Now, we have to modify our Keras script accordingly so that it accepts the generator that we just created.

import numpy as np

from keras.models import Sequential
from my_classes import DataGenerator

# Parameters
params = {'dim': (32,32,32),
          'batch_size': 64,
          'n_classes': 6,
          'n_channels': 1,
          'shuffle': True}

# Datasets
partition = # IDs
labels = # Labels

# Generators
training_generator = DataGenerator(partition['train'], labels, **params)
validation_generator = DataGenerator(partition['validation'], labels, **params)

# Design model
model = Sequential()
[...] # Architecture
model.compile()

# Train model on dataset
model.fit_generator(generator=training_generator,
                    validation_data=validation_generator,
                    use_multiprocessing=True,
                    workers=6)

As you can see, we called from model the fit_generator method instead of fit, where we just had to give our training generator as one of the arguments. Keras takes care of the rest!

Note that our implementation enables the use of the multiprocessing argument of fit_generator, where the number of threads specified in workers are those that generate batches in parallel. A high enough number of workers assures that CPU computations are efficiently managed, i.e. that the bottleneck is indeed the neural network's forward and backward operations on the GPU (and not data generation).

Conclusion

This is it! You can now run your Keras script with the command

python3 keras_script.py

and you will see that during the training phase, data is generated in parallel by the CPU and then directly fed to the GPU.

You can find a complete example of this strategy on applied on a specific example on GitHub where codes of data generation as well as the Keras script are available.

You may also like...

Artificial Intelligence cheatsheets

  • • Reflex-based models
  • • States-based models
  • • Variables-based models
  • • Logic-based models

Machine Learning cheatsheets

  • • Supervised learning
  • • Unsupervised learning
  • • Deep learning
  • • Machine learning tips and tricks

Deep Learning cheatsheets

  • • Convolutional neural networks
  • • Recurrent neural networks
  • • Deep learning tips and tricks