-
Install PyTorch from http://pytorch.org
-
Run the following command to install additional dependencies
pip install -r requirements.txt
We will be using a dataset containing 200 different classes of birds adapted from the CUB-200-2011 dataset. Download the training/validation/test images from here. The test image labels are not provided.
Run the script main.py
to train the model:
python main.py --data [data_dir] --model-name [model_name]
Run python main.py -h
for more information, such as available model names.
As the model trains, model checkpoints are saved to files such as model_x.pth
to the current working directory.
You can take one of the checkpoints and run:
python evaluate.py --data [data_dir] --model-name [model_name] --model [model_file]
That generates a file kaggle.csv
that you can upload to the private kaggle competition website.
Modifications by Arthur Cahu
Adapted from Rob Fergus and Soumith Chintala https://github.com/soumith/traffic-sign-detection-homework.
Adaptation done by Gul Varol: https://github.com/gulvarol