Skip to content

Commit

Permalink
change default behavior of end to None
Browse files Browse the repository at this point in the history
  • Loading branch information
limiteinductive committed Dec 13, 2023
1 parent a7551e0 commit 349e3e3
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 15 deletions.
5 changes: 3 additions & 2 deletions src/refiners/fluxion/layers/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def forward(self, x: Tensor) -> Tensor:


class Slicing(Module):
def __init__(self, dim: int = 0, start: int = 0, end: int = -1, step: int = 1) -> None:
def __init__(self, dim: int = 0, start: int = 0, end: int | None = None, step: int = 1) -> None:
super().__init__()
self.dim = dim
self.start = start
Expand All @@ -94,7 +94,8 @@ def __init__(self, dim: int = 0, start: int = 0, end: int = -1, step: int = 1) -
def forward(self, x: Tensor) -> Tensor:
dim_size = x.shape[self.dim]
start = self.start if self.start >= 0 else dim_size + self.start
end = self.end if self.end >= 0 else dim_size + self.end
end = self.end or dim_size
end = end if end >= 0 else dim_size + end
start = max(min(start, dim_size), 0)
end = max(min(end, dim_size), 0)
if start >= end:
Expand Down
8 changes: 2 additions & 6 deletions src/refiners/foundationals/latent_diffusion/image_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,7 @@ def __init__(
InjectionPoint(), # Wk
),
fl.Chain(
fl.Slicing(
dim=1, start=text_sequence_length, end=text_sequence_length + image_sequence_length
),
fl.Slicing(dim=1, start=text_sequence_length),
fl.Linear(
in_features=self.target.key_embedding_dim,
out_features=self.target.inner_dim,
Expand All @@ -286,9 +284,7 @@ def __init__(
InjectionPoint(), # Wv
),
fl.Chain(
fl.Slicing(
dim=1, start=text_sequence_length, end=text_sequence_length + image_sequence_length
),
fl.Slicing(dim=1, start=text_sequence_length),
fl.Linear(
in_features=self.target.key_embedding_dim,
out_features=self.target.inner_dim,
Expand Down
4 changes: 2 additions & 2 deletions src/refiners/foundationals/segment_anything/mask_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def __init__(
),
other=DenseEmbeddingUpscaling(embedding_dim=embedding_dim, device=device, dtype=dtype),
),
fl.Slicing(dim=1, start=1, end=num_mask_tokens + 1),
fl.Slicing(dim=1, start=1),
fl.Reshape(num_mask_tokens, embedding_dim, embedding_dim),
)

Expand All @@ -183,7 +183,7 @@ def __init__(
device=device,
dtype=dtype,
),
fl.Slicing(dim=-1, start=1, end=num_mask_tokens + 1),
fl.Slicing(dim=-1, start=1),
)


Expand Down
18 changes: 13 additions & 5 deletions tests/fluxion/layers/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ def test_slicing_negative_indices() -> None:
assert torch.equal(sliced, expected)


def test_none_end_slicing() -> None:
x = torch.randn(2, 1000, 400)
slicing = Slicing(dim=1, start=1)
sliced = slicing(x)
expected = x[:, 1:, :]
assert torch.equal(sliced, expected)


def test_slicing_step() -> None:
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=1, start=0, end=5, step=2)
Expand All @@ -27,39 +35,39 @@ def test_slicing_step() -> None:
assert torch.equal(sliced, expected)


def test_slicing_empty_slice():
def test_slicing_empty_slice() -> None:
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=1, start=3, end=3)
sliced = slicing_layer(x)
expected = x[:, 3:3]
assert torch.equal(sliced, expected)


def test_slicing_full_dimension():
def test_slicing_full_dimension() -> None:
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=2, start=0, end=5)
sliced = slicing_layer(x)
expected = x[:, :, :]
assert torch.equal(sliced, expected)


def test_slicing_step_greater_than_range():
def test_slicing_step_greater_than_range() -> None:
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=1, start=1, end=3, step=4)
sliced = slicing_layer(x)
expected = x[:, 1:3:4]
assert torch.equal(sliced, expected)


def test_slicing_reversed_start_end():
def test_slicing_reversed_start_end() -> None:
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=1, start=4, end=2)
sliced = slicing_layer(x)
expected = x[:, 4:2]
assert torch.equal(sliced, expected)


def test_slicing_out_of_bounds_indices():
def test_slicing_out_of_bounds_indices() -> None:
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=1, start=-10, end=10)
sliced = slicing_layer(x)
Expand Down

0 comments on commit 349e3e3

Please sign in to comment.