-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathfsapi.py
executable file
·96 lines (82 loc) · 2.71 KB
/
fsapi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import json
import torch
import numpy as np
from fs_two.model import FastSpeech2
class FSTWOapi:
def __init__(self, config, device=0):
weights_path = config.tts.weights_path
model_folder = "/".join(weights_path.split("/")[:-1])
config.preprocess_config.path.preprocessed_path = model_folder
self.speakers_dict, self.speaker_names = load_speakers_json(
config.preprocess_config.path.preprocessed_path
)
self.model = FastSpeech2(
config.preprocess_config,
config.model_config,
len(self.speaker_names),
).to(device)
# Load checkpoint if exists
self.weights_path = weights_path
if weights_path is not None:
checkpoint = torch.load(weights_path, map_location="cpu")
state = checkpoint["model"]
state['speaker_emb.weight'] = checkpoint["embedding"]
self.model.load_state_dict(checkpoint["model"])
self.cfg = config
self.device = device
# TODO get the righ restore step
self.restore_step = 0
def generate(
self,
phonemes,
duration_control=1.0,
pitch_control=1.0,
energy_control=1.0,
speaker_name=None,
):
if speaker_name is not None:
if not speaker_name in self.speakers_dict:
raise Exception(
f"Speaker {speaker_name} was not found in speakers.json"
)
speaker_id = self.speakers_dict[speaker_name]
speaker = torch.tensor(speaker_id).long().unsqueeze(0)
speaker = speaker.to(self.device)
self.model.eval()
src_len = np.array([len(phonemes[0])])
result = self.model(
speaker,
torch.from_numpy(phonemes).long().to(self.device),
torch.from_numpy(src_len).to(self.device),
max(src_len),
d_control=duration_control,
p_control=pitch_control,
e_control=energy_control,
)
(
output,
p_predictions,
e_predictions,
log_d_predictions,
d_rounded,
src_masks,
mel_masks,
src_lens,
mel_lens,
postnet_output,
pitch_mean,
pitch_std,
) = result
return postnet_output
def load_speakers_json(dir_path):
json_paht = os.path.join(dir_path, "speakers.json")
if os.path.exists(json_paht):
with open(
json_paht,
"r",
) as f:
speakers = json.load(f)
else:
print(f'Did not find speakers.josn at {dir_path}')
return speakers, list(speakers.keys())