Keras: Transfer Learning
Transfer learning is a powerful deep learning technique where a model developed for a task is reused as the starting point for a model on a second task. It's particularly effective in computer vision and natural language processing when you have limited data for your specific problem. Keras makes it straightforward to implement transfer learning using pre-trained models.
Why Transfer Learning?
- Leverage Pre-trained Knowledge: Models trained on vast datasets (like ImageNet for images or large text corpora for NLP) have learned rich, general-purpose features. Transfer learning allows you to "transfer" this knowledge to your specific task.
- Faster Training: You don't train from scratch, significantly reducing training time.
- Better Performance with Less Data: Especially beneficial when your target dataset is small, as it helps prevent overfitting and improves generalization.
Common Strategies for Transfer Learning in Keras
-
Feature Extraction:
- Use a pre-trained model (e.g., VGG16, ResNet50, MobileNet) as a fixed feature extractor.
- This involves taking the convolutional base of a pre-trained network and running new data through it.
- You then train a new, smaller classifier (e.g., a few
Denselayers) on top of these extracted features. - The weights of the pre-trained base model are frozen (not updated during training). This is a good strategy for smaller datasets where the new task's features are similar to what the base model learned.
-
Fine-tuning:
- Unfreeze some or all layers of the pre-trained model and re-train them (with a very low learning rate) on your new dataset.
- This allows the model to adapt the general features to your specific task.
- Typically, you start with feature extraction and then move to fine-tuning if you have a larger dataset or if the new task is somewhat different from the original task.
- It's common to fine-tune only the top layers of the pre-trained model, leaving earlier (more generic feature-extracting) layers frozen.
Example: Feature Extraction with a Pre-trained MobileNetV2
Let's use MobileNetV2, pre-trained on ImageNet, as a feature extractor for a binary image classification task.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import os
import matplotlib.pyplot as plt
# --- 1. Create Dummy Dataset (if not already present) ---
# We'll simulate a binary classification task with a small dataset.
# The 'data' directory should contain 'train' and 'validation' subdirectories,
# each with 'cats' and 'dogs' sub-subdirectories holding images.
base_dir = 'transfer_learning_data'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')
os.makedirs(os.path.join(train_dir, 'cats'), exist_ok=True)
os.makedirs(os.path.join(train_dir, 'dogs'), exist_ok=True)
os.makedirs(os.path.join(validation_dir, 'cats'), exist_ok=True)
os.makedirs(os.path.join(validation_dir, 'dogs'), exist_ok=True)
# Create dummy image files (e.g., 50 per class per split)
def create_dummy_images_for_tl(directory, num_images=50):
for class_name in ['cats', 'dogs']:
class_path = os.path.join(directory, class_name)
for i in range(num_images):
img_data = np.random.randint(0, 256, size=(150, 150, 3), dtype=np.uint8)
plt.imsave(os.path.join(class_path, f'{class_name}_{i}.png'), img_data)
# create_dummy_images_for_tl(train_dir, num_images=50) # Uncomment to generate
# create_dummy_images_for_tl(validation_dir, num_images=20) # Uncomment to generate
print("Dummy dataset created/ensured.")
# --- 2. Data Preprocessing and Augmentation ---
IMG_SIZE = (150, 150) # MobileNetV2 expects input shape 224x224 or larger, let's use 150 for this demo
BATCH_SIZE = 32
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
validation_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
class_mode='binary'
)
validation_generator = validation_datagen.flow_from_directory(
validation_dir,
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
class_mode='binary'
)
# --- 3. Load Pre-trained Base Model ---
# include_top=False: Don't include the ImageNet classification head.
# weights='imagenet': Load weights pre-trained on ImageNet.
# input_shape: Specify the input shape of your images.
base_model = MobileNetV2(input_shape=IMG_SIZE + (3,),
include_top=False,
weights='imagenet')
# --- 4. Freeze the Base Model ---
# This means its weights will not be updated during training.
base_model.trainable = False
# --- 5. Build the New Classifier Head ---
# We'll add our own classification layers on top of the frozen base.
inputs = keras.Input(shape=IMG_SIZE + (3,)))
x = base_model(inputs, training=False) # Important: Pass training=False when using frozen pre-trained models
x = layers.GlobalAveragePooling2D()(x) # Reduces spatial dimensions
x = layers.Dropout(0.2)(x) # Regularization
outputs = layers.Dense(1, activation='sigmoid')(x) # Binary classification output
model = keras.Model(inputs, outputs)
# --- 6. Compile the Model ---
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), # Use a low learning rate
loss='binary_crossentropy',
metrics=['accuracy'])
model.summary()
# --- 7. Train the New Head ---
print("\n--- Training the new classification head (Feature Extraction) ---")
epochs = 10 # Short training for demonstration
history_feature_extraction = model.fit(
train_generator,
epochs=epochs,
validation_data=validation_generator
)
Example: Fine-tuning the Pre-trained MobileNetV2
After feature extraction, you might want to fine-tune the base model for better performance.
# --- 1. Unfreeze some layers of the Base Model ---
# First, unfreeze the base model
base_model.trainable = True
# Then, freeze layers closer to the input (which learn more generic features)
# Unfreeze later layers (which learn more specific features)
# Experiment with how many layers to unfreeze.
# A common practice is to freeze all layers up to a certain point.
# Let's say we unfreeze the last few blocks.
# Find out how many layers are in the base_model
print(f"\nNumber of layers in the base model: {len(base_model.layers)}")
# Unfreeze from a certain layer onwards
fine_tune_at = 100 # Example: unfreeze from layer 100 onwards
for layer in base_model.layers:
if layer.name.startswith('block'): # MobileNetV2 layers are often named with 'block'
if int(layer.name.split('_')[0].replace('block', '')) >= 13: # Unfreeze block 13 onwards
layer.trainable = True
else:
layer.trainable = False
else: # Unfreeze other top layers if any
layer.trainable = True
# --- 2. Re-compile the Model with a very low learning rate ---
# It's important to re-compile the model AFTER unfreezing layers for changes to take effect.
model.compile(optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.00001), # Very low learning rate
loss='binary_crossentropy',
metrics=['accuracy'])
model.summary() # Observe which layers are trainable now
# --- 3. Continue Training (Fine-tuning) ---
print("\n--- Fine-tuning the model ---")
epochs_fine_tune = 10
total_epochs = epochs + epochs_fine_tune
history_fine_tune = model.fit(
train_generator,
epochs=total_epochs,
initial_epoch=history_feature_extraction.epoch[-1], # Start from where feature extraction left off
validation_data=validation_generator
)
# --- 4. Save the Fine-tuned Model ---
model.save('fine_tuned_mobilenetv2_cats_dogs.h5')
print("\nFine-tuned model saved as fine_tuned_mobilenetv2_cats_dogs.h5")
Considerations for Transfer Learning:
- Model Choice: Select a pre-trained model appropriate for your domain (e.g., image models for images, NLP models for text).
- Data Size & Similarity:
- Small dataset, similar to original: Feature extraction is usually sufficient.
- Small dataset, different from original: Feature extraction, perhaps fine-tuning only the top layers.
- Large dataset, similar to original: Fine-tuning the entire model (with a small LR).
- Large dataset, different from original: Consider training from scratch, but a pre-trained model is often a good initialization.
- Learning Rate: Always use a very small learning rate for fine-tuning to avoid corrupting the pre-trained weights.
- Regularization: Dropout and other regularization techniques are important to prevent overfitting during fine-tuning.
- Input Size: Ensure your input data is reshaped/resized to match the expected input size of the pre-trained model.
Transfer learning is a cornerstone of modern deep learning, allowing you to achieve state-of-the-art performance even with limited resources.