Skip to content

Commit

Permalink
deprecate DDPM step which is unused for now
Browse files Browse the repository at this point in the history
  • Loading branch information
deltheil committed Dec 13, 2023
1 parent a7551e0 commit 82a2aa1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 54 deletions.
55 changes: 3 additions & 52 deletions src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from torch import Generator, Tensor, arange, device as Device, randn, tensor
from torch import Tensor, arange, device as Device

from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler

Expand Down Expand Up @@ -30,54 +30,5 @@ def _generate_timesteps(self) -> Tensor:
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio
return timesteps.flip(0)

def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""
Generate the next step in the diffusion process.
This method adjusts the input data using added noise and an estimate of the denoised data, based on the current
step in the diffusion process. This adjusted data forms the next step in the diffusion process.
1. It uses current and previous timesteps to calculate the current factor dictating the contribution of original
data and noise to the new step.
2. An estimate of the denoised data (`estimated_denoised_data`) is generated.
3. It calculates coefficients for the estimated denoised data and current data (`original_data_coeff` and
`current_data_coeff`) that balance their contribution to the denoised data for the next step.
4. It calculates the denoised data for the next step (`denoised_x`), which is a combination of the estimated
denoised data and current data, adjusted by their respective coefficients.
5. Noise is then added to `denoised_x`. The magnitude of noise is controlled by a calculated variance based on
the cumulative scaling factor and the current factor.
The output is the new data step for the next stage in the diffusion process.
"""
timestep, previous_timestep = (
self.timesteps[step],
(
self.timesteps[step + 1]
if step < len(self.timesteps) - 1
else tensor(-(self.num_train_timesteps // self.num_inference_steps), device=self.device)
),
)
current_cumulative_factor, previous_cumulative_scale_factor = (
(self.scale_factors.cumprod(0))[timestep],
(
(self.scale_factors.cumprod(0))[previous_timestep]
if step < len(self.timesteps) - 1
else tensor(1, device=self.device)
),
)
current_factor = current_cumulative_factor / previous_cumulative_scale_factor
estimated_denoised_data = (x - (1 - current_cumulative_factor) ** 0.5 * noise) / current_cumulative_factor**0.5
estimated_denoised_data = estimated_denoised_data.clamp(-1, 1)
original_data_coeff = (previous_cumulative_scale_factor**0.5 * (1 - current_factor)) / (
1 - current_cumulative_factor
)
current_data_coeff = (
current_factor**0.5 * (1 - previous_cumulative_scale_factor) / (1 - current_cumulative_factor)
)
denoised_x = original_data_coeff * estimated_denoised_data + current_data_coeff * x
if step < len(self.timesteps) - 1:
variance = (1 - previous_cumulative_scale_factor) / (1 - current_cumulative_factor) * (1 - current_factor)
denoised_x = denoised_x + (variance.clamp(min=1e-20) ** 0.5) * randn(
x.shape, device=x.device, dtype=x.dtype, generator=generator
)
return denoised_x
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
raise NotImplementedError
14 changes: 12 additions & 2 deletions tests/foundationals/latent_diffusion/test_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,20 @@
from warnings import warn

import pytest
from torch import Tensor, allclose, device as Device, randn
from torch import Tensor, allclose, device as Device, equal, randn

from refiners.fluxion import manual_seed
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DPMSolver
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DDPM, DPMSolver


def test_ddpm_diffusers():
from diffusers import DDPMScheduler # type: ignore

diffusers_scheduler = DDPMScheduler(beta_schedule="scaled_linear", beta_start=0.00085, beta_end=0.012)
diffusers_scheduler.set_timesteps(1000)
refiners_scheduler = DDPM(num_inference_steps=1000)

assert equal(diffusers_scheduler.timesteps, refiners_scheduler.timesteps)


def test_dpm_solver_diffusers():
Expand Down

0 comments on commit 82a2aa1

Please sign in to comment.