⬡ Hub
Skip to content

intermediate_custom_neural_network

Intermediate - Custom Neural Network with TensorFlow Subclassing

Description

This project demonstrates a more advanced and flexible way to build neural networks in TensorFlow by subclassing the tf.keras.Model class. While the Sequential API is great for simple, linear stacks of layers, the subclassing API gives you complete control over the model's architecture and forward pass.

We will build a Multi-Layer Perceptron (MLP) to classify the MNIST handwritten digits, but this time, the model will be defined as a custom Python class. This approach is essential for implementing novel or complex architectures found in research papers.

Functionality

  1. Data Pipeline: The MNIST dataset is loaded and preprocessed, then fed into a tf.data.Dataset pipeline. This is an efficient way to handle data loading, shuffling, and batching.
  2. Custom Model Class:
    • A class CustomMLP is defined that inherits from tf.keras.Model.
    • In the __init__ method, all the necessary layers (e.g., Dense, Dropout) are defined.
    • In the call method, the forward pass is explicitly defined, detailing how input data flows through the layers. This method also includes a training flag to ensure that layers like Dropout behave correctly (i.e., they are only active during training).
  3. Model Training: The custom model is instantiated, compiled with an optimizer and loss function, and then trained using the standard .fit() method, showcasing its compatibility with the rest of the Keras ecosystem.
  4. Evaluation and Prediction: The trained model is evaluated on the test set, and predictions are made and visualized in the same way as with a Sequential model.

Architecture

  • TensorFlow & Keras: The project is built using TensorFlow and its Keras API.
  • tf.keras.Model Subclassing: This is the core concept of the project. It provides an object-oriented and highly customizable way to define a model.
  • tf.data: Used to create an efficient and scalable input pipeline for training and testing the model.
  • matplotlib: Used for visualizing the training history and the final predictions.

How to Run

Prerequisites

Make sure you have Python installed, along with the required libraries. You can install them using pip:

pip install tensorflow numpy matplotlib

Execution

To run the project, navigate to the project directory and execute the following command:

python intermediate_custom_neural_network.py

The script will print the model's summary, show the training progress over 10 epochs, and report the final test accuracy. It will also display plots for the training/validation loss and accuracy, as well as a sample of test images with their predicted labels.

Concepts Covered

  • Keras Subclassing API: The main takeaway of this project—how to create fully custom models by inheriting from tf.keras.Model.
  • Object-Oriented Model Building: Defining a model as a Python class.
  • Custom Forward Pass: The ability to define exactly how data flows through the model in the call method.
  • The training Argument: Understanding how to control the behavior of certain layers (like Dropout and BatchNormalization) during training vs. inference.
  • tf.data Pipelines: A best practice for feeding data into TensorFlow models efficiently.
  • Model Encapsulation: A cleaner way to organize complex model architectures.

Files and Subdirectories