Skip to content

Commit

Permalink
Support sharding for get_parameter_overview
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 612412904
  • Loading branch information
Conchylicultor authored and copybara-github committed Mar 4, 2024
1 parent eed40a1 commit 3cc7394
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions clu/parameter_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,15 +259,17 @@ def _get_parameter_overview(
def get_parameter_overview(
params: _ParamsContainer,
*,
include_stats: bool = True,
include_stats: bool | str = True,
max_lines: int | None = None,
) -> str:
"""Returns a string with variables names, their shapes, count.
Args:
params: Dictionary with parameters as NumPy arrays. The dictionary can be
nested.
include_stats: If True, add columns with mean and std for each variable.
include_stats: If True, add columns with mean and std for each variable. If
the string "global", params are sharded global arrays and this function
assumes it is called on every host, i.e. can use collectives.
max_lines: If not `None`, the maximum number of variables to include.
Returns:
Expand Down

0 comments on commit 3cc7394

Please sign in to comment.