intermediate_custom_neural_network
Intermediate - Custom Neural Network with TensorFlow Subclassing
Description
This project demonstrates a more advanced and flexible way to build neural networks in TensorFlow by subclassing the tf.keras.Model class. While the Sequential API is great for simple, linear stacks of layers, the subclassing API gives you complete control over the model's architecture and forward pass.
We will build a Multi-Layer Perceptron (MLP) to classify the MNIST handwritten digits, but this time, the model will be defined as a custom Python class. This approach is essential for implementing novel or complex architectures found in research papers.
Functionality
- Data Pipeline: The MNIST dataset is loaded and preprocessed, then fed into a
tf.data.Datasetpipeline. This is an efficient way to handle data loading, shuffling, and batching. - Custom Model Class:
- A class
CustomMLPis defined that inherits fromtf.keras.Model. - In the
__init__method, all the necessary layers (e.g.,Dense,Dropout) are defined. - In the
callmethod, the forward pass is explicitly defined, detailing how input data flows through the layers. This method also includes atrainingflag to ensure that layers likeDropoutbehave correctly (i.e., they are only active during training).
- A class
- Model Training: The custom model is instantiated, compiled with an optimizer and loss function, and then trained using the standard
.fit()method, showcasing its compatibility with the rest of the Keras ecosystem. - Evaluation and Prediction: The trained model is evaluated on the test set, and predictions are made and visualized in the same way as with a
Sequentialmodel.
Architecture
TensorFlow&Keras: The project is built using TensorFlow and its Keras API.tf.keras.ModelSubclassing: This is the core concept of the project. It provides an object-oriented and highly customizable way to define a model.tf.data: Used to create an efficient and scalable input pipeline for training and testing the model.matplotlib: Used for visualizing the training history and the final predictions.
How to Run
Prerequisites
Make sure you have Python installed, along with the required libraries. You can install them using pip:
pip install tensorflow numpy matplotlib
Execution
To run the project, navigate to the project directory and execute the following command:
python intermediate_custom_neural_network.py
The script will print the model's summary, show the training progress over 10 epochs, and report the final test accuracy. It will also display plots for the training/validation loss and accuracy, as well as a sample of test images with their predicted labels.
Concepts Covered
- Keras Subclassing API: The main takeaway of this project—how to create fully custom models by inheriting from
tf.keras.Model. - Object-Oriented Model Building: Defining a model as a Python class.
- Custom Forward Pass: The ability to define exactly how data flows through the model in the
callmethod. - The
trainingArgument: Understanding how to control the behavior of certain layers (likeDropoutandBatchNormalization) during training vs. inference. tf.dataPipelines: A best practice for feeding data into TensorFlow models efficiently.- Model Encapsulation: A cleaner way to organize complex model architectures.