Skip to content

Commit

Permalink
improve DPM solver test
Browse files Browse the repository at this point in the history
  • Loading branch information
catwell committed Jan 18, 2024
1 parent 999e429 commit ce30359
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions tests/foundationals/latent_diffusion/test_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,21 @@ def test_ddpm_diffusers():
assert equal(diffusers_scheduler.timesteps, refiners_scheduler.timesteps)


def test_dpm_solver_diffusers():
@pytest.mark.parametrize("n_steps, last_step_first_order", [(5, False), (5, True), (30, False), (30, True)])
def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore

manual_seed(0)

diffusers_scheduler = DiffuserScheduler(beta_schedule="scaled_linear", beta_start=0.00085, beta_end=0.012)
diffusers_scheduler.set_timesteps(30)
refiners_scheduler = DPMSolver(num_inference_steps=30)
diffusers_scheduler = DiffuserScheduler(
beta_schedule="scaled_linear",
beta_start=0.00085,
beta_end=0.012,
lower_order_final=False,
euler_at_final=last_step_first_order,
)
diffusers_scheduler.set_timesteps(n_steps)
refiners_scheduler = DPMSolver(num_inference_steps=n_steps, last_step_first_order=last_step_first_order)

sample = randn(1, 3, 32, 32)
noise = randn(1, 3, 32, 32)
Expand Down

0 comments on commit ce30359

Please sign in to comment.