Skip to content

Commit

Permalink
add a get_path helper to modules
Browse files Browse the repository at this point in the history
  • Loading branch information
catwell committed Jan 26, 2024
1 parent 0ee2d5e commit ce0339b
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
18 changes: 18 additions & 0 deletions src/refiners/fluxion/layers/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,21 @@ def _show_only_tag(self) -> bool:
"""
return False

def get_path(self, parent: "Chain | None" = None, top: "Module | None" = None) -> str:
"""Helper for debugging purpose only.
Returns the path of the module in the chain as a string.
If `top` is set then the path will be relative to `top`,
otherwise it will be relative to the root of the chain.
"""
if (parent is None) or (self == top):
return self.__class__.__name__
for k, m in parent._modules.items(): # type: ignore
if m is self:
return parent.get_path(parent=parent.parent, top=top) + "." + k
raise ValueError(f"{self} not found in {parent}")


class ContextModule(Module):
# we store parent into a one element list to avoid pytorch thinking it's a submodule
Expand Down Expand Up @@ -154,6 +169,9 @@ def structural_copy(self: TContextModule) -> TContextModule:

return clone

def get_path(self, parent: "Chain | None" = None, top: "Module | None" = None) -> str:
return super().get_path(parent=parent or self.parent, top=top)


class WeightedModule(Module):
@property
Expand Down
14 changes: 14 additions & 0 deletions tests/fluxion/layers/test_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,17 @@ def test_debug_print() -> None:
)

assert chain._show_error_in_tree("Chain.Linear_2") == EXPECTED_TREE # type: ignore[reportPrivateUsage]


def test_module_get_path() -> None:
chain = fl.Chain(
fl.Sum(
fl.Linear(1, 1),
fl.Linear(1, 1),
),
fl.Sum(),
)

assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1) == "Chain.Sum_1.Linear_2"
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1, top=chain.Sum_1) == "Sum.Linear_2"
assert chain.Sum_1.get_path() == "Chain.Sum_1"

0 comments on commit ce0339b

Please sign in to comment.