This project demonstrates the training and evaluation of a neural network model on the MNIST dataset, a popular dataset of hand-written digits. The goal is to classify each image into one of the ten possible digit classes (0-9) using deep learning techniques. The project includes data loading, dataset creation, model definition, and training using k-fold cross-validation.
Before running the code, make sure you have the necessary libraries installed. You can install them using pip (ideally in a python virtual environment or using conda):
pip install torch scikit-learn imbalanced-learn matplotlib tqdm
The project is structured as follows:
README.md
: This document provides an overview of the project.mnist.py
: The Python script containing the code for data loading, dataset creation, model definition, and training.data/
: A directory that should contain the MNIST dataset files (training.pt
andtest.pt
).
The project begins by importing the necessary libraries, including PyTorch, scikit-learn, imbalanced-learn, matplotlib, and tqdm. It also checks for the availability of a CUDA-compatible GPU and sets the device accordingly.
The MNIST dataset is loaded from the provided data files training.pt
and test.pt
.
A custom CTDataset
class is defined to preprocess the dataset. It normalizes the pixel values to the range [0, 1] and provides a convenient way to access the data.
The neural network architecture (MyNeuralNet
) is defined, consisting of three connected layers with ReLU activation functions. The input size is 28x28 pixels (784), and the output size is 10 (for the ten digit classes).
The train_model_kfold
function performs k-fold cross-validation training on the dataset. It uses Stochastic Gradient Descent (SGD) optimization with a learning rate of 0.01 and Cross-Entropy Loss. For each fold:
- Train and validation subsets are created.
- Data oversampling is performed using RandomOverSampler to handle class imbalance.
- Dataloaders are created with a batch size of 20.
- The model is trained for a specified number of epochs.
- Validation accuracy is calculated and printed for each fold.
The script provides validation accuracy for each fold during k-fold cross-validation. The final results include accuracy scores for all folds, helping assess the model's performance.
This plot displays predicted digits for a subset of images from the MNIST dataset.