# Project 9: Generative Adversarial Network (GAN) for Image Generation

"""
Description:
This project implements a basic Generative Adversarial Network (GAN) using TensorFlow/Keras
to generate new images that resemble a training dataset. A GAN consists of two neural networks:

1.  **Generator:** Learns to create realistic-looking images from random noise.
2.  **Discriminator:** Learns to distinguish between real images from the training set and fake images generated by the generator.

These two networks are trained simultaneously in a competitive process, where the generator tries to fool the discriminator,
and the discriminator tries to get better at identifying fakes. This project provides a hands-on understanding of GAN architecture and training.

Use Case:
Generating realistic images (e.g., faces, landscapes), data augmentation, creating art, style transfer.

Concepts Covered:
- Generative Adversarial Networks (GANs) architecture
- Generator and Discriminator networks
- Adversarial training process
- Loss functions for GANs (binary cross-entropy)
- Custom training loops with `tf.GradientTape`
- Image preprocessing and visualization
"""

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
import os

print("TensorFlow Version:", tf.__version__)

# --- Configuration --- #
BUFFER_SIZE = 60000 # For shuffling the dataset
BATCH_SIZE = 256
EPOCHS = 50 # More epochs generally lead to better results, but takes longer
NOISE_DIM = 100 # Dimension of the random noise vector for the generator
NUM_EXAMPLES_TO_GENERATE = 16 # Number of images to generate at the end

# --- 1. Load and Preprocess the Dataset (MNIST) --- #
# We'll use the MNIST dataset for simplicity, as it's small and easy to work with.
print("\nLoading MNIST dataset...")
(x_train, _), (_, _) = keras.datasets.mnist.load_data()

# Normalize images to [-1, 1] range (common for GANs)
x_train = x_train.astype('float32')
x_train = (x_train - 127.5) / 127.5 # Normalize to [-1, 1]

# Add a channel dimension (MNIST images are grayscale, so add a single channel)
x_train = x_train[..., tf.newaxis]

# Create tf.data dataset
train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

print(f"x_train shape: {x_train.shape}")
print(f"Number of batches: {len(train_dataset)}")

# --- 2. Define the Generator Model --- #
# The generator takes random noise as input and outputs an image.
# It typically uses Dense layers, Batch Normalization, and Conv2DTranspose (deconvolutional) layers.

def make_generator_model():
    model = keras.Sequential()
    
    # Foundation for 7x7 image
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(NOISE_DIM,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256) # None is for batch size

    # Upsample to 14x14
    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    # Upsample to 28x28
    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    # Output layer: 28x28x1 image with tanh activation to output values in [-1, 1]
    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

generator = make_generator_model()
generator.summary()

# Test the generator with random noise
noise = tf.random.normal([1, NOISE_DIM])
generated_image = generator(noise, training=False)

plt.imshow(generated_image[0, :, :, 0] * 0.5 + 0.5, cmap='gray') # Denormalize for display
plt.title("Initial Generated Image (Noise)")
plt.axis('off')
plt.show()

# --- 3. Define the Discriminator Model --- #
# The discriminator takes an image as input and outputs a probability (0 for fake, 1 for real).
# It typically uses Conv2D layers and LeakyReLU activation.

def make_discriminator_model():
    model = keras.Sequential()
    
    # Downsample to 14x14
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    # Downsample to 7x7
    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    # Output layer: Single unit for binary classification (real/fake)
    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

discriminator = make_discriminator_model()
discriminator.summary()

# Test the discriminator with the generated image
decision = discriminator(generated_image)
print(f"Discriminator's initial decision on fake image: {decision}")

# --- 4. Define Loss Functions and Optimizers --- #
# We use Binary Cross-Entropy for both generator and discriminator losses.
# The generator's loss is based on how well it fools the discriminator.
# The discriminator's loss is based on how well it distinguishes real from fake.

cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    # Generator wants discriminator to output 1 (real) for its fake images
    return cross_entropy(tf.ones_like(fake_output), fake_output)

# Optimizers for both networks
generator_optimizer = keras.optimizers.Adam(1e-4)
discriminator_optimizer = keras.optimizers.Adam(1e-4)

# --- 5. Define the Training Step --- #
# This is the core of GAN training, where both networks are updated.

# We will reuse this seed overtime (so it's easier to visualize progress)
seed = tf.random.normal([NUM_EXAMPLES_TO_GENERATE, NOISE_DIM])

@tf.function # Compile into a TensorFlow graph for speed
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, NOISE_DIM])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    
    return gen_loss, disc_loss

# --- 6. Training Loop --- #
print("\nStarting GAN training...")
def train(dataset, epochs):
    for epoch in range(epochs):
        start = tf.timestamp()

        gen_losses = []
        disc_losses = []
        for image_batch in dataset:
            g_loss, d_loss = train_step(image_batch)
            gen_losses.append(g_loss)
            disc_losses.append(d_loss)

        # Produce images for the GIF as we go
        # display.clear_output(wait=True)
        generate_and_save_images(generator, epoch + 1, seed)

        print (f'Time for epoch {epoch + 1} is {tf.timestamp()-start:.2f} sec')
        print(f'  Generator Loss: {tf.reduce_mean(gen_losses):.4f}')
        print(f'  Discriminator Loss: {tf.reduce_mean(disc_losses):.4f}')

    # Generate after the final epoch
    # display.clear_output(wait=True)
    generate_and_save_images(generator, epochs, seed)

# Helper function to generate and save images
def generate_and_save_images(model, epoch, test_input):
    predictions = model(test_input, training=False)

    fig = plt.figure(figsize=(4, 4))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(predictions[i, :, :, 0] * 0.5 + 0.5, cmap='gray') # Denormalize for display
        plt.axis('off')

    plt.suptitle(f"Epoch {epoch}")
    plt.savefig(f'gan_image_at_epoch_{epoch:04d}.png')
    plt.close(fig) # Close figure to prevent it from displaying in non-interactive environments

# Create a directory to save generated images
if not os.path.exists('./generated_gan_images'):
    os.makedirs('./generated_gan_images')

# Change current working directory to save images there
original_cwd = os.getcwd()
os.chdir('./generated_gan_images')

train(train_dataset, EPOCHS)

# Change back to original working directory
os.chdir(original_cwd)

print("\nGAN training complete. Generated images saved in './generated_gan_images' directory.")

# --- 7. Final Image Generation --- #
print("\nGenerating final set of images...")
final_noise = tf.random.normal([NUM_EXAMPLES_TO_GENERATE, NOISE_DIM])
final_generated_images = generator(final_noise, training=False)

fig = plt.figure(figsize=(8, 8))
for i in range(final_generated_images.shape[0]):
    plt.subplot(4, 4, i+1)
    plt.imshow(final_generated_images[i, :, :, 0] * 0.5 + 0.5, cmap='gray')
    plt.axis('off')
plt.suptitle("Final Generated Images")
plt.show()
