⬡ Hub
Skip to content

Keras: Data Preprocessing and Augmentation

Data preprocessing and augmentation are crucial steps in deep learning workflows, especially when working with images, text, or sequential data. Keras provides utilities to prepare your data for model training and to enhance your dataset through augmentation, which helps prevent overfitting and improves generalization.

1. Image Preprocessing and Augmentation

Keras provides the ImageDataGenerator utility for real-time data augmentation and preprocessing of image data. For more fine-grained control or advanced augmentation, tf.keras.utils.image_dataset_from_directory combined with tf.keras.layers.Rescaling and tf.keras.layers.Random... augmentation layers are often preferred in TensorFlow 2.x.

Using ImageDataGenerator (Legacy but common for simple cases)

ImageDataGenerator can perform operations like rescaling, rotation, width/height shifts, shear, zoom, and flips.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import matplotlib.pyplot as plt
import os

# Create dummy image directory structure
# data/
# └── train/
# │   └─ class_a/
# │   │   └─ img1.png
# │   │   └─ img2.png
# │   └─ class_b/
# │       └─ img3.png
# │       └─ img4.png
# └─ validation/
#     └─ class_a/
#     │   └─ img5.png
#     └─ class_b/
#         └─ img6.png

# Create dummy image files for demonstration
# In a real scenario, you'd have actual images.
def create_dummy_images(base_dir="data"):
    for split in ["train", "validation"]:
        for class_name in ["class_a", "class_b"]:
            os.makedirs(os.path.join(base_dir, split, class_name), exist_ok=True)
            for i in range(2):
                dummy_image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
                plt.imsave(os.path.join(base_dir, split, class_name, f"img_{split}_{class_name}_{i}.png"), dummy_image)

# Call the function to create dummy images
# create_dummy_images() # Uncomment to create dummy images if needed

# Define the ImageDataGenerator with augmentation parameters
train_datagen = ImageDataGenerator(
    rescale=1./255,          # Normalize pixel values to [0, 1]
    rotation_range=40,       # Rotate images by 0-40 degrees
    width_shift_range=0.2,   # Shift image horizontally by 20% of total width
    height_shift_range=0.2,  # Shift image vertically by 20% of total height
    shear_range=0.2,         # Apply shear transformations
    zoom_range=0.2,          # Zoom in/out
    horizontal_flip=True,    # Randomly flip images horizontally
    fill_mode='nearest'      # Strategy for filling in new pixels
)

test_datagen = ImageDataGenerator(rescale=1./255) # Only rescale for test data (no augmentation)

# Load images from directory using flow_from_directory
# It automatically infers labels from subdirectory names
train_generator = train_datagen.flow_from_directory(
    'data/train',
    target_size=(64, 64), # Resize images to this size
    batch_size=32,
    class_mode='binary' # 'binary' for 2 classes, 'categorical' for >2 classes
)

validation_generator = test_datagen.flow_from_directory(
    'data/validation',
    target_size=(64, 64),
    batch_size=32,
    class_mode='binary'
)

# You can iterate through the generator to see augmented images
# for _ in range(3):
#     img_batch, label_batch = next(train_generator)
#     plt.imshow(img_batch[0]) # Show first image in batch
#     plt.title(f"Augmented Image (Label: {label_batch[0]})")
#     plt.show()

For a more integrated and performant approach, especially with tf.data, Keras offers dedicated preprocessing layers.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

# Load example image data from TensorFlow datasets (e.g., flowers)
# In a real scenario, you'd use tf.keras.utils.image_dataset_from_directory
# to load your images from disk.
print("Loading example image data...")
# For demonstration, we'll create a dummy dataset in memory
(x_train, y_train), (x_test, y_test) = (
    (np.random.rand(100, 32, 32, 3), np.random.randint(0, 10, 100)),
    (np.random.rand(20, 32, 32, 3), np.random.randint(0, 10, 20))
)
x_train = x_train.astype('float32') * 255
x_test = x_test.astype('float32') * 255

# Create a tf.data.Dataset
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32).prefetch(tf.data.AUTOTUNE)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32).prefetch(tf.data.AUTOTUNE)

