From c6a6bf7cffd93cbf2575c6c373bf916c1bdbb0d1 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Fri, 19 Jan 2024 15:36:19 +0100 Subject: [PATCH] fix: summarize_tensor(tensor) when tensor.numel() == 0 --- src/refiners/fluxion/utils.py | 13 +++++++------ tests/fluxion/test_utils.py | 1 + 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index 6052ebbc1..dedf8e1c4 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -196,12 +196,13 @@ def summarize_tensor(tensor: torch.Tensor, /) -> str: if tensor.is_complex(): tensor_f = tensor.real.float() else: - info_list.extend( - [ - f"min={tensor.min():.2f}", # type: ignore - f"max={tensor.max():.2f}", # type: ignore - ] - ) + if tensor.numel() > 0: + info_list.extend( + [ + f"min={tensor.min():.2f}", # type: ignore + f"max={tensor.max():.2f}", # type: ignore + ] + ) tensor_f = tensor.float() info_list.extend( diff --git a/tests/fluxion/test_utils.py b/tests/fluxion/test_utils.py index a86d0cdda..6165eb74a 100644 --- a/tests/fluxion/test_utils.py +++ b/tests/fluxion/test_utils.py @@ -79,6 +79,7 @@ def test_summarize_tensor() -> None: assert summarize_tensor(torch.complex(torch.zeros(1, 3, 512, 512), torch.zeros(1, 3, 512, 512))) assert summarize_tensor(torch.zeros(1, 3, 512, 512).bfloat16()) assert summarize_tensor(torch.zeros(1, 3, 512, 512).bool()) + assert summarize_tensor(torch.zeros(1, 0, 512, 512).int()) def test_no_grad() -> None: