-
Notifications
You must be signed in to change notification settings - Fork 51
/
Copy pathmms_ars.py
159 lines (119 loc) · 4.05 KB
/
mms_ars.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
from git import Repo
from dotenv import dotenv_values
import subprocess
from pydub import AudioSegment
config = dotenv_values(".env")
"""
Configuration Guide
To setup this script properly, you need to provide certain environment variables through a `.env` file.
Below is a template of what it should contain:
CURRENT_DIR=/path/to/current/dir
AUDIO_SAMPLES_DIR=/path/to/audio_samples
FAIRSEQ_DIR=/path/to/fairseq
VIDEO_FILE=/path/to/video/file
AUDIO_FILE=/path/to/audio/file
RESAMPLED_AUDIO_FILE=/path/to/resampled/audio/file
TMPDIR=/path/to/tmp
PYTHONPATH=.
PREFIX=INFER
HYDRA_FULL_ERROR=1
USER=micro
MODEL=/path/to/fairseq/models_new/mms1b_all.pt # Use full path here
LANG=eng
Additionally, you need to configure the file fairseq/examples/mms/asr/config/infer_common.yaml.
In the YAML file, use a full path for the checkpoint field like this:
checkpoint: /path/to/checkpoint/${env:USER}/${env:PREFIX}/${common_eval.results_path}
Without this change, you might encounter permission issues, unless you are running the application in a container.
If you are planning to use a CPU for computation, you also need to add the following to the YAML file as a top-level directive:
common:
cpu: true
"""
def git_clone(url, path):
"""
Clones a git repository
Parameters:
url (str): The URL of the git repository
path (str): The local path where the git repository will be cloned
"""
if not os.path.exists(path):
Repo.clone_from(url, path)
def create_dirs(*dir_paths):
"""
Creates directories
Parameters:
*dir_paths (str): Directory paths to be created
"""
for dir_path in dir_paths:
os.makedirs(dir_path, exist_ok=True)
def install_requirements(requirements):
"""
Installs pip packages
Parameters:
requirements (list): List of packages to install
"""
subprocess.check_call(["pip", "install"] + requirements)
def download_file(url, path):
"""
Downloads a file
Parameters:
url (str): URL of the file to be downloaded
path (str): The path where the file will be saved
"""
subprocess.check_call(["wget", "-P", path, url])
def convert_video_to_audio(video_path, audio_path):
"""
Converts a video file to an audio file
Parameters:
video_path (str): Path to the video file
audio_path (str): Path to the output audio file
"""
subprocess.check_call(["ffmpeg", "-i", video_path, "-ar", "16000", audio_path])
def run_inference(model, lang, audio):
"""
Runs the MMS ASR inference
Parameters:
model (str): Path to the model file
lang (str): Language of the audio file
audio (str): Path to the audio file
"""
subprocess.check_call(
[
"python",
"examples/mms/asr/infer/mms_infer.py",
"--model",
model,
"--lang",
lang,
"--audio",
audio,
]
)
def resample_audio(audio_path, new_audio_path, new_sample_rate):
"""
Resamples an audio file
Parameters:
audio_path (str): Path to the current audio file
new_audio_path (str): Path to the output audio file
new_sample_rate (int): New sample rate in Hz
"""
audio = AudioSegment.from_file(audio_path)
audio = audio.set_frame_rate(new_sample_rate)
audio.export(new_audio_path, format='wav')
if __name__ == "__main__":
current_dir = config['CURRENT_DIR']
tmp_dir = config['TMPDIR']
fairseq_dir = config['FAIRSEQ_DIR']
video_file = config['VIDEO_FILE']
audio_file = config['AUDIO_FILE']
audio_file_resampled = config['RESAMPLED_AUDIO_FILE']
model_path = config['MODEL']
lang = config['LANG']
#git_clone('https://github.com/pytorch/fairseq', 'fairseq')
#create_dirs(tmp_dir)
#install_requirements(['--editable', './'])
#download_file('https://dl.fbaipublicfiles.com/mms/asr/mms1b_all.pt', './models_new')
# convert_video_to_audio(video_file, audio_file)
#resample_audio(audio_file, audio_file_resampled, 16000)
os.chdir(fairseq_dir)
run_inference(model_path, lang, audio_file_resampled)