This repository contains the comparison and evaluation of two GAN architectures for conditional image generation tasks:
- Conditional GAN (CGAN)
- Conditional StyleGAN (cStyleGAN)
We focus on evaluating image realism, diversity, and style fidelity using the following custom datasets:
- MNIST: Handwritten digit dataset. Link
- Shoe vs Sandal vs Boots: A dataset containing images of shoes, sandals, and boots. Link
- Flowers Dataset: A dataset containing images of various types of flowers. Link
- CGAN: Contains all files related to the training and architecture of the Conditional GAN.
- CStyleGAN: Contains all files related to the training and architecture of the Conditional StyleGAN.
- Notebooks: Contains all Jupyter notebooks used for training and experimentation.
- Outputs: Contains model checkpoints, output images, and transition videos of image generation over the training process.
- Python 3.7+
- CUDA 10.1 or higher (for GPU support)
- pip (Python package installer)
git clone https://github.com/yourusername/GAN-Models-Evaluation.git
cd GAN-Models-Evaluation
-
Create a virtual environment:
python -m venv venv source venv/bin/activate # On Windows use `venv\Scripts\activate`
-
Install the required packages:
pip install -r requirements.txt
Make sure the datasets are organized as follows:
datasets/
|-- MNIST/
| |-- train/
| |-- test/
|
|-- Shoe_vs_Sandal_vs_Boot/
| |-- shoe/
| |-- sandal/
| |-- boot/
|
|-- Flowers/
|-- daisy/
|-- lavender/
|-- rose/
|-- lily/
|-- sunflower/
-
Navigate to the Notebooks directory:
cd Notebooks
-
Launch Jupyter Notebook:
jupyter notebook
-
Open the respective notebook for CGAN or StyleGAN and run the cells to start training.
For running on a server or local machine without Jupyter:
-
Navigate to the respective model directory (CGAN or StyleGAN):
cd CGAN # or `cd CStyleGAN`
-
Run the training script:
python train.py -d path/to/data_dir -m path/to/model_save_path -a path/to/animation_save_path -e path/to/eval_images_save_path -t path/to/training_plot_path
The train.py
script accepts the following command-line arguments:
-d
,--DATA_DIR
: Path to the training data directory (required).-m
,--MODEL_PATH
: Path to save the trained model (required).-a
,--ANIMATION_PATH
: Path to save the animation of the training process (required).-e
,--EVAL_PATH
: Path to save evaluation images (required).-t
,--TRAINING_PLOT_PATH
: Path to save training plots (required).
python train.py -d datasets/MNIST/train -m outputs/cgan_model.pth -a outputs/cgan_animation.mp4 -e outputs/cgan_eval_images/ -t outputs/cgan_training_plot.png
You can set hyperparameters inside the train.py
script. Typical hyperparameters you might want to adjust include:
- Learning rate
- Batch size
- Number of epochs
- Noise dimension
- Specific architecture details
The CGAN model is a type of GAN where both the generator and discriminator are conditioned on some extra information, such as class labels. This enables the model to generate images that belong to a specific class.
- Architecture:
- Generator: Takes a noise vector and class label as input and generates an image.
- Discriminator: Takes an image and class label as input and outputs a probability of the image being real or fake.
The cStyleGAN model builds upon the StyleGAN architecture, incorporating class labels to generate images of specific categories with high style fidelity and diversity.
- Architecture:
- Generator: Uses style vectors to control different aspects of the generated image and is conditioned on class labels.
- Discriminator: Evaluates the realism of images conditioned on class labels.
The Outputs
directory contains:
- Model checkpoints: Saved models during training.
- Output images: Generated images at different training stages.
- Transition videos: Videos showing the progression of image generation over the training process.
We evaluate the models based on:
Fréchet Inception Distance (FID): In FID, we use the Inception network to extract features from an intermediate layer. Then we model the data distribution for these features using a multivariate Gaussian distribution with mean
Where:
-
$||\cdot||$ denotes the Euclidean distance between vectors. -
$\text{Tr}(\cdot)$ represents the trace operator, which computes the sum of the diagonal elements of a matrix. -
$(\Sigma_{\text{real}}\Sigma_{\text{gen}})^{1/2}$ denotes the matrix square root of the product of the covariance matrices.
This equation quantifies the dissimilarity between the feature distributions of the real and generated images, considering both their means and covariances. A lower FID score indicates a higher similarity between the distributions, suggesting that the generated images better match the characteristics of the real images.
Dataset | Model | FID Score |
---|---|---|
MNIST | cGAN | 2.5 |
cSGAN | 3.8 | |
Shoe vs Sandal vs Boot | cGAN | 24.3 |
cSGAN | 10.1 | |
Flower Classification | cGAN | 30.7 |
cSGAN | 20.5 |
The training timelapses for all the datasets can be found at: `Outputs/CGAN/results'
CGAN - MNIST Results:
CGAN - Shoe Results:
CGAN - Flower Results:
The training timelapses for all the datasets can be found at: `Outputs/StyleGAN/results'
CStyleGAN - MNIST Results:
CStyleGAN - Shoe Results:
CStyleGAN - Flower Results:
- This project uses datasets provided by M. Stephenson and others.
- Inspired by the work on GANs and StyleGAN by leading researchers in the field.