Skip to content

Commit

Permalink
upgrade pyright to 1.1.342 ; improve no_grad typing
Browse files Browse the repository at this point in the history
  • Loading branch information
limiteinductive committed Dec 29, 2023
1 parent 12eef9c commit 5af016e
Show file tree
Hide file tree
Showing 31 changed files with 136 additions and 94 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ t2i_adapter.set_scale(0.8)
sdxl.set_num_inference_steps(50)
sdxl.set_self_attention_guidance(enable=True, scale=0.75)

with torch.no_grad():
with no_grad():
# Note: default text prompts for IP-Adapter
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text="best quality, high quality", negative_text="monochrome, lowres, bad anatomy, worst quality, low quality"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ build-backend = "hatchling.build"
[tool.rye]
managed = true
dev-dependencies = [
"pyright == 1.1.333",
"pyright == 1.1.342",
"ruff>=0.0.292",
"docformatter>=1.7.5",
"pytest>=7.4.2",
Expand Down
1 change: 1 addition & 0 deletions requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ six==1.16.0
smmap==5.0.1
sympy==1.12
tokenizers==0.15.0
tomli==2.0.1
torch==2.1.1
torchvision==0.16.1
tqdm==4.66.1
Expand Down
4 changes: 2 additions & 2 deletions scripts/conversion/convert_diffusers_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import nn

from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import save_to_safetensors
from refiners.fluxion.utils import no_grad, save_to_safetensors
from refiners.foundationals.latent_diffusion import (
DPMSolver,
SD1ControlnetAdapter,
Expand All @@ -20,7 +20,7 @@ class Args(argparse.Namespace):
output_path: str | None


@torch.no_grad()
@no_grad()
def convert(args: Args) -> dict[str, torch.Tensor]:
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
controlnet_src: nn.Module = ControlNetModel.from_pretrained( # type: ignore
Expand Down
4 changes: 2 additions & 2 deletions scripts/conversion/convert_diffusers_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.lora import Lora, LoraAdapter
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import save_to_safetensors
from refiners.fluxion.utils import no_grad, save_to_safetensors
from refiners.foundationals.latent_diffusion import SD1UNet
from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets

Expand All @@ -37,7 +37,7 @@ class Args(argparse.Namespace):
verbose: bool


@torch.no_grad()
@no_grad()
def process(args: Args) -> None:
diffusers_state_dict = cast(dict[str, Tensor], torch.load(args.source_path, map_location="cpu")) # type: ignore
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
Expand Down
5 changes: 4 additions & 1 deletion src/refiners/fluxion/layers/sampling.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Callable

from torch import Size, Tensor, device as Device, dtype as DType
from torch.nn.functional import pad

Expand Down Expand Up @@ -40,7 +42,8 @@ def __init__(
),
)
if padding == 0:
self.insert(0, Lambda(lambda x: pad(x, (0, 1, 0, 1))))
zero_pad: Callable[[Tensor], Tensor] = lambda x: pad(x, (0, 1, 0, 1))
self.insert(0, Lambda(zero_pad))
if register_shape:
self.insert(0, SetContext(context="sampling", key="shapes", callback=self.register_shape))

Expand Down
6 changes: 3 additions & 3 deletions src/refiners/fluxion/model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import Tensor, nn
from torch.utils.hooks import RemovableHandle

from refiners.fluxion.utils import norm, save_to_safetensors
from refiners.fluxion.utils import no_grad, norm, save_to_safetensors

TORCH_BASIC_LAYERS: list[type[nn.Module]] = [
nn.Conv1d,
Expand Down Expand Up @@ -512,7 +512,7 @@ def _verify_missing_basic_layers(self) -> bool:

return True

@torch.no_grad()
@no_grad()
def _trace_module_execution_order(
self,
module: nn.Module,
Expand Down Expand Up @@ -603,7 +603,7 @@ def _convert_state_dict(

return converted_state_dict

@torch.no_grad()
@no_grad()
def _collect_layers_outputs(
self, module: nn.Module, args: ModuleArgs, keys_to_skip: list[str]
) -> list[tuple[str, Tensor]]:
Expand Down
16 changes: 14 additions & 2 deletions src/refiners/fluxion/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
from pathlib import Path
from typing import Iterable, Literal, TypeVar
from typing import Any, Iterable, Literal, TypeVar

import torch
from jaxtyping import Float
from numpy import array, float32
from PIL import Image
from safetensors import safe_open as _safe_open # type: ignore
from safetensors.torch import save_file as _save_file # type: ignore
from torch import Tensor, device as Device, dtype as DType, manual_seed as _manual_seed, norm as _norm # type: ignore
from torch import (
Tensor,
device as Device,
dtype as DType,
manual_seed as _manual_seed, # type: ignore
no_grad as _no_grad, # type: ignore
norm as _norm, # type: ignore
)
from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore

T = TypeVar("T")
Expand All @@ -22,6 +29,11 @@ def manual_seed(seed: int) -> None:
_manual_seed(seed)


class no_grad(_no_grad):
def __new__(cls, orig_func: Any | None = None) -> "no_grad": # type: ignore
return object.__new__(cls)


def pad(x: Tensor, pad: Iterable[int], value: float = 0.0, mode: str = "constant") -> Tensor:
return _pad(input=x, pad=pad, value=value, mode=mode) # type: ignore

Expand Down
7 changes: 5 additions & 2 deletions src/refiners/foundationals/clip/image_encoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from torch import device as Device, dtype as DType
from typing import Callable

from torch import Tensor, device as Device, dtype as DType

import refiners.fluxion.layers as fl
from refiners.foundationals.clip.common import FeedForward, PositionalEncoder
Expand Down Expand Up @@ -126,6 +128,7 @@ def __init__(
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self.feedforward_dim = feedforward_dim
select_first_embedding: Callable[[Tensor], Tensor] = lambda x: x[:, 0, :]
super().__init__(
ViTEmbeddings(
image_size=image_size, embedding_dim=embedding_dim, patch_size=patch_size, device=device, dtype=dtype
Expand All @@ -142,7 +145,7 @@ def __init__(
)
for _ in range(num_layers)
),
fl.Lambda(func=lambda x: x[:, 0, :]),
fl.Lambda(func=select_first_embedding),
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
fl.Linear(in_features=embedding_dim, out_features=output_dim, bias=False, device=device, dtype=dtype),
)
Expand Down
5 changes: 3 additions & 2 deletions src/refiners/foundationals/latent_diffusion/freeu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Any, Generic, TypeVar
from typing import Any, Callable, Generic, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -54,9 +54,10 @@ def forward(self, x: Tensor) -> Tensor:

class FreeUSkipFeatures(fl.Chain):
def __init__(self, n: int, skip_scale: float) -> None:
apply_filter: Callable[[Tensor], Tensor] = lambda x: fourier_filter(x, scale=skip_scale)
super().__init__(
fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[n]),
fl.Lambda(lambda x: fourier_filter(x, scale=skip_scale)),
fl.Lambda(apply_filter),
)


Expand Down
2 changes: 1 addition & 1 deletion src/refiners/foundationals/latent_diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def from_safetensors(
assert metadata is not None, "Invalid safetensors checkpoint: missing metadata"
tensors = load_from_safetensors(checkpoint_path, device=target.device)

sub_targets = {}
sub_targets: dict[str, list[LoraTarget]] = {}
for model_name in MODELS:
if not (v := metadata.get(f"{model_name}_targets", "")):
continue
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Callable

from torch import Tensor

from refiners.fluxion.adapters.adapter import Adapter
Expand Down Expand Up @@ -45,8 +47,9 @@ def __init__(
)

with self.setup_adapter(target):
slice_tensor: Callable[[Tensor], Tensor] = lambda x: x[:1]
super().__init__(
Parallel(sa_guided, Chain(Lambda(lambda x: x[:1]), target)),
Parallel(sa_guided, Chain(Lambda(slice_tensor), target)),
Lambda(self.compute_averaged_unconditioned_x),
)

Expand Down
6 changes: 3 additions & 3 deletions src/refiners/foundationals/segment_anything/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import Tensor, device as Device, dtype as DType

import refiners.fluxion.layers as fl
from refiners.fluxion.utils import image_to_tensor, interpolate, normalize, pad
from refiners.fluxion.utils import image_to_tensor, interpolate, no_grad, normalize, pad
from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
Expand Down Expand Up @@ -39,7 +39,7 @@ def __init__(
self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype)
self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype)

@torch.no_grad()
@no_grad()
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
original_size = (image.height, image.width)
target_size = self.compute_target_size(original_size)
Expand All @@ -48,7 +48,7 @@ def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
original_image_size=original_size,
)

@torch.no_grad()
@no_grad()
def predict(
self,
input: Image.Image | ImageEmbedding,
Expand Down
2 changes: 1 addition & 1 deletion src/refiners/training_utils/latent_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def on_compute_loss_end(self, trainer: LatentDiffusionTrainer[Any]) -> None:
self.timestep_bins[bin_index].append(loss_value)

def on_epoch_end(self, trainer: LatentDiffusionTrainer[Any]) -> None:
log_data = {}
log_data: dict[str, WandbLoggable] = {}
for bin_index, losses in self.timestep_bins.items():
if losses:
avg_loss = sum(losses) / len(losses)
Expand Down
4 changes: 2 additions & 2 deletions src/refiners/training_utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np
from loguru import logger
from torch import Tensor, cuda, device as Device, get_rng_state, no_grad, set_rng_state, stack
from torch import Tensor, cuda, device as Device, get_rng_state, set_rng_state, stack
from torch.autograd import backward
from torch.nn import Parameter
from torch.optim import Optimizer
Expand All @@ -26,7 +26,7 @@
from torch.utils.data import DataLoader, Dataset

from refiners.fluxion import layers as fl
from refiners.fluxion.utils import manual_seed
from refiners.fluxion.utils import manual_seed, no_grad
from refiners.training_utils.callback import (
Callback,
ClockCallback,
Expand Down
Loading

0 comments on commit 5af016e

Please sign in to comment.