# Data augmentation layers
data_augmentation = keras.Sequential([
    layers.RandomFlip("horizontal_and_vertical"),
    layers.RandomRotation(0.2),
    layers.RandomZoom(0.2),
    layers.RandomContrast(0.2),
    # Add other augmentation layers as needed
])

# Rescaling layer (normalize pixel values)
preprocess_rescale = layers.Rescaling(1./255)

# Build a model incorporating these layers
model_with_preprocessing = keras.Sequential([
    layers.Input(shape=(32, 32, 3)),
    data_augmentation,         # Augmentation applied to training data
    preprocess_rescale,        # Rescaling applied to all data
    layers.Conv2D(32, 3, activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(64, 3, activation='relu'),
    layers.MaxPooling2D(),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')
])

model_with_preprocessing.compile(optimizer='adam',
                                 loss='sparse_categorical_crossentropy',
                                 metrics=['accuracy'])

print("\nModel with Keras preprocessing layers:")
model_with_preprocessing.summary()

# Train the model (using the prepared tf.data.Dataset)
print("\nTraining Model with Keras Preprocessing Layers (dummy data)...")
history_layers = model_with_preprocessing.fit(
    train_ds,
    epochs=1, # Reduced for demonstration
    validation_data=test_ds,
    verbose=1
)

2. Text Preprocessing

Keras offers utilities for tokenizing text, converting text to sequences, and padding sequences.

The TextVectorization layer standardizes, tokenizes, and vectorizes text inputs.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

# Sample text data
texts = [
    "The quick brown fox jumps over the lazy dog.",
    "A lazy cat sleeps on the mat.",
    "Quickly, the fox ran away.",
    "The dog barks loudly."
]

# Create a TextVectorization layer
vectorize_layer = layers.TextVectorization(
    max_tokens=None,          # Max number of unique words to keep (vocabulary size)
    output_mode='int',        # Output as integer indices
    output_sequence_length=10 # Pad/truncate sequences to this length
)

# Adapt the layer to your dataset to build the vocabulary
vectorize_layer.adapt(texts)

print("\nVocabulary:", vectorize_layer.get_vocabulary()[:10]) # Show first 10 words

# Vectorize the text data
vectorized_texts = vectorize_layer(tf.constant(texts))
print("\nVectorized texts:\n", vectorized_texts)

# Example usage in a model
text_input = keras.Input(shape=(1,), dtype=tf.string)
x = vectorize_layer(text_input)
x = layers.Embedding(input_dim=len(vectorize_layer.get_vocabulary()), output_dim=64)(x)
x = layers.GlobalAveragePooling1D()(x)
output = layers.Dense(1, activation='sigmoid')(x)
text_model = keras.Model(text_input, output)

text_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
print("\nText Model with TextVectorization layer:")
text_model.summary()

Tokenizer (Legacy but common for simple cases)

The Tokenizer class can be used to vectorize a text corpus.

from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

# Sample text data
docs = ['well done!', 'good work', 'great effort', 'nice work', 'excellent!']

# Create a tokenizer
tokenizer = Tokenizer(num_words=100) # Only consider the top 100 most frequent words
tokenizer.fit_on_texts(docs)

# Convert text to sequences of integers
sequences = tokenizer.texts_to_sequences(docs)
print("\nText Sequences:", sequences)

# Pad sequences to a uniform length
padded_sequences = pad_sequences(sequences, maxlen=5)
print("Padded Sequences:\n", padded_sequences)

# Word index (mapping word to integer)
print("\nWord Index (first 5):", list(tokenizer.word_index.items())[:5])

Further Topics:

  • Different augmentation strategies (e.g., CutMix, Mixup).
  • Preprocessing for numerical and tabular data (Normalization, CategoryEncoding).
  • Using tf.data API for efficient data pipelines.
  • Custom data generators.
  • Pre-trained embeddings (Word2Vec, GloVe, BERT) for text.

Effective data preprocessing and augmentation are vital for robust and high-performing deep learning models, particularly in domains like computer vision and natural language processing.