# Project 2: Image Classifier (MNIST) with Keras Sequential API

"""
Description:
This project demonstrates how to build a simple image classifier using TensorFlow's Keras Sequential API.
We will train a neural network to classify handwritten digits from the MNIST dataset.
This is a classic "Hello World" for deep learning and is excellent for beginners to grasp
neural network fundamentals, data preprocessing, and model training/evaluation.

Use Case:
Automated recognition of handwritten digits, a foundational task for OCR (Optical Character Recognition).

Concepts Covered:
- Loading and preprocessing image data (MNIST dataset)
- Keras Sequential API for building neural networks
- Dense (fully connected) layers
- Activation functions (ReLU, Softmax)
- Model compilation (optimizer, loss function, metrics)
- Model training (.fit() method)
- Model evaluation (.evaluate() method)
- Making predictions (.predict() method)
"""

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import numpy as np

print("TensorFlow Version:", tf.__version__)

# 1. Load and Preprocess the MNIST Dataset
# MNIST is a dataset of 60,000 training images and 10,000 testing images
# of handwritten digits (0-9).
print("\nLoading MNIST dataset...")
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Display shapes of the raw data
print(f"x_train shape: {x_train.shape} (60,000 images, 28x28 pixels)")
print(f"y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape}")
print(f"y_test shape: {y_test.shape}")

# Normalize pixel values to be between 0 and 1
# Images are 28x28 grayscale. We flatten them into a 784-dimensional vector.
x_train = x_train.reshape(-1, 784).astype("float32") / 255.0
x_test = x_test.reshape(-1, 784).astype("float32") / 255.0

# Display shapes after preprocessing
print(f"x_train shape after flattening and normalization: {x_train.shape}")
print(f"x_test shape after flattening and normalization: {x_test.shape}")

# 2. Build the Neural Network Model using Keras Sequential API
# The Sequential model is a linear stack of layers.
print("\nBuilding the model...")
model = keras.Sequential([
    # Input layer: A Dense layer with 128 units and ReLU activation.
    # input_shape=(784,) specifies the shape of the input data (784 features).
    layers.Dense(128, activation="relu", input_shape=(784,)),
    
    # Dropout layer: Randomly sets a fraction of input units to 0 at each update
    # during training time, which helps prevent overfitting.
    layers.Dropout(0.2),
    
    # Output layer: A Dense layer with 10 units (for 10 classes: digits 0-9).
    # Softmax activation ensures the output is a probability distribution over the classes.
    layers.Dense(10, activation="softmax"),
])

# Display the model architecture summary
model.summary()

# 3. Compile the Model
# Compilation configures the model for training.
print("\nCompiling the model...")
model.compile(
    # Optimizer: Adam is a popular choice for its efficiency.
    optimizer="adam",
    
    # Loss function: sparse_categorical_crossentropy is used for integer labels (0, 1, ..., 9).
    # If labels were one-hot encoded (e.g., [0,0,1,0,...]), we would use categorical_crossentropy.
    loss="sparse_categorical_crossentropy",
    
    # Metrics: What to monitor during training and testing.
    metrics=["accuracy"],
)

# 4. Train the Model
# The .fit() method trains the model for a fixed number of epochs (iterations over the dataset).
print("\nTraining the model...")
history = model.fit(
    x_train, y_train,
    epochs=10, # Number of full passes through the training dataset
    batch_size=32, # Number of samples per gradient update
    validation_split=0.1, # Use 10% of training data for validation during training
    verbose=1 # Show progress bar during training
)

# Plot training history (loss and accuracy)
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

# 5. Evaluate the Model
# Evaluate the model's performance on the unseen test data.
print("\nEvaluating the model on test data...")
test_loss, test_accuracy = model.evaluate(x_test, y_test, verbose=0)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_accuracy:.4f}")

# 6. Make Predictions
# Use the trained model to predict the class of new (test) samples.
print("\nMaking predictions on a few test samples...")
predictions = model.predict(x_test[:5]) # Predict for the first 5 test images
predicted_classes = np.argmax(predictions, axis=1)

print("\nOriginal Labels for first 5 test samples:", y_test[:5])
print("Predicted Classes for first 5 test samples:", predicted_classes)

# Visualize some predictions
plt.figure(figsize=(10, 5))
for i in range(5):
    plt.subplot(1, 5, i + 1)
    plt.imshow(x_test[i].reshape(28, 28), cmap='gray')
    plt.title(f"Pred: {predicted_classes[i]}\nTrue: {y_test[i]}")
    plt.axis('off')
plt.suptitle("Sample Predictions")
plt.show()
