-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathaugment.py
executable file
·167 lines (145 loc) · 6.8 KB
/
augment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import random
import torch as th
from torch import nn
from torch.nn import functional as F
"""
Augmentations from:
https://github.com/facebookresearch/denoiser/blob/e27bf5cdcda6e6ffc3a332763411d864210f94c8/denoiser/augment.py
"""
class Remix(nn.Module):
"""Remix.
Mixes different noises with clean speech within a given batch
"""
def forward(self, sources):
noise, clean = sources
bs, *other = noise.shape
device = noise.device
perm = th.argsort(th.rand(bs, device=device), dim=0)
return th.stack([noise[perm], clean])
class RevEcho(nn.Module):
"""
Hacky Reverb but runs on GPU without slowing down training.
This reverb adds a succession of attenuated echos of the input
signal to itself. Intuitively, the delay of the first echo will happen
after roughly 2x the radius of the room and is controlled by `first_delay`.
Then RevEcho keeps adding echos with the same delay and further attenuation
until the amplitude ratio between the last and first echo is 1e-3.
The attenuation factor and the number of echos to adds is controlled
by RT60 (measured in seconds). RT60 is the average time to get to -60dB
(remember volume is measured over the squared amplitude so this matches
the 1e-3 ratio).
At each call to RevEcho, `first_delay`, `initial` and `RT60` are
sampled from their range. Then, to prevent this reverb from being too regular,
the delay time is resampled uniformly within `first_delay +- 10%`,
as controlled by the `jitter` parameter. Finally, for a denser reverb,
multiple trains of echos are added with different jitter noises.
Args:
- initial: amplitude of the first echo as a fraction
of the input signal. For each sample, actually sampled from
`[0, initial]`. Larger values means louder reverb. Physically,
this would depend on the absorption of the room walls.
- rt60: range of values to sample the RT60 in seconds, i.e.
after RT60 seconds, the echo amplitude is 1e-3 of the first echo.
The default values follow the recommendations of
https://arxiv.org/ftp/arxiv/papers/2001/2001.08662.pdf, Section 2.4.
Physically this would also be related to the absorption of the
room walls and there is likely a relation between `RT60` and
`initial`, which we ignore here.
- first_delay: range of values to sample the first echo delay in seconds.
The default values are equivalent to sampling a room of 3 to 10 meters.
- repeat: how many train of echos with differents jitters to add.
Higher values means a denser reverb.
- jitter: jitter used to make each repetition of the reverb echo train
slightly different. For instance a jitter of 0.1 means
the delay between two echos will be in the range `first_delay +- 10%`,
with the jittering noise being resampled after each single echo.
- keep_clean: fraction of the reverb of the clean speech to add back
to the ground truth. 0 = dereverberation, 1 = no dereverberation.
- sample_rate: sample rate of the input signals.
"""
def __init__(self, proba=0.5, initial=0.3, rt60=(0.3, 1.3), first_delay=(0.01, 0.03),
repeat=3, jitter=0.1, keep_clean=0.1, sample_rate=16000):
super().__init__()
self.proba = proba
self.initial = initial
self.rt60 = rt60
self.first_delay = first_delay
self.repeat = repeat
self.jitter = jitter
self.keep_clean = keep_clean
self.sample_rate = sample_rate
def _reverb(self, source, initial, first_delay, rt60):
"""
Return the reverb for a single source.
"""
length = source.shape[-1]
reverb = th.zeros_like(source)
for _ in range(self.repeat):
frac = 1 # what fraction of the first echo amplitude is still here
echo = initial * source
while frac > 1e-3:
# First jitter noise for the delay
jitter = 1 + self.jitter * random.uniform(-1, 1)
delay = min(
1 + int(jitter * first_delay * self.sample_rate),
length)
# Delay the echo in time by padding with zero on the left
echo = F.pad(echo[:, :, :-delay], (delay, 0))
reverb += echo
# Second jitter noise for the attenuation
jitter = 1 + self.jitter * random.uniform(-1, 1)
# we want, with `d` the attenuation, d**(rt60 / first_ms) = 1e-3
# i.e. log10(d) = -3 * first_ms / rt60, so that
attenuation = 10**(-3 * jitter * first_delay / rt60)
echo *= attenuation
frac *= attenuation
return reverb
def forward(self, wav):
if random.random() >= self.proba:
return wav
noise, clean = wav
# Sample characteristics for the reverb
initial = random.random() * self.initial
first_delay = random.uniform(*self.first_delay)
rt60 = random.uniform(*self.rt60)
reverb_noise = self._reverb(noise, initial, first_delay, rt60)
# Reverb for the noise is always added back to the noise
noise += reverb_noise
reverb_clean = self._reverb(clean, initial, first_delay, rt60)
# Split clean reverb among the clean speech and noise
clean += self.keep_clean * reverb_clean
noise += (1 - self.keep_clean) * reverb_clean
return th.stack([noise, clean])
class Shift(nn.Module):
"""Shift."""
def __init__(self, shift=8192, same=False):
"""__init__.
:param shift: randomly shifts the signals up to a given factor
:param same: shifts both clean and noisy files by the same factor
"""
super().__init__()
self.shift = shift
self.same = same
def forward(self, wav):
sources, batch, channels, length = wav.shape
length = length - self.shift
if self.shift > 0:
if not self.training:
wav = wav[..., :length]
else:
offsets = th.randint(
self.shift,
[1 if self.same else sources, batch, 1, 1], device=wav.device)
offsets = offsets.expand(sources, -1, channels, -1)
indexes = th.arange(length, device=wav.device)
wav = wav.gather(3, indexes + offsets)
return wav
class Augment(nn.Module):
def __init__(self, shift):
super().__init__()
self.augment = nn.Sequential(Remix(),
RevEcho(),
Shift(shift=shift)
)
def forward(self, x):
return self.augment(x)