Skip to content

Commit

Permalink
finegrain-ai#171: clean assert & generic mean
Browse files Browse the repository at this point in the history
  • Loading branch information
piercus committed Jan 10, 2024
1 parent 2a9665e commit aa232ce
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
12 changes: 8 additions & 4 deletions src/refiners/fluxion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,13 @@ def summarize_tensor(tensor: torch.Tensor, /) -> str:
]
)

if tensor.dtype == torch.float or tensor.is_complex():
info_list.extend([f"mean={tensor.mean():.2f}", f"std={tensor.std():.2f}", f"norm={norm(x=tensor):.2f}"])

info_list.extend([f"grad={tensor.requires_grad}"])
info_list.extend(
[
f"mean={tensor.float().mean():.2f}",
f"std={tensor.float().std():.2f}",
f"norm={norm(x=tensor.float()):.2f}",
f"grad={tensor.requires_grad}",
]
)

return "Tensor(" + ", ".join(info_list) + ")"
10 changes: 6 additions & 4 deletions tests/fluxion/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,12 @@ def test_tensor_to_image() -> None:


def test_summarize_tensor() -> None:
assert type(summarize_tensor(torch.zeros(1, 3, 512, 512).int())) == str
assert type(summarize_tensor(torch.zeros(1, 3, 512, 512).float())) == str
assert type(summarize_tensor(torch.zeros(1, 3, 512, 512).double())) == str
assert type(summarize_tensor(torch.complex(torch.zeros(1, 3, 512, 512), torch.zeros(1, 3, 512, 512)))) == str
assert summarize_tensor(torch.zeros(1, 3, 512, 512).int())
assert summarize_tensor(torch.zeros(1, 3, 512, 512).float())
assert summarize_tensor(torch.zeros(1, 3, 512, 512).double())
assert summarize_tensor(torch.complex(torch.zeros(1, 3, 512, 512), torch.zeros(1, 3, 512, 512)))
assert summarize_tensor(torch.zeros(1, 3, 512, 512).bfloat16())
assert summarize_tensor(torch.zeros(1, 3, 512, 512).bool())


def test_no_grad() -> None:
Expand Down

0 comments on commit aa232ce

Please sign in to comment.