diff --git a/xtts-streaming/Dockerfile b/xtts-streaming/Dockerfile new file mode 100644 index 00000000..798ad460 --- /dev/null +++ b/xtts-streaming/Dockerfile @@ -0,0 +1,12 @@ +FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-devel +ENV DEBIAN_FRONTEND=noninteractive +ENV PYTHONUNBUFFERED True +ENV NVIDIA_DISABLE_REQUIRE=0 + +RUN apt-get update && \ + apt-get install --no-install-recommends -y sox libsox-fmt-all curl wget gcc git git-lfs build-essential libaio-dev libsndfile1 ssh ffmpeg && \ + apt-get clean && apt-get -y autoremove + +COPY requirements.txt . +RUN python -m pip install --use-deprecated=legacy-resolver -r requirements.txt \ + && python -m pip cache purge diff --git a/xtts-streaming/config.yaml b/xtts-streaming/config.yaml index 07996915..a53b035a 100644 --- a/xtts-streaming/config.yaml +++ b/xtts-streaming/config.yaml @@ -1,14 +1,14 @@ +base_image: + image: htrivedi05/xtts-streaming + python_executable_path: /opt/conda/bin/python environment_variables: COQUI_TOS_AGREED: '1' external_package_dirs: [] model_metadata: {} model_name: XTTS Streaming -python_version: py310 -requirements_file: ./requirements.txt resources: - accelerator: T4 + accelerator: H100 cpu: '3' memory: 10Gi use_gpu: true secrets: {} -system_packages: [] diff --git a/xtts-streaming/model/model.py b/xtts-streaming/model/model.py index 3c7d10b6..b59a87d9 100644 --- a/xtts-streaming/model/model.py +++ b/xtts-streaming/model/model.py @@ -1,8 +1,5 @@ -import base64 -import io import logging import os -import wave import numpy as np import torch @@ -23,7 +20,7 @@ def __init__(self, **kwargs): def load(self): device = "cuda" model_name = "tts_models/multilingual/multi-dataset/xtts_v2" - logging.info("⏳Downloading model") + logging.info("⏳ Downloading model") ModelManager().download_model(model_name) model_path = os.path.join( get_user_data_dir("tts"), model_name.replace("/", "--") @@ -32,8 +29,12 @@ def load(self): config = XttsConfig() config.load_json(os.path.join(model_path, "config.json")) self.model = Xtts.init_from_config(config) - self.model.load_checkpoint(config, checkpoint_dir=model_path, eval=True) + # self.model.load_checkpoint(config, checkpoint_dir=model_path, eval=True) + self.model.load_checkpoint( + config, checkpoint_dir=model_path, eval=True, use_deepspeed=True + ) self.model.to(device) + # self.compiled_model = torch.compile(self.model.inference_stream) self.speaker = { "speaker_embedding": self.model.speaker_manager.speakers[SPEAKER_NAME][ @@ -51,7 +52,18 @@ def load(self): .half() .tolist(), } - logging.info("🔥Model Loaded") + + self.speaker_embedding = ( + torch.tensor(self.speaker.get("speaker_embedding")) + .unsqueeze(0) + .unsqueeze(-1) + ) + self.gpt_cond_latent = ( + torch.tensor(self.speaker.get("gpt_cond_latent")) + .reshape((-1, 1024)) + .unsqueeze(0) + ) + logging.info("🔥 Model Loaded") def wav_postprocess(self, wav): """Post process the output waveform""" @@ -66,32 +78,21 @@ def predict(self, model_input): text = model_input.get("text") language = model_input.get("language", "en") chunk_size = int( - model_input.get("chunk_size", 150) + model_input.get("chunk_size", 20) ) # Ensure chunk_size is an integer add_wav_header = False - speaker_embedding = ( - torch.tensor(self.speaker.get("speaker_embedding")) - .unsqueeze(0) - .unsqueeze(-1) - ) - gpt_cond_latent = ( - torch.tensor(self.speaker.get("gpt_cond_latent")) - .reshape((-1, 1024)) - .unsqueeze(0) - ) - streamer = self.model.inference_stream( text, language, - gpt_cond_latent, - speaker_embedding, + self.gpt_cond_latent, + self.speaker_embedding, stream_chunk_size=chunk_size, enable_text_splitting=True, + temperature=0.2, ) for chunk in streamer: - print(type(chunk)) processed_chunk = self.wav_postprocess(chunk) processed_bytes = processed_chunk.tobytes() yield processed_bytes diff --git a/xtts-streaming/requirements.txt b/xtts-streaming/requirements.txt index a5d44b0f..c2e4ca22 100644 --- a/xtts-streaming/requirements.txt +++ b/xtts-streaming/requirements.txt @@ -1 +1,9 @@ git+https://github.com/coqui-ai/TTS@fa28f99f1508b5b5366539b2149963edcb80ba62 +deepspeed==0.10.3 +python-multipart==0.0.6 +typing-extensions>=4.8.0 +numpy==1.24.3 +cutlet +mecab-python3==1.0.6 +unidic-lite==1.0.8 +unidic==1.1.0