From c56be279d14f7dafa8d31e02ede87f202eb884d8 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Thu, 11 Jan 2024 08:10:49 +0100 Subject: [PATCH 1/3] display module dtype and device --- src/refiners/fluxion/layers/module.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/refiners/fluxion/layers/module.py b/src/refiners/fluxion/layers/module.py index 1fc663e35..2c7210ef4 100644 --- a/src/refiners/fluxion/layers/module.py +++ b/src/refiners/fluxion/layers/module.py @@ -243,6 +243,15 @@ def _module_to_tree(self, module: Module) -> TreeNode: case (_, False): value = f"({module._tag}) {module}" # pyright: ignore[reportPrivateUsage] + dtype_str = getattr(module, "dtype", None).__repr__().replace("torch.", "") + + if hasattr(module, "device") and module.device is not None: + device_str = f"{module.device.type}:{module.device.index}" + else: + device_str = "cpu" + + value = f"{value} [dtype={dtype_str}, device={device_str}] " + class_name = module.__class__.__name__ node: TreeNode = {"value": value, "class_name": class_name, "children": []} From 2d652d1e3f08a49b6bbf72a08a07479b52cf2182 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Thu, 11 Jan 2024 12:27:29 +0100 Subject: [PATCH 2/3] WeightedModule _repr_ change for #173 --- src/refiners/fluxion/layers/module.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/refiners/fluxion/layers/module.py b/src/refiners/fluxion/layers/module.py index 2c7210ef4..22936c0d6 100644 --- a/src/refiners/fluxion/layers/module.py +++ b/src/refiners/fluxion/layers/module.py @@ -164,6 +164,10 @@ def device(self) -> Device: def dtype(self) -> DType: return self.weight.dtype + def __repr__(self) -> str: + str = super().__repr__() + return f"{str} [device={self.device}, dtype={self.dtype}]" + class TreeNode(TypedDict): value: str @@ -243,15 +247,6 @@ def _module_to_tree(self, module: Module) -> TreeNode: case (_, False): value = f"({module._tag}) {module}" # pyright: ignore[reportPrivateUsage] - dtype_str = getattr(module, "dtype", None).__repr__().replace("torch.", "") - - if hasattr(module, "device") and module.device is not None: - device_str = f"{module.device.type}:{module.device.index}" - else: - device_str = "cpu" - - value = f"{value} [dtype={dtype_str}, device={device_str}] " - class_name = module.__class__.__name__ node: TreeNode = {"value": value, "class_name": class_name, "children": []} From cfea67c81237cee7ec51da163b573d4a30b37a17 Mon Sep 17 00:00:00 2001 From: Colle Date: Thu, 11 Jan 2024 18:09:19 +0100 Subject: [PATCH 3/3] Update src/refiners/fluxion/layers/module.py Co-authored-by: Benjamin Trom --- src/refiners/fluxion/layers/module.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/refiners/fluxion/layers/module.py b/src/refiners/fluxion/layers/module.py index 22936c0d6..bb612fd7a 100644 --- a/src/refiners/fluxion/layers/module.py +++ b/src/refiners/fluxion/layers/module.py @@ -164,9 +164,8 @@ def device(self) -> Device: def dtype(self) -> DType: return self.weight.dtype - def __repr__(self) -> str: - str = super().__repr__() - return f"{str} [device={self.device}, dtype={self.dtype}]" + def __str__(self) -> str: + return f"{super().__str__().removesuffix(')')}, device={self.device}, dtype={str(self.dtype).removeprefix('torch.')})" class TreeNode(TypedDict):