-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathaugment.py
101 lines (75 loc) · 3.71 KB
/
augment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Implementation of SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
# Ref: https://arxiv.org/pdf/1904.08779.pdf
import random
import numpy as np
import tensorflow as tf
from tensorflow_addons.image import sparse_image_warp
class SpecAugment():
'''
Augmentation Parameters for policies
-----------------------------------------
Policy | W | F | m_F | T | p | m_T
-----------------------------------------
None | 0 | 0 | - | 0 | - | -
-----------------------------------------
LB | 80 | 27 | 1 | 100 | 1.0 | 1
-----------------------------------------
LD | 80 | 27 | 2 | 100 | 1.0 | 2
-----------------------------------------
SM | 40 | 15 | 2 | 70 | 0.2 | 2
-----------------------------------------
SS | 40 | 27 | 2 | 70 | 0.2 | 2
-----------------------------------------
LB : LibriSpeech basic
LD : LibriSpeech double
SM : Switchboard mild
SS : Switchboard strong
W : Time Warp parameter
F : Frequency Mask parameter
m_F : Number of Frequency masks
T : Time Mask parameter
p : Parameter for calculating upper bound for time mask
m_T : Number of time masks
'''
def __init__(self, mel_spectrogram, policy, zero_mean_normalized=True):
self.mel_spectrogram = mel_spectrogram
self.policy = policy
self.zero_mean_normalized = zero_mean_normalized
# Policy Specific Parameters
if self.policy == 'LB':
self.W, self.F, self.m_F, self.T, self.p, self.m_T = 80, 27, 1, 100, 1.0, 1
elif self.policy == 'LD':
self.W, self.F, self.m_F, self.T, self.p, self.m_T = 80, 27, 2, 100, 1.0, 2
elif self.policy == 'SM':
self.W, self.F, self.m_F, self.T, self.p, self.m_T = 40, 15, 2, 70, 0.2, 2
elif self.policy == 'SS':
self.W, self.F, self.m_F, self.T, self.p, self.m_T = 40, 27, 2, 70, 0.2, 2
def time_warp(self):
# Reshape to [Batch_size, time, freq, 1] for sparse_image_warp func.
self.mel_spectrogram = np.reshape(self.mel_spectrogram, (-1, self.mel_spectrogram.shape[0], self.mel_spectrogram.shape[1], 1))
v, tau = self.mel_spectrogram.shape[1], self.mel_spectrogram.shape[2]
horiz_line_thru_ctr = self.mel_spectrogram[0][v//2]
random_pt = horiz_line_thru_ctr[random.randrange(self.W, tau - self.W)] # random point along the horizontal/time axis
w = np.random.uniform((-self.W), self.W) # distance
# Source Points
src_points = [[[v//2, random_pt[0]]]]
# Destination Points
dest_points = [[[v//2, random_pt[0] + w]]]
self.mel_spectrogram, _ = sparse_image_warp(self.mel_spectrogram, src_points, dest_points, num_boundary_points=2)
return self.mel_spectrogram
def freq_mask(self):
v = self.mel_spectrogram.shape[1] # no. of mel bins
# apply m_F frequency masks to the mel spectrogram
for i in range(self.m_F):
f = int(np.random.uniform(0, self.F)) # [0, F)
f0 = random.randint(0, v - f) # [0, v - f)
self.mel_spectrogram[:, f0:f0 + f, :, :] = 0
return self.mel_spectrogram
def time_mask(self):
tau = self.mel_spectrogram.shape[2] # time frames
# apply m_T time masks to the mel spectrogram
for i in range(self.m_T):
t = int(np.random.uniform(0, self.T)) # [0, T)
t0 = random.randint(0, tau - t) # [0, tau - t)
self.mel_spectrogram[:, :, t0:t0 + t, :] = 0
return self.mel_spectrogram