Model ID | WER (test-clean) |
speech_jax_wav2vec2-large-lv60_960h | 3.38% |
speech_jax_wav2vec2-large-lv60_100h | 5.5% |
git clone
pip3 install -e .
# JAX & tensorflow should be installed by user depending on your hardware
# you don't need to install JAX & tensorflow if you are running training on Cloud TPUs
Converting librispeech data to tfrecords
# there are many librispeech splits available
# you can set `-c` & `-s` flags appropriately to download and convert those splits into tfrecords
python3 src/speech_jax/ -c clean -s train.100 -n 100
Uploading tfrecords to GCS bucket
gsutil -m cp -r clean.train.100 gs://librispeech_jax/
# similarly, you can copy other directories to your GCS bucket
Launching Cloud TPUs
# setting env variables for later use
export TPU_NAME=jax-models
export ZONE=us-central1-a
export RUNTIME_VERSION=v2-alpha
# create TPU VM
gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --zone ${ZONE} --accelerator-type ${ACCELERATOR_TYPE} --version ${RUNTIME_VERSION}
# ssh TPU VM
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE}
Starting training
# switch to relevant directory
cd projects
# following command will finetune Wav2Vec2-large model on librispeech-960h dataset
python3 configs/wav2vec2_asr.yaml
# final model is saved in the huggingface format
# => you can load it directly using `FlaxAutoModel.from_pretrained`
- Google Developers Experts program for providing GCP credits
- TPU Cloud Research for providing free TPU resources