From 2d448238a14f14ef1b5079be00646778604924da Mon Sep 17 00:00:00 2001 From: L0SG Date: Tue, 9 Jul 2024 22:58:19 -0700 Subject: [PATCH] BigVGAN-v2 release --- .gitignore | 6 + LICENSE | 2 +- README.md | 128 +++- alias_free_cuda/__init__.py | 0 alias_free_cuda/activation1d.py | 63 ++ alias_free_cuda/anti_alias_activation.cpp | 48 ++ alias_free_cuda/anti_alias_activation_cuda.cu | 314 +++++++++ alias_free_cuda/compat.h | 31 + alias_free_cuda/load.py | 72 +++ alias_free_cuda/test_activation.py | 55 ++ alias_free_cuda/test_activation_snake_beta.py | 55 ++ alias_free_cuda/type_shim.h | 97 +++ configs/bigvgan_22khz_80band.json | 2 +- configs/bigvgan_24khz_100band.json | 2 +- configs/bigvgan_base_22khz_80band.json | 2 +- configs/bigvgan_base_24khz_100band.json | 2 +- configs/bigvgan_v2_22khz_80band_256x.json | 61 ++ .../bigvgan_v2_22khz_80band_fmax8k_256x.json | 61 ++ configs/bigvgan_v2_24khz_100band_256x.json | 61 ++ configs/bigvgan_v2_44khz_128band_256x.json | 61 ++ configs/bigvgan_v2_44khz_128band_512x.json | 61 ++ incl_licenses/LICENSE_6 | 21 + incl_licenses/LICENSE_7 | 21 + incl_licenses/LICENSE_8 | 21 + inference.py | 5 +- inference_e2e.py | 3 +- meldataset.py | 11 +- models.py | 608 +++++++++++++++++- nv-modelcard++/.gitkeep | 0 nv-modelcard++/bias.md | 4 + nv-modelcard++/explainability.md | 13 + nv-modelcard++/overview.md | 115 ++++ nv-modelcard++/privacy.md | 14 + nv-modelcard++/safety.md | 6 + parse_scripts/parse_libritts.py | 2 +- requirements.txt | 6 +- test_cuda_vs_torch_model.py | 161 +++++ train.py | 90 ++- 38 files changed, 2196 insertions(+), 89 deletions(-) create mode 100644 .gitignore create mode 100644 alias_free_cuda/__init__.py create mode 100644 alias_free_cuda/activation1d.py create mode 100644 alias_free_cuda/anti_alias_activation.cpp create mode 100644 alias_free_cuda/anti_alias_activation_cuda.cu create mode 100644 alias_free_cuda/compat.h create mode 100644 alias_free_cuda/load.py create mode 100644 alias_free_cuda/test_activation.py create mode 100644 alias_free_cuda/test_activation_snake_beta.py create mode 100644 alias_free_cuda/type_shim.h create mode 100644 configs/bigvgan_v2_22khz_80band_256x.json create mode 100644 configs/bigvgan_v2_22khz_80band_fmax8k_256x.json create mode 100644 configs/bigvgan_v2_24khz_100band_256x.json create mode 100644 configs/bigvgan_v2_44khz_128band_256x.json create mode 100644 configs/bigvgan_v2_44khz_128band_512x.json create mode 100644 incl_licenses/LICENSE_6 create mode 100644 incl_licenses/LICENSE_7 create mode 100644 incl_licenses/LICENSE_8 create mode 100644 nv-modelcard++/.gitkeep create mode 100644 nv-modelcard++/bias.md create mode 100644 nv-modelcard++/explainability.md create mode 100644 nv-modelcard++/overview.md create mode 100644 nv-modelcard++/privacy.md create mode 100644 nv-modelcard++/safety.md create mode 100644 test_cuda_vs_torch_model.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f29f694 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +*.pyc +__pycache__/ +*/__pycache__/ +alias_free_cuda/build/ +exp/ +tmp/ \ No newline at end of file diff --git a/LICENSE b/LICENSE index e966359..45b7741 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2022 NVIDIA CORPORATION. +Copyright (c) 2024 NVIDIA CORPORATION. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index a6cff37..c62f00f 100644 --- a/README.md +++ b/README.md @@ -4,18 +4,32 @@
-### [Paper](https://arxiv.org/abs/2206.04658) -### [Audio demo](https://bigvgan-demo.github.io/) +### [Paper](https://arxiv.org/abs/2206.04658)   [Project page](https://research.nvidia.com/labs/adlr/projects/bigvgan/)   [Audio demo](https://bigvgan-demo.github.io/) + +## News +[Jul 2024] We release BigVGAN-v2 along with pretrained checkpoints. Below are the highlights: +* Custom CUDA kernel for inference: we provide a fused upsampling + activation kernel written in CUDA for accelerated inference speed. Our test shows 1.5 - 3x faster speed on a single A100 GPU. +* Improved discriminator and loss: BigVGAN-v2 is trained using a [multi-scale sub-band CQT discriminator](https://arxiv.org/abs/2311.14957) and a [multi-scale mel spectrogram loss](https://arxiv.org/abs/2306.06546). +* Larger training data: BigVGAN-v2 is trained using datasets containing diverse audio types, including speech in multiple languages, environmental sounds, and instruments. +* We provide pretrained checkpoints of BigVGAN-v2 using diverse audio configurations, supporting up to 44 kHz sampling rate and 512x upsampling ratio. ## Installation -Clone the repository and install dependencies. +The codebase has been tested on Python `3.10` and PyTorch `2.3.1` conda packages with either `pytorch-cuda=12.1` or `pytorch-cuda=11.8`. Below is an example command to create the conda environment: +```shell +conda create -n bigvgan python=3.10 pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia +conda activate bigvgan +``` + +Clone the repository and install dependencies: ```shell -# the codebase has been tested on Python 3.8 / 3.10 with PyTorch 1.12.1 / 1.13 conda binaries git clone https://github.com/NVIDIA/BigVGAN +cd BigVGAN pip install -r requirements.txt ``` -Create symbolic link to the root of the dataset. The codebase uses filelist with the relative path from the dataset. Below are the example commands for LibriTTS dataset. + + +Create symbolic link to the root of the dataset. The codebase uses filelist with the relative path from the dataset. Below are the example commands for LibriTTS dataset: ``` shell cd LibriTTS && \ ln -s /path/to/your/LibriTTS/train-clean-100 train-clean-100 && \ @@ -29,24 +43,25 @@ cd .. ``` ## Training -Train BigVGAN model. Below is an example command for training BigVGAN using LibriTTS dataset at 24kHz with a full 100-band mel spectrogram as input. +Train BigVGAN model. Below is an example command for training BigVGAN-v2 using LibriTTS dataset at 24kHz with a full 100-band mel spectrogram as input: ```shell python train.py \ ---config configs/bigvgan_24khz_100band.json \ +--config configs/bigvgan_v2_24khz_100band_256x.json \ --input_wavs_dir LibriTTS \ --input_training_file LibriTTS/train-full.txt \ --input_validation_file LibriTTS/val-full.txt \ --list_input_unseen_wavs_dir LibriTTS LibriTTS \ --list_input_unseen_validation_file LibriTTS/dev-clean.txt LibriTTS/dev-other.txt \ ---checkpoint_path exp/bigvgan +--checkpoint_path exp/bigvgan_v2_24khz_100band_256x ``` + ## Synthesis Synthesize from BigVGAN model. Below is an example command for generating audio from the model. It computes mel spectrograms using wav files from `--input_wavs_dir` and saves the generated audio to `--output_dir`. ```shell python inference.py \ ---checkpoint_file exp/bigvgan/g_05000000 \ +--checkpoint_file exp/bigvgan_v2_24khz_100band_256x/g_03000000 \ --input_wavs_dir /path/to/your/input_wav \ --output_dir /path/to/your/output_wav ``` @@ -57,39 +72,98 @@ It loads mel spectrograms from `--input_mels_dir` and saves the generated audio Make sure that the STFT hyperparameters for mel spectrogram are the same as the model, which are defined in `config.json` of the corresponding model. ```shell python inference_e2e.py \ ---checkpoint_file exp/bigvgan/g_05000000 \ +--checkpoint_file exp/bigvgan_v2_24khz_100band_256x/g_03000000 \ --input_mels_dir /path/to/your/input_mel \ --output_dir /path/to/your/output_wav ``` -## Pretrained Models -We provide the [pretrained models](https://drive.google.com/drive/folders/1e9wdM29d-t3EHUpBb8T4dcHrkYGAXTgq). -One can download the checkpoints of generator (e.g., g_05000000) and discriminator (e.g., do_05000000) within the listed folders. +## Using Custom CUDA Kernel for Synthesis +You can apply the fast CUDA inference kernel by using a parameter `use_cuda_kernel` when instantiating BigVGAN: + +```python +generator = BigVGAN(h, use_cuda_kernel=True) +``` + +You can also pass `--use_cuda_kernel` to `inference.py` and `inference_e2e.py` to enable this feature. -|Folder Name|Sampling Rate|Mel band|fmax|Params.|Dataset|Fine-Tuned| -|------|---|---|---|---|------|---| -|bigvgan_24khz_100band|24 kHz|100|12000|112M|LibriTTS|No| -|bigvgan_base_24khz_100band|24 kHz|100|12000|14M|LibriTTS|No| -|bigvgan_22khz_80band|22 kHz|80|8000|112M|LibriTTS + VCTK + LJSpeech|No| -|bigvgan_base_22khz_80band|22 kHz|80|8000|14M|LibriTTS + VCTK + LJSpeech|No| +When applied for the first time, it builds the kernel using `nvcc` and `ninja`. If the build succeeds, the kernel is saved to `alias_free_cuda/build` and the model automatically loads the kernel. The codebase has been tested using CUDA `12.1`. -The paper results are based on 24kHz BigVGAN models trained on LibriTTS dataset. +Please make sure that both are installed in your system and `nvcc` installed in your system matches the version your PyTorch build is using. + +We recommend running `test_cuda_vs_torch_model.py` first to build and check the correctness of the CUDA kernel. See below example command and its output, where it returns `[Success] test CUDA fused vs. plain torch BigVGAN inference`: + +```python +python test_cuda_vs_torch_model.py \ +--checkpoint_file /path/to/your/bigvgan/g_03000000 +``` + +```shell +loading plain Pytorch BigVGAN +... +loading CUDA kernel BigVGAN with auto-build +Detected CUDA files, patching ldflags +Emitting ninja build file /path/to/your/BigVGAN/alias_free_cuda/build/build.ninja... +Building extension module anti_alias_activation_cuda... +... +Loading extension module anti_alias_activation_cuda... +... +Loading '/path/to/your/bigvgan/g_03000000' +... +[Success] test CUDA fused vs. plain torch BigVGAN inference + > mean_difference=0.0007238413265440613 +... +``` + +If you see `[Fail] test CUDA fused vs. plain torch BigVGAN inference`, it means that the CUDA kernel inference is incorrect. Please check if `nvcc` installed in your system is compatible with your PyTorch version. + + +## Pretrained Models +We provide the [pretrained models](https://drive.google.com/drive/folders/1L2RDeJMBE7QAI8qV51n0QAf4mkSgUUeE?usp=sharing). +One can download the checkpoints of the generator weight (e.g., `g_(training_steps)`) and its discriminator/optimizer states (e.g., `do_(training_steps)`) within the listed folders. + +|Folder Name|Sampling Rate|Mel band|fmax|Upsampling Ratio|Params.|Dataset|Fine-Tuned| +|------|---|---|---|---|---|------|---| +|bigvgan_v2_44khz_128band_512x|44 kHz|128|22050|512|122M|Large-scale Compilation|No| +|bigvgan_v2_44khz_128band_256x|44 kHz|128|22050|256|112M|Large-scale Compilation|No| +|bigvgan_v2_24khz_100band_256x|24 kHz|100|12000|256|112M|Large-scale Compilation|No| +|bigvgan_v2_22khz_80band_256x|22 kHz|80|11025|256|112M|Large-scale Compilation|No| +|bigvgan_v2_22khz_80band_fmax8k_256x|22 kHz|80|8000|256|112M|Large-scale Compilation|No| +|bigvgan_24khz_100band|24 kHz|100|12000|256|112M|LibriTTS|No| +|bigvgan_base_24khz_100band|24 kHz|100|12000|256|14M|LibriTTS|No| +|bigvgan_22khz_80band|22 kHz|80|8000|256|112M|LibriTTS + VCTK + LJSpeech|No| +|bigvgan_base_22khz_80band|22 kHz|80|8000|256|14M|LibriTTS + VCTK + LJSpeech|No| + +The paper results are based on the original 24kHz BigVGAN models (`bigvgan_24khz_100band` and `bigvgan_base_24khz_100band`) trained on LibriTTS dataset. We also provide 22kHz BigVGAN models with band-limited setup (i.e., fmax=8000) for TTS applications. -Note that, the latest checkpoints use ``snakebeta`` activation with log scale parameterization, which have the best overall quality. +Note that the checkpoints use ``snakebeta`` activation with log scale parameterization, which have the best overall quality. +You can fine-tune the models by downloading the checkpoints (both the generator weight and its discrimiantor/optimizer states) and resuming training using your audio dataset. -## TODO +## Training Details of BigVGAN-v2 +Comapred to the original BigVGAN, the pretrained checkpoints of BigVGAN-v2 used `batch_size=32` with a longer `segment_size=65536` and are trained using 8 A100 GPUs. -Current codebase only provides a plain PyTorch implementation for the filtered nonlinearity. We are working on a fast CUDA kernel implementation, which will be released in the future. +Note that the BigVGAN-v2 `json` config files in `./configs` use `batch_size=4` as default to fit in a single A100 GPU for training. You can fine-tune the models adjusting `batch_size` depending on your GPUs. +When training BigVGAN-v2 from scratch with small batch size, it can potentially encounter the early divergence problem mentioned in the paper. In such case, we recommend lowering the `clip_grad_norm` value (e.g. `100`) for the early training iterations (e.g. 20000 steps) and increase the value to the default `500`. + +## Evaluation Results of BigVGAN-v2 +Below are the objective results of the 24kHz model (`bigvgan_v2_24khz_100band_256x`) obtained from the LibriTTS `dev` sets. BigVGAN-v2 shows noticeable improvements of the metrics. The model also exhibits reduced perceptual artifacts, especially for non-speech audio. + +|Model|Dataset|Steps|PESQ(↑)|M-STFT(↓)|MCD(↓)|Periodicity(↓)|V/UV F1(↑)| +|-------|-----|-----|-----|-----|-----|-----|-----| +|BigVGAN|LibriTTS|1M|4.027|0.7997|0.3745|0.1018|0.9598| +|BigVGAN|LibriTTS|5M|4.256|0.7409|0.2988|0.0809|0.9698| +|BigVGAN-v2|Large-scale Compilation|3M|**4.359**|**0.7134**|0.3060|**0.0621**|**0.9777**| + +## Acknowledgements +We thank Vijay Anand Korthikanti and Kevin J. Shih for their generous support in implementing the CUDA kernel for inference. ## References * [HiFi-GAN](https://github.com/jik876/hifi-gan) (for generator and multi-period discriminator) - * [Snake](https://github.com/EdwardDixon/snake) (for periodic activation) - * [Alias-free-torch](https://github.com/junjun3518/alias-free-torch) (for anti-aliasing) - * [Julius](https://github.com/adefossez/julius) (for low-pass filter) +* [UnivNet](https://github.com/mindslab-ai/univnet) (for multi-resolution discriminator) +* [descript-audio-codec](https://github.com/descriptinc/descript-audio-codec) and [vocos](https://github.com/gemelo-ai/vocos) (for multi-band multi-scale STFT discriminator and multi-scale mel spectrogram loss) +* [Amphion](https://github.com/open-mmlab/Amphion) (for multi-scale sub-band CQT discriminator) -* [UnivNet](https://github.com/mindslab-ai/univnet) (for multi-resolution discriminator) \ No newline at end of file diff --git a/alias_free_cuda/__init__.py b/alias_free_cuda/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/alias_free_cuda/activation1d.py b/alias_free_cuda/activation1d.py new file mode 100644 index 0000000..99c51f4 --- /dev/null +++ b/alias_free_cuda/activation1d.py @@ -0,0 +1,63 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +import torch +import torch.nn as nn +from alias_free_torch.resample import UpSample1d, DownSample1d +# load fused CUDA kernel: this enables importing anti_alias_activation_cuda +from alias_free_cuda import load +load.load() + +class FusedAntiAliasActivation(torch.autograd.Function): + """ + Assumes filter size 12, replication padding on upsampling, and logscale alpha/beta parameters as inputs + """ + @staticmethod + def forward(ctx, inputs, ftr, alpha, beta): + import anti_alias_activation_cuda + activation_results = anti_alias_activation_cuda.forward(inputs, ftr, alpha, beta) + return activation_results + + @staticmethod + def backward(ctx, output_grads): + # TODO: implement bwd pass + raise NotImplementedError + return output_grads, None, None + +class Activation1d(nn.Module): + def __init__(self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + fused: bool = True + ): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + self.fused = fused # whether to use fused CUDA kernel or not + + + def forward(self, x): + if not self.fused: + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + return x + else: + if self.act.__class__.__name__ == "Snake": + beta = self.act.alpha.data # snake uses same params for alpha and beta + else: + beta = self.act.beta.data # snakebeta uses different params for alpha and beta + alpha = self.act.alpha.data + if not self.act.alpha_logscale: # exp baked into cuda kernel, cancel it out with a log + alpha = torch.log(alpha) + beta = torch.log(beta) + x = FusedAntiAliasActivation.apply(x, self.upsample.filter, alpha, beta) + x = self.downsample(x) + return x diff --git a/alias_free_cuda/anti_alias_activation.cpp b/alias_free_cuda/anti_alias_activation.cpp new file mode 100644 index 0000000..c68556e --- /dev/null +++ b/alias_free_cuda/anti_alias_activation.cpp @@ -0,0 +1,48 @@ +/* coding=utf-8 + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace anti_alias_activation { + +torch::Tensor fwd_cuda(torch::Tensor const& input, + torch::Tensor const& filter, + torch::Tensor const& alpha, + torch::Tensor const& beta + ); + +torch::Tensor fwd(torch::Tensor const& input, + torch::Tensor const& filter, + torch::Tensor const& alpha, + torch::Tensor const& beta + ) { + AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); + //AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + // (input.scalar_type() == at::ScalarType::BFloat16), + // "Only fp16 and bf16 are supported"); + + return fwd_cuda(input, filter, alpha, beta); +} + +} // end namespace anti_alias_activation + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", + &anti_alias_activation::fwd, + "Anti Alias Activation -- Forward."); +} diff --git a/alias_free_cuda/anti_alias_activation_cuda.cu b/alias_free_cuda/anti_alias_activation_cuda.cu new file mode 100644 index 0000000..143ceea --- /dev/null +++ b/alias_free_cuda/anti_alias_activation_cuda.cu @@ -0,0 +1,314 @@ +/* coding=utf-8 + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "type_shim.h" +#include +#include +#include +#include +#include + +namespace { + + /* +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } + +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} +*/ + +template +__global__ void anti_alias_activation_forward( + output_t *dst, + const input_t *src, + const input_t *ftr, + const input_t *alpha, + const input_t *beta, + int batch_size, + int channels, + int seq_len) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4; + constexpr int BUFFER_SIZE = 32; + constexpr int FILTER_SIZE = 12; + constexpr int HALF_FILTER_SIZE = 6; + constexpr int REPLICATION_PAD = 5; // 5 on each side + + // blockDim/threadIdx = (128, 1, 1) + // gridDim/blockIdx = (seq_blocks, channels, batches) + int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z)); + int local_offset = threadIdx.x * BUFFER_SIZE; + int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset; + + + //int intermediate_seq_len = seq_len * 2 - 1 + 4 * REPLICATION_PAD; + //int intermediate_block_offset = (blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_seq_len * (blockIdx.y + gridDim.y * blockIdx.z)); + //int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2; + + int output_seq_len = seq_len * 2 ; // + int output_block_offset = (blockIdx.x * 128 * BUFFER_SIZE * 2 + output_seq_len * (blockIdx.y + gridDim.y * blockIdx.z)); + int output_local_offset = threadIdx.x * BUFFER_SIZE * 2; + int output_seq_offset = blockIdx.x * 128 * BUFFER_SIZE *2 + output_local_offset; + // get values needed for replication padding before moving pointer + const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z)); + input_t seq_left_most_value = right_most_pntr[0]; + input_t seq_right_most_value = right_most_pntr[seq_len - 1]; + + src += block_offset + local_offset; + dst += output_block_offset + output_local_offset ; + alpha = alpha + blockIdx.y; + input_t alpha_val = expf(alpha[0]); + beta = beta + blockIdx.y; + input_t beta_val = expf(beta[0]); + // load data from global memory + input_t elements[2*FILTER_SIZE+2*BUFFER_SIZE] = {0}; + input_t intermediates[2*FILTER_SIZE+2*BUFFER_SIZE] = {0}; + //output_t output[2*BUFFER_SIZE]; + input_t filter[FILTER_SIZE]; + //input_t temp_data[ELEMENTS_PER_LDG_STG]; + //uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; + + #pragma unroll + for (int it = 0; it < FILTER_SIZE; it+=1) { + filter[it] = ftr[it]; + } + + + #pragma unroll + for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE ; it+=1) { + int element_index = seq_offset + it; + if ((element_index < 0) && (element_index >= -REPLICATION_PAD)) { + elements[2*(HALF_FILTER_SIZE+it)] = 2*seq_left_most_value; + } + if ((element_index >= seq_len) && (element_index < seq_len + REPLICATION_PAD)) { + elements[2*(HALF_FILTER_SIZE+it)] = 2*seq_right_most_value; + } + if ((element_index >= 0) && (element_index < seq_len)) { + elements[2*(HALF_FILTER_SIZE+it)] = 2*src[it]; + } + } + + + + // apply filter + #pragma unroll + for (int it = 0; it < (2 * BUFFER_SIZE + 2*FILTER_SIZE); it+=1) { + input_t acc = 0.0; + + int element_index = output_seq_offset + it; // index for output + #pragma unroll + for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx+=1){ + if ((element_index + f_idx) >= 0){ + acc += filter[f_idx] * elements[it+f_idx]; + } + } + intermediates[it] = acc; + } + + double no_div_by_zero = 0.000000001; + #pragma unroll + for (int it = 0; it < 12 + 2 * BUFFER_SIZE; it++) { + intermediates[it] += (1.0/(beta_val + no_div_by_zero)) * sinf(intermediates[it] * alpha_val) * sinf(intermediates[it] * alpha_val); + } + + + // now copy to output + #pragma unroll + for (int it = 0; it < 2*BUFFER_SIZE; it+=1){ + int element_index = output_seq_offset + it; + if (element_index < output_seq_len) { + dst[it] = intermediates[it+6]; + } + } + + + + // for (int it = 0; it < BUFFER_SIZE; it+=ELEMENTS_PER_LDG_STG) { + // int element_index = seq_offset + it; + // if (element_index < seq_len) { + // dst[it] = output[it]; + // } + // } + + + // // Upsample convolution + // for (int it = 0; it < 2 * BUFFER_SIZE + 12; it+=1) { + // input_t acc = 0.0; + + // for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx+=1){ + // acc += filter[f_idx] * elements[it+f_idx]; + // } + // intermediates[it] = acc; + // } + + // // correct the corners of intermediates + // if (seq_offset == 0) { + // for (int it = 0; it < 6; it+=1) + // intermediates[it] = 0; + // } + + // if (seq_offset + 32 >= seq_len) { + // int offset = seq_len % 32 == 0 ? 32 : seq_len % 32; + + // for (int it = 0; it < 6; it++) { + // intermediates[6+2*offset+it] = 0; + // } + // } + + + + + // for (int it = 0; it < BUFFER_SIZE; it+=ELEMENTS_PER_LDG_STG) { + // int element_index = seq_offset + it; + // if (element_index < seq_len) { + // dst[it] = output[it]; + // } + // } +} + +template +void dispatch_anti_alias_activation_forward( + output_t *dst, + const input_t *src, + const input_t *ftr, + const input_t *alpha, + const input_t *beta, + int batch_size, + int channels, + int seq_len) +{ + if (seq_len == 0) { + return; + } else { + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + constexpr int seq_len_per_block = 4096; + int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block; + dim3 blocks(blocks_per_seq_len, channels, batch_size); + dim3 threads(threads_per_block, 1, 1); + + anti_alias_activation_forward + <<>>(dst, src, ftr, alpha, beta, batch_size, channels, seq_len); + } +} +} + +namespace anti_alias_activation { + + torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& filter, torch::Tensor const& alpha, torch::Tensor const& beta) +{ + // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = input.size(0); + const int channels = input.size(1); + const int seq_len = input.size(2); + + // Output + auto act_options = input.options().requires_grad(false); + int output_seq_len = seq_len*2; // we'll be dilating between each element by interspersing with zeros + + torch::Tensor anti_alias_activation_results = + torch::empty({batches, channels, output_seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* filter_ptr = static_cast(filter.data_ptr()); + void* alpha_ptr = static_cast(alpha.data_ptr()); + void* beta_ptr = static_cast(beta.data_ptr()); + void* anti_alias_activation_results_ptr = static_cast(anti_alias_activation_results.data_ptr()); + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch anti alias activation_forward", + dispatch_anti_alias_activation_forward( + reinterpret_cast(anti_alias_activation_results_ptr), + reinterpret_cast(input_ptr), + reinterpret_cast(filter_ptr), + reinterpret_cast(alpha_ptr), + reinterpret_cast(beta_ptr), + batches, + channels, + seq_len); + ); + return anti_alias_activation_results; +} +} diff --git a/alias_free_cuda/compat.h b/alias_free_cuda/compat.h new file mode 100644 index 0000000..92e7eb7 --- /dev/null +++ b/alias_free_cuda/compat.h @@ -0,0 +1,31 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*This code is copied fron NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + + + +#ifndef TORCH_CHECK +#define TORCH_CHECK AT_CHECK +#endif + +#ifdef VERSION_GE_1_3 +#define DATA_PTR data_ptr +#else +#define DATA_PTR data +#endif diff --git a/alias_free_cuda/load.py b/alias_free_cuda/load.py new file mode 100644 index 0000000..9d0db42 --- /dev/null +++ b/alias_free_cuda/load.py @@ -0,0 +1,72 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +import os +import pathlib +import subprocess + +from torch.utils import cpp_extension + +# Setting this param to a list has a problem of generating different +# compilation commands (with diferent order of architectures) and +# leading to recompilation of fused kernels. Set it to empty string +# to avoid recompilation and assign arch flags explicity in +# extra_cuda_cflags below +os.environ["TORCH_CUDA_ARCH_LIST"] = "" + + +def load(): + # Check if cuda 11 is installed for compute capability 8.0 + cc_flag = [] + _, bare_metal_major, _ = _get_cuda_bare_metal_version( + cpp_extension.CUDA_HOME) + if int(bare_metal_major) >= 11: + cc_flag.append('-gencode') + cc_flag.append('arch=compute_80,code=sm_80') + + # Build path + srcpath = pathlib.Path(__file__).parent.absolute() + buildpath = srcpath / 'build' + _create_build_dir(buildpath) + + # Helper function to build the kernels. + def _cpp_extention_load_helper(name, sources, extra_cuda_flags): + return cpp_extension.load( + name=name, + sources=sources, + build_directory=buildpath, + extra_cflags=['-O3',], + extra_cuda_cflags=['-O3', + '-gencode', 'arch=compute_70,code=sm_70', + '--use_fast_math'] + extra_cuda_flags + cc_flag, + verbose=True + ) + + extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__', + '--expt-relaxed-constexpr', + '--expt-extended-lambda'] + + sources=[srcpath / 'anti_alias_activation.cpp', + srcpath / 'anti_alias_activation_cuda.cu'] + anti_alias_activation_cuda = _cpp_extention_load_helper( + "anti_alias_activation_cuda", sources, extra_cuda_flags) + +def _get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], + universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + + return raw_output, bare_metal_major, bare_metal_minor + + +def _create_build_dir(buildpath): + try: + os.mkdir(buildpath) + except OSError: + if not os.path.isdir(buildpath): + print(f"Creation of the build directory {buildpath} failed") \ No newline at end of file diff --git a/alias_free_cuda/test_activation.py b/alias_free_cuda/test_activation.py new file mode 100644 index 0000000..b8518a6 --- /dev/null +++ b/alias_free_cuda/test_activation.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +import math +import torch +import alias_free_cuda +from alias_free_cuda import activation1d +from activations import Snake, SnakeBeta + +def test_load_fused_kernels(): + try: + import alias_free_cuda + import torch + print("[Success] load_fused_kernels") + except ImportError as e: + print("[Fail] load_fused_kernels") + raise e + +def test_anti_alias_activation(): + data = torch.rand((10, 10, 50000), device='cuda') + + # check activations.Snake cuda vs. torch + fused_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=True).cuda() + fused_activation_output = fused_anti_alias_activation(data) + + torch_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=False).cuda() + torch_activation_output = torch_anti_alias_activation(data) + + test_result = (fused_activation_output - torch_activation_output).abs() + + while test_result.dim() != 1: + test_result = test_result.mean(dim=-1) + + diff = test_result.mean(dim=-1) + + if diff <= 1e-3: + print( + f"\n[Success] test_fused_anti_alias_activation" + f"\n > mean_difference={diff}" + f"\n > fused_values={fused_activation_output[-1][-1][-100:].tolist()}" + f"\n > torch_values={torch_activation_output[-1][-1][-100:].tolist()}" + ) + else: + print( + f"\n[Fail] test_fused_anti_alias_activation" + f"\n > mean_difference={diff}, " + f"\n > fused_values={fused_activation_output[-1][-1][-30:].tolist()}, " + f"\n > torch_values={torch_activation_output[-1][-1][-30:].tolist()}" + ) + +if __name__ == "__main__": + from alias_free_cuda import load + load.load() + test_load_fused_kernels() + test_anti_alias_activation() diff --git a/alias_free_cuda/test_activation_snake_beta.py b/alias_free_cuda/test_activation_snake_beta.py new file mode 100644 index 0000000..49bf55c --- /dev/null +++ b/alias_free_cuda/test_activation_snake_beta.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +import math +import torch +import alias_free_cuda +from alias_free_cuda import activation1d +from activations import Snake, SnakeBeta + +def test_load_fused_kernels(): + try: + import alias_free_cuda + import torch + print("[Success] load_fused_kernels") + except ImportError as e: + print("[Fail] load_fused_kernels") + raise e + +def test_anti_alias_activation(): + data = torch.rand((10, 10, 50000), device='cuda') + + # check activations.Snake cuda vs. torch + fused_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=True).cuda() + fused_activation_output = fused_anti_alias_activation(data) + + torch_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=False).cuda() + torch_activation_output = torch_anti_alias_activation(data) + + test_result = (fused_activation_output - torch_activation_output).abs() + + while test_result.dim() != 1: + test_result = test_result.mean(dim=-1) + + diff = test_result.mean(dim=-1) + + if diff <= 1e-3: + print( + f"\n[Success] test_fused_anti_alias_activation" + f"\n > mean_difference={diff}" + f"\n > fused_values={fused_activation_output[-1][-1][-100:].tolist()}" + f"\n > torch_values={torch_activation_output[-1][-1][-100:].tolist()}" + ) + else: + print( + f"\n[Fail] test_fused_anti_alias_activation" + f"\n > mean_difference={diff}, " + f"\n > fused_values={fused_activation_output[-1][-1][-30:].tolist()}, " + f"\n > torch_values={torch_activation_output[-1][-1][-30:].tolist()}" + ) + +if __name__ == "__main__": + from alias_free_cuda import load + load.load() + test_load_fused_kernels() + test_anti_alias_activation() diff --git a/alias_free_cuda/type_shim.h b/alias_free_cuda/type_shim.h new file mode 100644 index 0000000..d30a7e4 --- /dev/null +++ b/alias_free_cuda/type_shim.h @@ -0,0 +1,97 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include +#include "compat.h" + + +#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + + + +#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ + switch(TYPEIN) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_in = float; \ + switch(TYPEOUT) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_in = at::Half; \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_in = at::BFloat16; \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ + } + diff --git a/configs/bigvgan_22khz_80band.json b/configs/bigvgan_22khz_80band.json index 9ebdf92..64bca78 100644 --- a/configs/bigvgan_22khz_80band.json +++ b/configs/bigvgan_22khz_80band.json @@ -5,7 +5,7 @@ "learning_rate": 0.0001, "adam_b1": 0.8, "adam_b2": 0.99, - "lr_decay": 0.999, + "lr_decay": 0.9999996, "seed": 1234, "upsample_rates": [4,4,2,2,2,2], diff --git a/configs/bigvgan_24khz_100band.json b/configs/bigvgan_24khz_100band.json index d9988a9..e7f7ff0 100644 --- a/configs/bigvgan_24khz_100band.json +++ b/configs/bigvgan_24khz_100band.json @@ -5,7 +5,7 @@ "learning_rate": 0.0001, "adam_b1": 0.8, "adam_b2": 0.99, - "lr_decay": 0.999, + "lr_decay": 0.9999996, "seed": 1234, "upsample_rates": [4,4,2,2,2,2], diff --git a/configs/bigvgan_base_22khz_80band.json b/configs/bigvgan_base_22khz_80band.json index 32979f5..fd24484 100644 --- a/configs/bigvgan_base_22khz_80band.json +++ b/configs/bigvgan_base_22khz_80band.json @@ -5,7 +5,7 @@ "learning_rate": 0.0001, "adam_b1": 0.8, "adam_b2": 0.99, - "lr_decay": 0.999, + "lr_decay": 0.9999996, "seed": 1234, "upsample_rates": [8,8,2,2], diff --git a/configs/bigvgan_base_24khz_100band.json b/configs/bigvgan_base_24khz_100band.json index 889a77c..0911508 100644 --- a/configs/bigvgan_base_24khz_100band.json +++ b/configs/bigvgan_base_24khz_100band.json @@ -5,7 +5,7 @@ "learning_rate": 0.0001, "adam_b1": 0.8, "adam_b2": 0.99, - "lr_decay": 0.999, + "lr_decay": 0.9999996, "seed": 1234, "upsample_rates": [8,8,2,2], diff --git a/configs/bigvgan_v2_22khz_80band_256x.json b/configs/bigvgan_v2_22khz_80band_256x.json new file mode 100644 index 0000000..e96bd5f --- /dev/null +++ b/configs/bigvgan_v2_22khz_80band_256x.json @@ -0,0 +1,61 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 4, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.9999996, + "seed": 1234, + + "upsample_rates": [4,4,2,2,2,2], + "upsample_kernel_sizes": [8,8,4,4,4,4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "use_tanh_at_final": false, + "use_bias_at_final": false, + + "activation": "snakebeta", + "snake_logscale": true, + + "use_cqtd_instead_of_mrd": true, + "cqtd_filters": 128, + "cqtd_max_filters": 1024, + "cqtd_filters_scale": 1, + "cqtd_dilations": [1, 2, 4], + "cqtd_hop_lengths": [512, 256, 256], + "cqtd_n_octaves": [9, 9, 9], + "cqtd_bins_per_octaves": [24, 36, 48], + + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "use_multiscale_melloss": true, + "lambda_melloss": 15, + + "clip_grad_norm": 500, + + "segment_size": 65536, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 22050, + + "fmin": 0, + "fmax": null, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/configs/bigvgan_v2_22khz_80band_fmax8k_256x.json b/configs/bigvgan_v2_22khz_80band_fmax8k_256x.json new file mode 100644 index 0000000..a3c9699 --- /dev/null +++ b/configs/bigvgan_v2_22khz_80band_fmax8k_256x.json @@ -0,0 +1,61 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 4, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.9999996, + "seed": 1234, + + "upsample_rates": [4,4,2,2,2,2], + "upsample_kernel_sizes": [8,8,4,4,4,4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "use_tanh_at_final": false, + "use_bias_at_final": false, + + "activation": "snakebeta", + "snake_logscale": true, + + "use_cqtd_instead_of_mrd": true, + "cqtd_filters": 128, + "cqtd_max_filters": 1024, + "cqtd_filters_scale": 1, + "cqtd_dilations": [1, 2, 4], + "cqtd_hop_lengths": [512, 256, 256], + "cqtd_n_octaves": [9, 9, 9], + "cqtd_bins_per_octaves": [24, 36, 48], + + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "use_multiscale_melloss": true, + "lambda_melloss": 15, + + "clip_grad_norm": 500, + + "segment_size": 65536, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 22050, + + "fmin": 0, + "fmax": 8000, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/configs/bigvgan_v2_24khz_100band_256x.json b/configs/bigvgan_v2_24khz_100band_256x.json new file mode 100644 index 0000000..8057ee2 --- /dev/null +++ b/configs/bigvgan_v2_24khz_100band_256x.json @@ -0,0 +1,61 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 4, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.9999996, + "seed": 1234, + + "upsample_rates": [4,4,2,2,2,2], + "upsample_kernel_sizes": [8,8,4,4,4,4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "use_tanh_at_final": false, + "use_bias_at_final": false, + + "activation": "snakebeta", + "snake_logscale": true, + + "use_cqtd_instead_of_mrd": true, + "cqtd_filters": 128, + "cqtd_max_filters": 1024, + "cqtd_filters_scale": 1, + "cqtd_dilations": [1, 2, 4], + "cqtd_hop_lengths": [512, 256, 256], + "cqtd_n_octaves": [9, 9, 9], + "cqtd_bins_per_octaves": [24, 36, 48], + + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "use_multiscale_melloss": true, + "lambda_melloss": 15, + + "clip_grad_norm": 500, + + "segment_size": 65536, + "num_mels": 100, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 24000, + + "fmin": 0, + "fmax": null, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/configs/bigvgan_v2_44khz_128band_256x.json b/configs/bigvgan_v2_44khz_128band_256x.json new file mode 100644 index 0000000..b6999d3 --- /dev/null +++ b/configs/bigvgan_v2_44khz_128band_256x.json @@ -0,0 +1,61 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 4, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.9999996, + "seed": 1234, + + "upsample_rates": [4,4,2,2,2,2], + "upsample_kernel_sizes": [8,8,4,4,4,4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "use_tanh_at_final": false, + "use_bias_at_final": false, + + "activation": "snakebeta", + "snake_logscale": true, + + "use_cqtd_instead_of_mrd": true, + "cqtd_filters": 128, + "cqtd_max_filters": 1024, + "cqtd_filters_scale": 1, + "cqtd_dilations": [1, 2, 4], + "cqtd_hop_lengths": [512, 256, 256], + "cqtd_n_octaves": [9, 9, 9], + "cqtd_bins_per_octaves": [24, 36, 48], + + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "use_multiscale_melloss": true, + "lambda_melloss": 15, + + "clip_grad_norm": 500, + + "segment_size": 65536, + "num_mels": 128, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 44100, + + "fmin": 0, + "fmax": null, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/configs/bigvgan_v2_44khz_128band_512x.json b/configs/bigvgan_v2_44khz_128band_512x.json new file mode 100644 index 0000000..2d7176c --- /dev/null +++ b/configs/bigvgan_v2_44khz_128band_512x.json @@ -0,0 +1,61 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 4, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.9999996, + "seed": 1234, + + "upsample_rates": [8,4,2,2,2,2], + "upsample_kernel_sizes": [16,8,4,4,4,4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "use_tanh_at_final": false, + "use_bias_at_final": false, + + "activation": "snakebeta", + "snake_logscale": true, + + "use_cqtd_instead_of_mrd": true, + "cqtd_filters": 128, + "cqtd_max_filters": 1024, + "cqtd_filters_scale": 1, + "cqtd_dilations": [1, 2, 4], + "cqtd_hop_lengths": [512, 256, 256], + "cqtd_n_octaves": [9, 9, 9], + "cqtd_bins_per_octaves": [24, 36, 48], + + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "use_multiscale_melloss": true, + "lambda_melloss": 15, + + "clip_grad_norm": 500, + + "segment_size": 65536, + "num_mels": 128, + "num_freq": 2049, + "n_fft": 2048, + "hop_size": 512, + "win_size": 2048, + + "sampling_rate": 44100, + + "fmin": 0, + "fmax": null, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/incl_licenses/LICENSE_6 b/incl_licenses/LICENSE_6 new file mode 100644 index 0000000..2569ec0 --- /dev/null +++ b/incl_licenses/LICENSE_6 @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023-present, Descript + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/incl_licenses/LICENSE_7 b/incl_licenses/LICENSE_7 new file mode 100644 index 0000000..c37bdaf --- /dev/null +++ b/incl_licenses/LICENSE_7 @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Charactr Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/incl_licenses/LICENSE_8 b/incl_licenses/LICENSE_8 new file mode 100644 index 0000000..ab3d7ff --- /dev/null +++ b/incl_licenses/LICENSE_8 @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Amphion + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/inference.py b/inference.py index 0769fbb..727080d 100644 --- a/inference.py +++ b/inference.py @@ -40,7 +40,7 @@ def scan_checkpoint(cp_dir, prefix): def inference(a, h): - generator = Generator(h).to(device) + generator = Generator(h, use_cuda_kernel=a.use_cuda_kernel).to(device) state_dict_g = load_checkpoint(a.checkpoint_file, device) generator.load_state_dict(state_dict_g['generator']) @@ -54,7 +54,7 @@ def inference(a, h): with torch.no_grad(): for i, filname in enumerate(filelist): # load the ground truth audio and resample if necessary - wav, sr = librosa.load(os.path.join(a.input_wavs_dir, filname), h.sampling_rate, mono=True) + wav, sr = librosa.load(os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True) wav = torch.FloatTensor(wav).to(device) # compute mel spectrogram from the ground truth audio x = get_mel(wav.unsqueeze(0)) @@ -77,6 +77,7 @@ def main(): parser.add_argument('--input_wavs_dir', default='test_files') parser.add_argument('--output_dir', default='generated_files') parser.add_argument('--checkpoint_file', required=True) + parser.add_argument('--use_cuda_kernel', action='store_true', default=False) a = parser.parse_args() diff --git a/inference_e2e.py b/inference_e2e.py index 9d2ad60..8dcfdc9 100644 --- a/inference_e2e.py +++ b/inference_e2e.py @@ -36,7 +36,7 @@ def scan_checkpoint(cp_dir, prefix): def inference(a, h): - generator = Generator(h).to(device) + generator = Generator(h, use_cuda_kernel=a.use_cuda_kernel).to(device) state_dict_g = load_checkpoint(a.checkpoint_file, device) generator.load_state_dict(state_dict_g['generator']) @@ -73,6 +73,7 @@ def main(): parser.add_argument('--input_mels_dir', default='test_mel_files') parser.add_argument('--output_dir', default='generated_files_from_mel') parser.add_argument('--checkpoint_file', required=True) + parser.add_argument('--use_cuda_kernel', action='store_true', default=False) a = parser.parse_args() diff --git a/meldataset.py b/meldataset.py index 306f301..a40fc7a 100644 --- a/meldataset.py +++ b/meldataset.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 NVIDIA CORPORATION. +# Copyright (c) 2024 NVIDIA CORPORATION. # Licensed under the MIT license. # Adapted from https://github.com/jik876/hifi-gan under the MIT license. @@ -16,7 +16,7 @@ import pathlib from tqdm import tqdm -MAX_WAV_VALUE = 32768.0 +MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases) def load_wav(full_path, sr_target): @@ -65,8 +65,9 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, global mel_basis, hann_window if fmax not in mel_basis: - mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) - mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + str_key_mel_basis = str(fmax)+'_'+str(y.device) + mel_basis[str_key_mel_basis] = torch.from_numpy(mel).float().to(y.device) hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') @@ -78,7 +79,7 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, spec = torch.view_as_real(spec) spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) - spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) + spec = torch.matmul(mel_basis[str_key_mel_basis], spec) spec = spectral_normalize_torch(spec) return spec diff --git a/models.py b/models.py index f5022f1..637b0f5 100644 --- a/models.py +++ b/models.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 NVIDIA CORPORATION. +# Copyright (c) 2024 NVIDIA CORPORATION. # Licensed under the MIT license. # Adapted from https://github.com/jik876/hifi-gan under the MIT license. @@ -10,12 +10,18 @@ import torch.nn as nn from torch.nn import Conv1d, ConvTranspose1d, Conv2d from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +from torchaudio.transforms import Spectrogram, Resample +from librosa.filters import mel as librosa_mel_fn +from scipy import signal import activations from utils import init_weights, get_padding -from alias_free_torch import * - -LRELU_SLOPE = 0.1 +from alias_free_torch.act import Activation1d as TorchActivation1d +import typing +from typing import List, Optional, Tuple +from collections import namedtuple +import math +import functools class AMPBlock1(torch.nn.Module): @@ -45,6 +51,14 @@ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=No self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers + # select which Activation1d, lazy-load cuda version to ensure backward compatibility + if self.h.get("use_cuda_kernel", False): + # faster CUDA kernel implementation of Activation1d + from alias_free_cuda.activation1d import Activation1d as CudaActivation1d + Activation1d = CudaActivation1d + else: + Activation1d = TorchActivation1d + if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing self.activations = nn.ModuleList([ Activation1d( @@ -93,6 +107,14 @@ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None) self.num_layers = len(self.convs) # total number of conv layers + # select which Activation1d, lazy-load cuda version to ensure backward compatibility + if self.h.get("use_cuda_kernel", False): + # faster CUDA kernel implementation of Activation1d + from alias_free_cuda.activation1d import Activation1d as CudaActivation1d + Activation1d = CudaActivation1d + else: + Activation1d = TorchActivation1d + if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing self.activations = nn.ModuleList([ Activation1d( @@ -123,9 +145,16 @@ def remove_weight_norm(self): class BigVGAN(torch.nn.Module): # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks. - def __init__(self, h): + # New in v2: if use_cuda_kernel is set to True, it loads optimized CUDA kernels for AMP. + # NOTE: use_cuda_kernel=True should be used for inference only (training is not supported). + def __init__( + self, + h, + use_cuda_kernel: bool=False + ): super(BigVGAN, self).__init__() self.h = h + self.h["use_cuda_kernel"] = use_cuda_kernel # add it to global hyperparameters (h) self.num_kernels = len(h.resblock_kernel_sizes) self.num_upsamples = len(h.upsample_rates) @@ -151,6 +180,14 @@ def __init__(self, h): ch = h.upsample_initial_channel // (2 ** (i + 1)) for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): self.resblocks.append(resblock(h, ch, k, d, activation=h.activation)) + + # select which Activation1d, lazy-load cuda version to ensure backward compatibility + if self.h.get("use_cuda_kernel", False): + # faster CUDA kernel implementation of Activation1d + from alias_free_cuda.activation1d import Activation1d as CudaActivation1d + Activation1d = CudaActivation1d + else: + Activation1d = TorchActivation1d # post conv if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing @@ -161,13 +198,20 @@ def __init__(self, h): self.activation_post = Activation1d(activation=activation_post) else: raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.") - - self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + + # whether to use bias for the final conv_post. Defaults to True for backward compatibility + self.use_bias_at_final = h.get("use_bias_at_final", True) + self.conv_post = weight_norm(Conv1d( + ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final + )) # weight initialization for i in range(len(self.ups)): self.ups[i].apply(init_weights) self.conv_post.apply(init_weights) + + # final tanh activation. Defaults to True for backward compatibility + self.use_tanh_at_final = h.get("use_tanh_at_final", True) def forward(self, x): # pre conv @@ -189,7 +233,11 @@ def forward(self, x): # post conv x = self.activation_post(x) x = self.conv_post(x) - x = torch.tanh(x) + # final tanh activation + if self.use_tanh_at_final: + x = torch.tanh(x) + else: + x = torch.clamp(x, min=-1., max=1.) # bound the output to [-1, 1] return x @@ -232,7 +280,7 @@ def forward(self, x): for l in self.convs: x = l(x) - x = F.leaky_relu(x, LRELU_SLOPE) + x = F.leaky_relu(x, 0.1) fmap.append(x) x = self.conv_post(x) fmap.append(x) @@ -272,7 +320,7 @@ def __init__(self, cfg, resolution): self.resolution = resolution assert len(self.resolution) == 3, \ "MRD layer requires list with len=3, got {}".format(self.resolution) - self.lrelu_slope = LRELU_SLOPE + self.lrelu_slope = 0.1 norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm if hasattr(cfg, "mrd_use_spectral_norm"): @@ -345,17 +393,542 @@ def forward(self, y, y_hat): return y_d_rs, y_d_gs, fmap_rs, fmap_gs +# Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec +# Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license. +# LICENSE is in incl_licenses directory. +class DiscriminatorB(nn.Module): + def __init__( + self, + window_length: int, + channels: int = 32, + hop_factor: float = 0.25, + bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)), + ): + super().__init__() + self.window_length = window_length + self.hop_factor = hop_factor + self.spec_fn = Spectrogram( + n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None + ) + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + convs = lambda: nn.ModuleList( + [ + weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))), + ] + ) + self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) + + self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))) + + def spectrogram(self, x): + # Remove DC offset + x = x - x.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + x = self.spec_fn(x) + x = torch.view_as_real(x) + x = x.permute(0, 3, 2, 1) # [B, F, T, C] -> [B, C, T, F] + # Split into bands + x_bands = [x[..., b[0] : b[1]] for b in self.bands] + return x_bands + + def forward(self, x: torch.Tensor): + x_bands = self.spectrogram(x.squeeze(1)) + fmap = [] + x = [] + + for band, stack in zip(x_bands, self.band_convs): + for i, layer in enumerate(stack): + band = layer(band) + band = torch.nn.functional.leaky_relu(band, 0.1) + if i > 0: + fmap.append(band) + x.append(band) + + x = torch.cat(x, dim=-1) + x = self.conv_post(x) + fmap.append(x) + + return x, fmap + +# Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec +# Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license. +# LICENSE is in incl_licenses directory. +class MultiBandDiscriminator(nn.Module): + def __init__( + self, + h, + ): + """ + Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec. + and the modified code adapted from https://github.com/gemelo-ai/vocos. + """ + super().__init__() + # fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h. + self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512]) + self.discriminators = nn.ModuleList( + [DiscriminatorB(window_length=w) for w in self.fft_sizes] + ) + + def forward( + self, + y: torch.Tensor, + y_hat: torch.Tensor + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: + + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for d in self.discriminators: + y_d_r, fmap_r = d(x=y) + y_d_g, fmap_g = d(x=y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +# Adapted from https://github.com/open-mmlab/Amphion/blob/main/models/vocoders/gan/discriminator/mssbcqtd.py under the MIT license. +# LICENSE is in incl_licenses directory. +class DiscriminatorCQT(nn.Module): + def __init__(self, cfg, hop_length, n_octaves, bins_per_octave): + super().__init__() + self.cfg = cfg + + self.filters = cfg["cqtd_filters"] + self.max_filters = cfg["cqtd_max_filters"] + self.filters_scale = cfg["cqtd_filters_scale"] + self.kernel_size = (3, 9) + self.dilations = cfg["cqtd_dilations"] + self.stride = (1, 2) + + self.in_channels = cfg["cqtd_in_channels"] + self.out_channels = cfg["cqtd_out_channels"] + self.fs = cfg["sampling_rate"] + self.hop_length = hop_length + self.n_octaves = n_octaves + self.bins_per_octave = bins_per_octave + + # lazy-load + from nnAudio import features + self.cqt_transform = features.cqt.CQT2010v2( + sr=self.fs * 2, + hop_length=self.hop_length, + n_bins=self.bins_per_octave * self.n_octaves, + bins_per_octave=self.bins_per_octave, + output_format="Complex", + pad_mode="constant", + ) + + self.conv_pres = nn.ModuleList() + for i in range(self.n_octaves): + self.conv_pres.append( + nn.Conv2d( + self.in_channels * 2, + self.in_channels * 2, + kernel_size=self.kernel_size, + padding=self.get_2d_padding(self.kernel_size), + ) + ) + + self.convs = nn.ModuleList() + + self.convs.append( + nn.Conv2d( + self.in_channels * 2, + self.filters, + kernel_size=self.kernel_size, + padding=self.get_2d_padding(self.kernel_size), + ) + ) + + in_chs = min(self.filters_scale * self.filters, self.max_filters) + for i, dilation in enumerate(self.dilations): + out_chs = min( + (self.filters_scale ** (i + 1)) * self.filters, self.max_filters + ) + self.convs.append( + weight_norm(nn.Conv2d( + in_chs, + out_chs, + kernel_size=self.kernel_size, + stride=self.stride, + dilation=(dilation, 1), + padding=self.get_2d_padding(self.kernel_size, (dilation, 1)), + )) + ) + in_chs = out_chs + out_chs = min( + (self.filters_scale ** (len(self.dilations) + 1)) * self.filters, + self.max_filters, + ) + self.convs.append( + weight_norm(nn.Conv2d( + in_chs, + out_chs, + kernel_size=(self.kernel_size[0], self.kernel_size[0]), + padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])), + )) + ) + + self.conv_post = weight_norm(nn.Conv2d( + out_chs, + self.out_channels, + kernel_size=(self.kernel_size[0], self.kernel_size[0]), + padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])), + )) + + self.activation = torch.nn.LeakyReLU(negative_slope=0.1) + self.resample = Resample(orig_freq=self.fs, new_freq=self.fs * 2) + + self.cqtd_normalize_volume = self.cfg.get("cqtd_normalize_volume", False) + if self.cqtd_normalize_volume: + print(f"INFO: cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!") + + def get_2d_padding( + self, kernel_size: typing.Tuple[int, int], dilation: typing.Tuple[int, int] = (1, 1) + ): + return ( + ((kernel_size[0] - 1) * dilation[0]) // 2, + ((kernel_size[1] - 1) * dilation[1]) // 2, + ) + + def forward(self, x): + fmap = [] + + if self.cqtd_normalize_volume: + # Remove DC offset + x = x - x.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + + x = self.resample(x) + + z = self.cqt_transform(x) + + z_amplitude = z[:, :, :, 0].unsqueeze(1) + z_phase = z[:, :, :, 1].unsqueeze(1) + + z = torch.cat([z_amplitude, z_phase], dim=1) + z = torch.permute(z, (0, 1, 3, 2)) # [B, C, W, T] -> [B, C, T, W] + + latent_z = [] + for i in range(self.n_octaves): + latent_z.append( + self.conv_pres[i]( + z[ + :, + :, + :, + i * self.bins_per_octave : (i + 1) * self.bins_per_octave, + ] + ) + ) + latent_z = torch.cat(latent_z, dim=-1) + + for i, l in enumerate(self.convs): + latent_z = l(latent_z) + + latent_z = self.activation(latent_z) + fmap.append(latent_z) + + latent_z = self.conv_post(latent_z) + + return latent_z, fmap + + +class MultiScaleSubbandCQTDiscriminator(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.cfg = cfg + # Using get with defaults + self.cfg["cqtd_filters"] = self.cfg.get("cqtd_filters", 32) + self.cfg["cqtd_max_filters"] = self.cfg.get("cqtd_max_filters", 1024) + self.cfg["cqtd_filters_scale"] = self.cfg.get("cqtd_filters_scale", 1) + self.cfg["cqtd_dilations"] = self.cfg.get("cqtd_dilations", [1, 2, 4]) + self.cfg["cqtd_in_channels"] = self.cfg.get("cqtd_in_channels", 1) + self.cfg["cqtd_out_channels"] = self.cfg.get("cqtd_out_channels", 1) + # multi-scale params to loop over + self.cfg["cqtd_hop_lengths"] = self.cfg.get("cqtd_hop_lengths", [512, 256, 256]) + self.cfg["cqtd_n_octaves"] = self.cfg.get("cqtd_n_octaves", [9, 9, 9]) + self.cfg["cqtd_bins_per_octaves"] = self.cfg.get("cqtd_bins_per_octaves", [24, 36, 48]) + + self.discriminators = nn.ModuleList( + [ + DiscriminatorCQT( + self.cfg, + hop_length=self.cfg["cqtd_hop_lengths"][i], + n_octaves=self.cfg["cqtd_n_octaves"][i], + bins_per_octave=self.cfg["cqtd_bins_per_octaves"][i], + ) + for i in range(len(self.cfg["cqtd_hop_lengths"])) + ] + ) + + def forward( + self, + y: torch.Tensor, + y_hat: torch.Tensor + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: + + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for disc in self.discriminators: + y_d_r, fmap_r = disc(y) + y_d_g, fmap_g = disc(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class CombinedDiscriminator(nn.Module): + # wrapper of chaining multiple discrimiantor architectures + # ex: combine mbd and cqtd as a single class + def __init__( + self, + list_discriminator: List[nn.Module] + ): + super().__init__() + self.discrimiantor = nn.ModuleList(list_discriminator) + + def forward( + self, + y: torch.Tensor, + y_hat: torch.Tensor + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: + + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for disc in self.discrimiantor: + y_d_r, y_d_g, fmap_r, fmap_g = disc(y, y_hat) + y_d_rs.extend(y_d_r) + fmap_rs.extend(fmap_r) + y_d_gs.extend(y_d_g) + fmap_gs.extend(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +# Adapted from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py under the MIT license. +# LICENSE is in incl_licenses directory. +class MultiScaleMelSpectrogramLoss(nn.Module): + """Compute distance between mel spectrograms. Can be used + in a multi-scale way. + + Parameters + ---------- + n_mels : List[int] + Number of mels per STFT, by default [5, 10, 20, 40, 80, 160, 320], + window_lengths : List[int], optional + Length of each window of each STFT, by default [32, 64, 128, 256, 512, 1024, 2048] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 0.0 (no ampliciation on mag part) + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 1.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + Additional code copied and modified from https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py + """ + + def __init__( + self, + sampling_rate: int, + n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320], + window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 0.0, + log_weight: float = 1.0, + pow: float = 1.0, + weight: float = 1.0, + match_stride: bool = False, + mel_fmin: List[float] = [0, 0, 0, 0, 0, 0, 0], + mel_fmax: List[float] = [None, None, None, None, None, None, None], + window_type: str = 'hann', + ): + super().__init__() + self.sampling_rate = sampling_rate + + STFTParams = namedtuple( + "STFTParams", + ["window_length", "hop_length", "window_type", "match_stride"], + ) + + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.n_mels = n_mels + self.loss_fn = loss_fn + self.clamp_eps = clamp_eps + self.log_weight = log_weight + self.mag_weight = mag_weight + self.weight = weight + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.pow = pow + + @staticmethod + @functools.lru_cache(None) + def get_window( + window_type,window_length, + ): + return signal.get_window(window_type, window_length) + + @staticmethod + @functools.lru_cache(None) + def get_mel_filters( + sr, n_fft, n_mels, fmin, fmax + ): + return librosa_mel_fn(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) + + def mel_spectrogram( + self, wav, n_mels, fmin, fmax, window_length, hop_length, match_stride, window_type + ): + # mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from: + # https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py + B, C, T = wav.shape + + if match_stride: + assert ( + hop_length == window_length // 4 + ), "For match_stride, hop must equal n_fft // 4" + right_pad = math.ceil(T / hop_length) * hop_length - T + pad = (window_length - hop_length) // 2 + else: + right_pad = 0 + pad = 0 + + wav = torch.nn.functional.pad( + wav, (pad, pad + right_pad), mode='reflect' + ) + + window = self.get_window(window_type, window_length) + window = torch.from_numpy(window).to(wav.device).float() + + stft = torch.stft( + wav.reshape(-1, T), + n_fft=window_length, + hop_length=hop_length, + window=window, + return_complex=True, + center=True, + ) + _, nf, nt = stft.shape + stft = stft.reshape(B, C, nf, nt) + if match_stride: + # Drop first two and last two frames, which are added + # because of padding. Now num_frames * hop_length = num_samples. + stft = stft[..., 2:-2] + magnitude = torch.abs(stft) + + nf = magnitude.shape[2] + mel_basis = self.get_mel_filters(self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax) + mel_basis = torch.from_numpy(mel_basis).to(wav.device) + mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T + mel_spectrogram = mel_spectrogram.transpose(-1, 2) + + return mel_spectrogram + + def forward( + self, + x: torch.Tensor, + y: torch.Tensor + ) -> torch.Tensor: + """Computes mel loss between an estimate and a reference + signal. + + Parameters + ---------- + x : torch.Tensor + Estimate signal + y : torch.Tensor + Reference signal + + Returns + ------- + torch.Tensor + Mel loss. + """ + + loss = 0.0 + for n_mels, fmin, fmax, s in zip( + self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params + ): + kwargs = { + "n_mels": n_mels, + "fmin": fmin, + "fmax": fmax, + "window_length": s.window_length, + "hop_length": s.hop_length, + "match_stride": s.match_stride, + "window_type": s.window_type, + } + + x_mels = self.mel_spectrogram(x, **kwargs) + y_mels = self.mel_spectrogram(y, **kwargs) + x_logmels = torch.log(x_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0)) + y_logmels = torch.log(y_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0)) + + loss += self.log_weight * self.loss_fn(x_logmels, y_logmels) + loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels) + + return loss + + +# loss functions +def feature_loss( + fmap_r: List[List[torch.Tensor]], + fmap_g: List[List[torch.Tensor]] +) -> torch.Tensor: -def feature_loss(fmap_r, fmap_g): loss = 0 for dr, dg in zip(fmap_r, fmap_g): for rl, gl in zip(dr, dg): loss += torch.mean(torch.abs(rl - gl)) - return loss*2 + return loss*2 # this equates to lambda=2.0 for the feature matching loss +def discriminator_loss( + disc_real_outputs: List[torch.Tensor], + disc_generated_outputs: List[torch.Tensor] +) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: -def discriminator_loss(disc_real_outputs, disc_generated_outputs): loss = 0 r_losses = [] g_losses = [] @@ -368,8 +941,10 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs): return loss, r_losses, g_losses - -def generator_loss(disc_outputs): +def generator_loss( + disc_outputs: List[torch.Tensor] +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + loss = 0 gen_losses = [] for dg in disc_outputs: @@ -377,5 +952,4 @@ def generator_loss(disc_outputs): gen_losses.append(l) loss += l - return loss, gen_losses - + return loss, gen_losses \ No newline at end of file diff --git a/nv-modelcard++/.gitkeep b/nv-modelcard++/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/nv-modelcard++/bias.md b/nv-modelcard++/bias.md new file mode 100644 index 0000000..aee310b --- /dev/null +++ b/nv-modelcard++/bias.md @@ -0,0 +1,4 @@ +Field | Response +:---------------------------------------------------------------------------------------------------|:--------------- +Participation considerations from adversely impacted groups protected classes in model design and testing: | None +Measures taken to mitigate against unwanted bias: | No measures taken to mitigate against unwanted bias. \ No newline at end of file diff --git a/nv-modelcard++/explainability.md b/nv-modelcard++/explainability.md new file mode 100644 index 0000000..d057add --- /dev/null +++ b/nv-modelcard++/explainability.md @@ -0,0 +1,13 @@ +Field | Response +:------------------------------------------------------------------------------------------------------|:--------------------------------------------------------------------------------- +Intended Application & Domain: | Generating waveform from mel spectrogram. +Model Type: | Convolutional Neural Network (CNN) +Intended Users: | This model is intended for developers to synthesize and generate waveforms from the AI-generated mel spectrograms. +Output: | Audio Waveform +Describe how the model works: | Model generates audio waveform corresponding to the input mel spectrogram. +Name the adversely impacted groups this has been tested to deliver comparable outcomes regardless of: | Not Applicable +Technical Limitations: | This may not perform well on synthetically-generated mel spectrograms that deviate significantly from the profile of mel spectrograms on which this was trained. +Verified to have met prescribed NVIDIA quality standards: | Yes +Performance Metrics: | Perceptual Evaluation of Speech Quality (PESQ), Virtual Speech Quality Objective Listener (VISQOL), Multi-resolution STFT (MRSTFT), Mel cepstral distortion (MCD), Periodicity RMSE, Voice/Unvoiced F1 Score (V/UV F1) +Potential Known Risks: | This model may generate low-quality or distorted soundwaves. +Licensing: | https://github.com/NVIDIA/BigVGAN/blob/main/LICENSE \ No newline at end of file diff --git a/nv-modelcard++/overview.md b/nv-modelcard++/overview.md new file mode 100644 index 0000000..0c60375 --- /dev/null +++ b/nv-modelcard++/overview.md @@ -0,0 +1,115 @@ +# Model Overview + +## Description: +BigVGAN is a generative AI model specialized in synthesizing audio waveforms using Mel spectrogram as inputs. + +
+ +BigVGAN is a fully convolutional architecture with several upsampling blocks using transposed convolution followed by multiple residual dilated convolution layers. + +BigVGAN consists of a novel module, called anti-aliased multi-periodicity composition (AMP), which is specifically designed for generating waveforms. AMP is specialized in synthesizing high-frequency and periodic soundwaves drawing inspiration from audio signal processing principles. + +It applies a periodic activation function, called Snake, which provides an inductive bias to the architecture in generating periodic soundwaves. It also applies anti-aliasing filters to reduce undesired artifacts in the generated waveforms.
+ +This model is ready for commercial use.
+ + +## References(s): +* [BigVGAN: A Universal Neural Vocoder with Large-Scale Training](https://arxiv.org/abs/2206.04658)
+* [Project Page](https://research.nvidia.com/labs/adlr/projects/bigvgan/)
+* [Audio Demo](https://bigvgan-demo.github.io/)
+ +## Model Architecture: +**Architecture Type:** Convolution Neural Network (CNN)
+**Network Architecture:** You can see the details of this model on this link: https://github.com/NVIDIA/BigVGAN and the related paper can be found here: https://arxiv.org/abs/2206.04658
+**Model Version:** 2.0
+ +## Input: +**Input Type:** Audio
+**Input Format:** Mel Spectrogram
+**Input Parameters:** None
+**Other Properties Related to Input:** The input mel spectrogram has shape `[batch, channels, frames]`, where `channels` refers to the number of mel bands defined by the model and `frames` refers to the temporal length. The model supports arbitrary long `frames` that fits into the GPU memory. + +## Output: +**Input Type:** Audio
+**Output Format:** Audio Waveform
+**Output Parameters:** None
+**Other Properties Related to Output:** The output audio waveform has shape `[batch, 1, time]`, where `1` refers to the mono audio channels and `time` refers to the temporal length. `time` is defined as a fixed integer multiple of input `frames`, which is an upsampling ratio of the model (`time = upsampling ratio * frames`). The output audio waveform consitutes float values with a range of `[-1, 1]`. + +## Software Integration: +**Runtime Engine(s):** PyTorch + +**Supported Hardware Microarchitecture Compatibility:** NVIDIA Ampere, NVIDIA Hopper, NVIDIA Lovelace, NVIDIA Turing, NVIDIA Volta
+ + +## Preferred/Supported Operating System(s): +Linux + + +## Model Version(s): +v2.0 + +## Training, Testing, and Evaluation Datasets: + +### Training Dataset: +The dataset contains diverse audio types, including speech in multiple languages, environmental sounds, and instruments. + +**Links:** +* [AAM: Artificial Audio Multitracks Dataset](https://zenodo.org/records/5794629) +* [AudioCaps](https://audiocaps.github.io/) +* [AudioSet](https://research.google.com/audioset/index.html) +* [common-accent](https://huggingface.co/datasets/DTU54DL/common-accent) +* [Crowd Sourced Emotional Multimodal Actors Dataset (CREMA-D)](https://ieeexplore.ieee.org/document/6849440) +* [DCASE2017 Challenge, Task 4: Large-scale weakly supervised sound event detection for smart cars](https://dcase.community/challenge2017/task-large-scale-sound-event-detection) +* [FSDnoisy18k](https://zenodo.org/records/2529934) +* [Free Universal Sound Separation Dataset](https://zenodo.org/records/3694384) +* [Greatest Hits dataset](https://andrewowens.com/vis/) +* [GTZAN](https://ieeexplore.ieee.org/document/1021072) +* [JL corpus](https://www.kaggle.com/datasets/tli725/jl-corpus) +* [Medley-solos-DB: a cross-collection dataset for musical instrument recognition](https://zenodo.org/records/3464194) +* [MUSAN: A Music, Speech, and Noise Corpus](https://www.openslr.org/17/) +* [MusicBench](https://huggingface.co/datasets/amaai-lab/MusicBench) +* [MusicCaps](https://www.kaggle.com/datasets/googleai/musiccaps) +* [MusicNet](https://www.kaggle.com/datasets/imsparsh/musicnet-dataset) +* [NSynth](https://magenta.tensorflow.org/datasets/nsynth) +* [OnAir-Music-Dataset](https://github.com/sevagh/OnAir-Music-Dataset) +* [Audio Piano Triads Dataset](https://zenodo.org/records/4740877) +* [Pitch Audio Dataset (Surge synthesizer)](https://zenodo.org/records/4677097) +* [SONYC Urban Sound Tagging (SONYC-UST): a multilabel dataset from an urban acoustic sensor network](https://zenodo.org/records/3966543) +* [VocalSound: A Dataset for Improving Human Vocal Sounds Recognition](https://arxiv.org/abs/2205.03433) +* [WavText5K](https://github.com/microsoft/WavText5K) +* [CSS10: A Collection of Single Speaker Speech Datasets for 10 Languages](https://github.com/Kyubyong/css10) +* [Hi-Fi Multi-Speaker English TTS Dataset (Hi-Fi TTS)](https://www.openslr.org/109/) +* [IIIT-H Indic Speech Databases](http://festvox.org/databases/iiit_voices/) +* [Libri-Light: A Benchmark for ASR with Limited or No Supervision](https://arxiv.org/abs/1912.07875) +* [LibriTTS: A Corpus Derived from LibriSpeech for Text-to-Speech](https://www.openslr.org/60) +* [LibriTTS-R: A Restored Multi-Speaker Text-to-Speech Corpus](https://www.openslr.org/141/) +* [The SIWIS French Speech Synthesis Database](https://datashare.ed.ac.uk/handle/10283/2353) +* [Crowdsourced high-quality Colombian Spanish speech data set](https://openslr.org/72/) +* [TTS-Portuguese Corpus](https://github.com/Edresson/TTS-Portuguese-Corpus) +* [CSTR VCTK Corpus: English Multi-speaker Corpus for CSTR Voice Cloning Toolkit](https://datashare.ed.ac.uk/handle/10283/3443) + +** Data Collection Method by dataset
+* Human
+ +** Labeling Method by dataset (for those with labels)
+* Hybrid: Automated, Human, Unknown
+ +### Evaluating Dataset: + +Properties: The audio generation quality of BigVGAN is evaluated using `dev` splits of the [LibriTTS dataset](https://www.openslr.org/60/) and [Hi-Fi TTS dataset](https://www.openslr.org/109/). The datasets include speech in English language with equal balance of genders. + +** Data Collection Method by dataset
+* Human
+ +** Labeling Method by dataset
+* Automated
+ + +## Inference: +**Engine:** PyTorch
+**Test Hardware:** NVIDIA A100 GPU
+ +## Ethical Considerations: +NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse. For more detailed information on ethical considerations for this model, please see the Model Card++ Explainability, Bias, Safety & Security, and Privacy Subcards. Please report security vulnerabilities or NVIDIA AI Concerns [here](https://www.nvidia.com/en-us/support/submit-security-vulnerability/). + diff --git a/nv-modelcard++/privacy.md b/nv-modelcard++/privacy.md new file mode 100644 index 0000000..2d94d9e --- /dev/null +++ b/nv-modelcard++/privacy.md @@ -0,0 +1,14 @@ +Field | Response +:----------------------------------------------------------------------------------------------------------------------------------|:----------------------------------------------- +Generatable or reverse engineerable personal information? | None +Protected class data used to create this model? | None +Was consent obtained for any personal data used? | Not Applicable (No Personal Data) +How often is dataset reviewed? | Before Release +Is a mechanism in place to honor data subject right of access or deletion of personal data? | Not Applicable +If personal collected for the development of the model, was it collected directly by NVIDIA? | Not Applicable +If personal collected for the development of the model by NVIDIA, do you maintain or have access to disclosures made to data subjects? | Not Applicable +If personal collected for the development of this AI model, was it minimized to only what was required? | Not Applicable +Is data in dataset traceable? | Yes +Is there provenance for all datasets used in training? | Yes +Does data labeling (annotation, metadata) comply with privacy laws? | Yes +Is data compliant with data subject requests for data correction or removal, if such a request was made? | No, not possible with externally-sourced data. diff --git a/nv-modelcard++/safety.md b/nv-modelcard++/safety.md new file mode 100644 index 0000000..6baaaa7 --- /dev/null +++ b/nv-modelcard++/safety.md @@ -0,0 +1,6 @@ +Field | Response +:---------------------------------------------------|:---------------------------------- +Model Application(s): | Synethic Audio Generation +Describe the life critical impact (if present). | Not Applicable +Use Case Restrictions: | None +Model and dataset restrictions: | The Principle of least privilege (PoLP) is applied limiting access for dataset generation and model development. Restrictions enforce dataset access during training, and dataset license constraints adhered to. diff --git a/parse_scripts/parse_libritts.py b/parse_scripts/parse_libritts.py index 0886ed0..d5e3bb0 100644 --- a/parse_scripts/parse_libritts.py +++ b/parse_scripts/parse_libritts.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 NVIDIA CORPORATION. +# Copyright (c) 2024 NVIDIA CORPORATION. # Licensed under the MIT license. import os, glob diff --git a/requirements.txt b/requirements.txt index 2bad038..f555a4d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,12 @@ torch numpy -librosa==0.8.1 +librosa>=0.8.1 scipy tensorboard soundfile matplotlib pesq auraloss -tqdm \ No newline at end of file +tqdm +nnAudio +ninja \ No newline at end of file diff --git a/test_cuda_vs_torch_model.py b/test_cuda_vs_torch_model.py new file mode 100644 index 0000000..f560682 --- /dev/null +++ b/test_cuda_vs_torch_model.py @@ -0,0 +1,161 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +import math +import torch +import json +from env import AttrDict +from models import BigVGAN +from time import time +from tqdm import tqdm +import os +from meldataset import mel_spectrogram, MAX_WAV_VALUE +import librosa +from scipy.io.wavfile import write +import numpy as np + +import argparse + +# for easier debugging +torch.set_printoptions( + linewidth=200, + threshold=10_000 +) + +def generate_soundwave(duration=5.0, sr=24000): + t = np.linspace(0, duration, int(sr * duration), False, dtype=np.float32) + + modulation = np.sin(2 * np.pi * t / duration) + + min_freq = 220 + max_freq = 1760 + frequencies = min_freq + (max_freq - min_freq) * (modulation + 1) / 2 + soundwave = np.sin(2 * np.pi * frequencies * t) + + soundwave = soundwave / np.max(np.abs(soundwave)) * 0.95 + + return soundwave, sr + +def get_mel(x, h): + return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax) + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print("Loading '{}'".format(filepath)) + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test script to check CUDA kernel correctness.") + parser.add_argument('--checkpoint_file', type=str, required=True, help="Path to the checkpoint file. Assumes config.json exists in the directory.") + + args = parser.parse_args() + + config_file = os.path.join(os.path.split(args.checkpoint_file)[0], 'config.json') + with open(config_file) as f: + config = f.read() + json_config = json.loads(config) + h = AttrDict({**json_config}) + + print("loading plain Pytorch BigVGAN") + generator_original = BigVGAN(h).to("cuda") + print("loading CUDA kernel BigVGAN with auto-build") + generator_cuda_kernel = BigVGAN(h, use_cuda_kernel=True).to("cuda") + + state_dict_g = load_checkpoint(args.checkpoint_file, "cuda") + generator_original.load_state_dict(state_dict_g['generator']) + generator_cuda_kernel.load_state_dict(state_dict_g['generator']) + + generator_original.eval() + generator_original.remove_weight_norm() + generator_cuda_kernel.eval() + generator_cuda_kernel.remove_weight_norm() + + toc_total_original = 0. + toc_total_cuda_kernel = 0. + audio_length_total = 0. + diff = 0. + + num_sample = 10 + num_mel_frame = 128 + for i in tqdm(range(num_sample)): + # random mel: use large num_mel_frame to test peak gpu util performance + data = torch.rand((1, h.num_mels, num_mel_frame), device='cuda') + # original inference + torch.cuda.synchronize() + tic = time() + with torch.inference_mode(): + audio_original = generator_original(data) + torch.cuda.synchronize() + toc = time() - tic + toc_total_original += toc + # cuda kernel inference + torch.cuda.synchronize() + tic = time() + with torch.inference_mode(): + audio_cuda_kernel = generator_cuda_kernel(data) + torch.cuda.synchronize() + toc = time() - tic + toc_total_cuda_kernel += toc + audio_length_total += audio_cuda_kernel.shape[-1] + + # both outputs should be (almost) the same + test_result = (audio_original - audio_cuda_kernel).abs() + diff += test_result.mean(dim=-1).item() + + diff /= num_sample + if diff <= 2e-3: # we can expect a small difference (~1e-3) which does not affect perceptual quality + print( + f"\n[Success] test CUDA fused vs. plain torch BigVGAN inference" + f"\n > mean_difference={diff}" + f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}" + f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}" + ) + else: + print( + f"\n[Fail] test CUDA fused vs. plain torch BigVGAN inference" + f"\n > mean_difference={diff}" + f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}, " + f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}" + ) + + audio_second = audio_length_total / h.sampling_rate + khz_original = audio_length_total / toc_total_original / 1000 + khz_cuda_kernel = audio_length_total / toc_total_cuda_kernel / 1000 + + print('Original BigVGAN: took {:.2f} seconds to generate {:.2f} seconds of audio, {:.1f}kHz, {:.1f} faster than realtime'.format(toc_total_original, audio_second, khz_original, audio_second / toc_total_original)) + print('CUDA kernel BigVGAN: took {:.2f} seconds to generate {:.2f} seconds of audio, {:.1f}kHz, {:.1f} faster than realtime'.format(toc_total_cuda_kernel, audio_second, khz_cuda_kernel, audio_second / toc_total_cuda_kernel)) + print('speedup of CUDA kernel: {}'.format(khz_cuda_kernel/khz_original)) + + # use artificial sine waves for inference test + audio_real, sr = generate_soundwave(duration=5., sr=h.sampling_rate) + audio_real = torch.tensor(audio_real).to("cuda") + # compute mel spectrogram from the ground truth audio + x = get_mel(audio_real.unsqueeze(0), h) + + with torch.inference_mode(): + y_g_hat_original = generator_original(x) + y_g_hat_cuda_kernel = generator_cuda_kernel(x) + + audio_real = audio_real.squeeze() + audio_real = audio_real * MAX_WAV_VALUE + audio_real = audio_real.cpu().numpy().astype('int16') + + audio_original = y_g_hat_original.squeeze() + audio_original = audio_original * MAX_WAV_VALUE + audio_original = audio_original.cpu().numpy().astype('int16') + + audio_cuda_kernel = y_g_hat_cuda_kernel.squeeze() + audio_cuda_kernel = audio_cuda_kernel * MAX_WAV_VALUE + audio_cuda_kernel = audio_cuda_kernel.cpu().numpy().astype('int16') + + os.makedirs('tmp', exist_ok=True) + output_file_real = os.path.join('tmp', 'audio_real.wav') + output_file_original = os.path.join('tmp', 'audio_generated_original.wav') + output_file_cuda_kernel = os.path.join('tmp', 'audio_generated_cuda_kernel.wav') + write(output_file_real, h.sampling_rate, audio_real) + write(output_file_original, h.sampling_rate, audio_original) + write(output_file_cuda_kernel, h.sampling_rate, audio_cuda_kernel) + print("Example generated audios of original vs. fused CUDA kernel written to tmp!") + print("Done") \ No newline at end of file diff --git a/train.py b/train.py index cadee69..32a10bd 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 NVIDIA CORPORATION. +# Copyright (c) 2024 NVIDIA CORPORATION. # Licensed under the MIT license. # Adapted from https://github.com/jik876/hifi-gan under the MIT license. @@ -21,8 +21,8 @@ from torch.nn.parallel import DistributedDataParallel from env import AttrDict, build_env from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist, MAX_WAV_VALUE -from models import BigVGAN, MultiPeriodDiscriminator, MultiResolutionDiscriminator,\ - feature_loss, generator_loss, discriminator_loss +from models import BigVGAN, MultiPeriodDiscriminator, MultiResolutionDiscriminator, MultiBandDiscriminator, MultiScaleSubbandCQTDiscriminator, \ + feature_loss, generator_loss, discriminator_loss, MultiScaleMelSpectrogramLoss from utils import plot_spectrogram, plot_spectrogram_clipped, scan_checkpoint, load_checkpoint, save_checkpoint, save_audio import torchaudio as ta from pesq import pesq @@ -34,8 +34,12 @@ def train(rank, a, h): if h.num_gpus > 1: # initialize distributed - init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'], - world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank) + init_process_group( + backend=h.dist_config['dist_backend'], + init_method=h.dist_config['dist_url'], + world_size=h.dist_config['world_size'] * h.num_gpus, + rank=rank + ) # set seed and device torch.cuda.manual_seed(h.seed) @@ -44,19 +48,36 @@ def train(rank, a, h): # define BigVGAN generator generator = BigVGAN(h).to(device) - print("Generator params: {}".format(sum(p.numel() for p in generator.parameters()))) # define discriminators. MPD is used by default mpd = MultiPeriodDiscriminator(h).to(device) - print("Discriminator mpd params: {}".format(sum(p.numel() for p in mpd.parameters()))) - # define additional discriminators. BigVGAN uses MRD as default - mrd = MultiResolutionDiscriminator(h).to(device) - print("Discriminator mrd params: {}".format(sum(p.numel() for p in mrd.parameters()))) + # define additional discriminators. BigVGAN-v1 uses UnivNet's MRD as default + # New in BigVGAN-v2: option to switch to new discriminators: MultiBandDiscriminator / MultiScaleSubbandCQTDiscriminator + if h.get("use_mbd_instead_of_mrd", False): # switch to MBD + print("INFO: using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator") + mrd = MultiBandDiscriminator(h).to(device) # variable name is kept as "mrd" for backward compatibility & minimal code change + elif h.get("use_cqtd_instead_of_mrd", False): # switch CQTD + print("INFO: using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator") + mrd = MultiScaleSubbandCQTDiscriminator(h).to(device) + else: + mrd = MultiResolutionDiscriminator(h).to(device) # fallback to original MRD in BigVGAN-v1 + + # New in BigVGAN-v2: option to switch to multi-scale L1 mel loss + if h.get("use_multiscale_melloss", False): + print("INFO: using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss") + fn_mel_loss_multiscale = MultiScaleMelSpectrogramLoss(sampling_rate=h.sampling_rate) # NOTE: accepts waveform as input + else: + fn_mel_loss_singlescale = F.l1_loss - # create or scan the latest checkpoint from checkpoints directory + # print the model & number of parameters, and create or scan the latest checkpoint from checkpoints directory if rank == 0: print(generator) + print(mpd) + print(mrd) + print("Generator params: {}".format(sum(p.numel() for p in generator.parameters()))) + print("Discriminator mpd params: {}".format(sum(p.numel() for p in mpd.parameters()))) + print("Discriminator mrd params: {}".format(sum(p.numel() for p in mrd.parameters()))) os.makedirs(a.checkpoint_path, exist_ok=True) print("checkpoints directory : ", a.checkpoint_path) @@ -85,8 +106,7 @@ def train(rank, a, h): mrd = DistributedDataParallel(mrd, device_ids=[rank]).to(device) optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]) - optim_d = torch.optim.AdamW(itertools.chain(mrd.parameters(), mpd.parameters()), - h.learning_rate, betas=[h.adam_b1, h.adam_b2]) + optim_d = torch.optim.AdamW(itertools.chain(mrd.parameters(), mpd.parameters()), h.learning_rate, betas=[h.adam_b1, h.adam_b2]) if state_dict_do is not None: optim_g.load_state_dict(state_dict_do['optim_g']) @@ -191,7 +211,7 @@ def validate(rank, a, h, loader, mode="seen"): val_pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, 'wb') # MRSTFT calculation - val_mrstft_tot += loss_mrstft(y_g_hat.squeeze(1), y).item() + val_mrstft_tot += loss_mrstft(y_g_hat, y).item() # log audio and figures to Tensorboard if j % a.eval_subsample == 0: # subsample every nth from validation set @@ -261,8 +281,7 @@ def validate(rank, a, h, loader, mode="seen"): y = y.unsqueeze(1) y_g_hat = generator(x) - y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, - h.fmin, h.fmax_for_loss) + y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax_for_loss) optim_d.zero_grad() @@ -276,11 +295,14 @@ def validate(rank, a, h, loader, mode="seen"): loss_disc_all = loss_disc_s + loss_disc_f + # set clip_grad_norm value + clip_grad_norm = h.get("clip_grad_norm", 1000.) # defaults to 1000 + # whether to freeze D for initial training steps if steps >= a.freeze_step: loss_disc_all.backward() - grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), 1000.) - grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), 1000.) + grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), clip_grad_norm) + grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), clip_grad_norm) optim_d.step() else: print("WARNING: skipping D training for the first {} steps".format(a.freeze_step)) @@ -291,7 +313,11 @@ def validate(rank, a, h, loader, mode="seen"): optim_g.zero_grad() # L1 Mel-Spectrogram Loss - loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45 + lambda_melloss = h.get("lambda_melloss", 45.) # defaults to 45 in BigVGAN-v1 if not set + if h.get("use_multiscale_melloss", False): # uses wav for loss + loss_mel = fn_mel_loss_multiscale(y, y_g_hat) * lambda_melloss + else: # uses mel for loss + loss_mel = fn_mel_loss_singlescale(y_mel, y_g_hat_mel) * lambda_melloss # MPD loss y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat) @@ -310,17 +336,21 @@ def validate(rank, a, h, loader, mode="seen"): loss_gen_all = loss_mel loss_gen_all.backward() - grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), 1000.) + grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), clip_grad_norm) optim_g.step() if rank == 0: # STDOUT logging if steps % a.stdout_interval == 0: - with torch.no_grad(): - mel_error = F.l1_loss(y_mel, y_g_hat_mel).item() - - print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'. - format(steps, loss_gen_all, mel_error, time.time() - start_b)) + mel_error = loss_mel.item() / lambda_melloss # log training mel regression loss to stdout + print( + f"Steps: {steps:d}, " + f"Gen Loss Total: {loss_gen_all:4.3f}, " + f"Mel Error: {mel_error:4.3f}, " + f"s/b: {time.time() - start_b:4.3f} " + f"lr: {optim_g.param_groups[0]['lr']:4.7f} " + f"grad_norm_g: {grad_norm_g:4.3f}" + ) # checkpointing if steps % a.checkpoint_interval == 0 and steps != 0: @@ -338,7 +368,8 @@ def validate(rank, a, h, loader, mode="seen"): # Tensorboard summary logging if steps % a.summary_interval == 0: - sw.add_scalar("training/gen_loss_total", loss_gen_all, steps) + mel_error = loss_mel.item() / lambda_melloss # log training mel regression loss to tensorboard + sw.add_scalar("training/gen_loss_total", loss_gen_all.item(), steps) sw.add_scalar("training/mel_spec_error", mel_error, steps) sw.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps) sw.add_scalar("training/gen_loss_mpd", loss_gen_f.item(), steps) @@ -368,9 +399,10 @@ def validate(rank, a, h, loader, mode="seen"): validate(rank, a, h, list_unseen_validation_loader[i], mode="unseen_{}".format(list_unseen_validation_loader[i].dataset.name)) steps += 1 - - scheduler_g.step() - scheduler_d.step() + + # BigVGAN-v2 learning rate scheduler is changed from epoch-level to step-level + scheduler_g.step() + scheduler_d.step() if rank == 0: print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))