From 8e6ec3641bea5fdde84b1260abcebc8997266bbc Mon Sep 17 00:00:00 2001 From: takuseno Date: Wed, 16 Oct 2024 20:57:56 +0900 Subject: [PATCH] Add header --- d3rlpy/logging/file_adapter.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/d3rlpy/logging/file_adapter.py b/d3rlpy/logging/file_adapter.py index e6d58a80..637402c1 100644 --- a/d3rlpy/logging/file_adapter.py +++ b/d3rlpy/logging/file_adapter.py @@ -42,10 +42,12 @@ class FileAdapter(LoggerAdapter): _algo: AlgProtocol _logdir: str + _is_model_watched: bool def __init__(self, algo: AlgProtocol, logdir: str): self._algo = algo self._logdir = logdir + self._is_model_watched = False if not os.path.exists(self._logdir): os.makedirs(self._logdir) LOG.info(f"Directory is created at {self._logdir}") @@ -91,14 +93,29 @@ def watch_model( step: int, ) -> None: assert self._algo.impl + + # write header at the first call + if not self._is_model_watched: + self._is_model_watched = True + for name, grad in self._algo.impl.modules.get_gradients(): + path = os.path.join(self._logdir, f"{name}_grad.csv") + with open(path, "w") as f: + print( + ",".join( + ["epoch", "step", "min", "max", "mean", "std"] + ), + file=f, + ) + for name, grad in self._algo.impl.modules.get_gradients(): path = os.path.join(self._logdir, f"{name}_grad.csv") with open(path, "a") as f: min_grad = grad.min() max_grad = grad.max() - mean_grad = grad.mean() + mean = grad.mean() + std = grad.std() print( - f"{epoch},{step},{min_grad},{max_grad},{mean_grad}", + f"{epoch},{step},{min_grad},{max_grad},{mean},{std}", file=f, )