diff --git a/clu/parameter_overview.py b/clu/parameter_overview.py index 84cf435..ae2fd5c 100644 --- a/clu/parameter_overview.py +++ b/clu/parameter_overview.py @@ -259,7 +259,7 @@ def _get_parameter_overview( def get_parameter_overview( params: _ParamsContainer, *, - include_stats: bool = True, + include_stats: bool | str = True, max_lines: int | None = None, ) -> str: """Returns a string with variables names, their shapes, count. @@ -267,7 +267,9 @@ def get_parameter_overview( Args: params: Dictionary with parameters as NumPy arrays. The dictionary can be nested. - include_stats: If True, add columns with mean and std for each variable. + include_stats: If True, add columns with mean and std for each variable. If + the string "global", params are sharded global arrays and this function + assumes it is called on every host, i.e. can use collectives. max_lines: If not `None`, the maximum number of variables to include. Returns: