diff --git a/d3rlpy/algos/qlearning/base.py b/d3rlpy/algos/qlearning/base.py index e2517961..1dc4988f 100644 --- a/d3rlpy/algos/qlearning/base.py +++ b/d3rlpy/algos/qlearning/base.py @@ -378,7 +378,6 @@ def fit( experiment_name: Optional[str] = None, with_timestamp: bool = True, logging_steps: int = 500, - gradient_logging_steps: Optional[int] = None, logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH, logger_adapter: LoggerAdapterFactory = FileAdapterFactory(), show_progress: bool = True, @@ -404,7 +403,6 @@ def fit( directory name. logging_steps: Number of steps to log metrics. This will be ignored if logging_strategy is EPOCH. - gradient_logging_steps: Number of steps to log gradients. logging_strategy: Logging strategy to use. logger_adapter: LoggerAdapterFactory object. show_progress: Flag to show progress bar for iterations. @@ -427,7 +425,6 @@ def fit( experiment_name=experiment_name, with_timestamp=with_timestamp, logging_steps=logging_steps, - gradient_logging_steps=gradient_logging_steps, logging_strategy=logging_strategy, logger_adapter=logger_adapter, show_progress=show_progress, @@ -445,7 +442,6 @@ def fitter( n_steps: int, n_steps_per_epoch: int = 10000, logging_steps: int = 500, - gradient_logging_steps: Optional[int] = None, logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH, experiment_name: Optional[str] = None, with_timestamp: bool = True, @@ -476,7 +472,6 @@ def fitter( directory name. logging_steps: Number of steps to log metrics. This will be ignored if logging_strategy is EPOCH. - gradient_logging_steps: Number of steps to log gradients. logging_strategy: Logging strategy to use. logger_adapter: LoggerAdapterFactory object. show_progress: Flag to show progress bar for iterations. @@ -499,15 +494,6 @@ def fitter( # initialize scalers build_scalers_with_transition_picker(self, dataset) - # setup logger - if experiment_name is None: - experiment_name = self.__class__.__name__ - logger = D3RLPyLogger( - adapter_factory=logger_adapter, - experiment_name=experiment_name, - with_timestamp=with_timestamp, - ) - # instantiate implementation if self._impl is None: LOG.debug("Building models...") @@ -522,11 +508,20 @@ def fitter( else: LOG.warning("Skip building models since they're already built.") + # setup logger + if experiment_name is None: + experiment_name = self.__class__.__name__ + logger = D3RLPyLogger( + algo=self, + adapter_factory=logger_adapter, + experiment_name=experiment_name, + n_steps_per_epoch=n_steps_per_epoch, + with_timestamp=with_timestamp, + ) + # save hyperparameters save_config(self, logger) - logger.watch_model(0, 0, gradient_logging_steps, self) - # training loop n_epochs = n_steps // n_steps_per_epoch total_step = 0 @@ -566,10 +561,6 @@ def fitter( total_step += 1 - logger.watch_model( - epoch, total_step, gradient_logging_steps, self - ) - if ( logging_strategy == LoggingStrategy.STEPS and total_step % logging_steps == 0 @@ -619,7 +610,6 @@ def fit_online( experiment_name: Optional[str] = None, with_timestamp: bool = True, logging_steps: int = 500, - gradient_logging_steps: Optional[int] = None, logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH, logger_adapter: LoggerAdapterFactory = FileAdapterFactory(), show_progress: bool = True, @@ -648,7 +638,6 @@ def fit_online( directory name. logging_steps: Number of steps to log metrics. This will be ignored if logging_strategy is EPOCH. - gradient_logging_steps: Number of steps to log gradients. logging_strategy: Logging strategy to use. logger_adapter: LoggerAdapterFactory object. show_progress: Flag to show progress bar for iterations. @@ -663,15 +652,6 @@ def fit_online( # check action-space assert_action_space_with_env(self, env) - # setup logger - if experiment_name is None: - experiment_name = self.__class__.__name__ + "_online" - logger = D3RLPyLogger( - adapter_factory=logger_adapter, - experiment_name=experiment_name, - with_timestamp=with_timestamp, - ) - # initialize algorithm parameters build_scalers_with_env(self, env) @@ -683,11 +663,20 @@ def fit_online( else: LOG.warning("Skip building models since they're already built.") + # setup logger + if experiment_name is None: + experiment_name = self.__class__.__name__ + "_online" + logger = D3RLPyLogger( + algo=self, + adapter_factory=logger_adapter, + experiment_name=experiment_name, + n_steps_per_epoch=n_steps_per_epoch, + with_timestamp=with_timestamp, + ) + # save hyperparameters save_config(self, logger) - logger.watch_model(0, 0, gradient_logging_steps, self) - # switch based on show_progress flag xrange = trange if show_progress else range @@ -756,10 +745,6 @@ def fit_online( for name, val in loss.items(): logger.add_metric(name, val) - logger.watch_model( - epoch, total_step, gradient_logging_steps, self - ) - if ( logging_strategy == LoggingStrategy.STEPS and total_step % logging_steps == 0 diff --git a/d3rlpy/algos/transformer/base.py b/d3rlpy/algos/transformer/base.py index 0aff030d..7834f502 100644 --- a/d3rlpy/algos/transformer/base.py +++ b/d3rlpy/algos/transformer/base.py @@ -418,15 +418,6 @@ def fit( # initialize scalers build_scalers_with_trajectory_slicer(self, dataset) - # setup logger - if experiment_name is None: - experiment_name = self.__class__.__name__ - logger = D3RLPyLogger( - adapter_factory=logger_adapter, - experiment_name=experiment_name, - with_timestamp=with_timestamp, - ) - # instantiate implementation if self._impl is None: LOG.debug("Building models...") @@ -441,6 +432,17 @@ def fit( else: LOG.warning("Skip building models since they're already built.") + # setup logger + if experiment_name is None: + experiment_name = self.__class__.__name__ + logger = D3RLPyLogger( + algo=self, + adapter_factory=logger_adapter, + experiment_name=experiment_name, + n_steps_per_epoch=n_steps_per_epoch, + with_timestamp=with_timestamp, + ) + # save hyperparameters save_config(self, logger) diff --git a/d3rlpy/logging/file_adapter.py b/d3rlpy/logging/file_adapter.py index cc87d8a4..e6d58a80 100644 --- a/d3rlpy/logging/file_adapter.py +++ b/d3rlpy/logging/file_adapter.py @@ -1,7 +1,7 @@ import json import os from enum import Enum, IntEnum -from typing import Any, Dict, Optional +from typing import Any, Dict import numpy as np @@ -36,12 +36,15 @@ class FileAdapter(LoggerAdapter): models as d3 files. Args: + algo: Algorithm. logdir (str): Log directory. """ + _algo: AlgProtocol _logdir: str - def __init__(self, logdir: str): + def __init__(self, algo: AlgProtocol, logdir: str): + self._algo = algo self._logdir = logdir if not os.path.exists(self._logdir): os.makedirs(self._logdir) @@ -86,21 +89,18 @@ def watch_model( self, epoch: int, step: int, - logging_steps: Optional[int], - algo: AlgProtocol, ) -> None: - assert algo.impl - if logging_steps is not None and step % logging_steps == 0: - for name, grad in algo.impl.modules.get_gradients(): - path = os.path.join(self._logdir, f"{name}_grad.csv") - with open(path, "a") as f: - min_grad = grad.min() - max_grad = grad.max() - mean_grad = grad.mean() - print( - f"{epoch},{step},{name},{min_grad},{max_grad},{mean_grad}", - file=f, - ) + assert self._algo.impl + for name, grad in self._algo.impl.modules.get_gradients(): + path = os.path.join(self._logdir, f"{name}_grad.csv") + with open(path, "a") as f: + min_grad = grad.min() + max_grad = grad.max() + mean_grad = grad.mean() + print( + f"{epoch},{step},{min_grad},{max_grad},{mean_grad}", + file=f, + ) class FileAdapterFactory(LoggerAdapterFactory): @@ -118,6 +118,8 @@ class FileAdapterFactory(LoggerAdapterFactory): def __init__(self, root_dir: str = "d3rlpy_logs"): self._root_dir = root_dir - def create(self, experiment_name: str) -> FileAdapter: + def create( + self, algo: AlgProtocol, experiment_name: str, n_steps_per_epoch: int + ) -> FileAdapter: logdir = os.path.join(self._root_dir, experiment_name) - return FileAdapter(logdir) + return FileAdapter(algo, logdir) diff --git a/d3rlpy/logging/logger.py b/d3rlpy/logging/logger.py index a0cb152f..c5c42d5d 100644 --- a/d3rlpy/logging/logger.py +++ b/d3rlpy/logging/logger.py @@ -112,23 +112,21 @@ def watch_model( self, epoch: int, step: int, - logging_steps: Optional[int], - algo: AlgProtocol, ) -> None: r"""Watch model parameters / gradients during training. Args: epoch: Epoch. step: Training step. - logging_steps: Training step. - algo: Algorithm. """ class LoggerAdapterFactory(Protocol): r"""Interface of LoggerAdapterFactory.""" - def create(self, experiment_name: str) -> LoggerAdapter: + def create( + self, algo: AlgProtocol, experiment_name: str, n_steps_per_epoch: int + ) -> LoggerAdapter: r"""Creates LoggerAdapter. This method instantiates ``LoggerAdapter`` with a given @@ -136,20 +134,25 @@ def create(self, experiment_name: str) -> LoggerAdapter: This method is usually called at the beginning of training. Args: + algo: Algorithm. experiment_name: Experiment name. + steps_per_epoch: Number of steps per epoch. """ raise NotImplementedError class D3RLPyLogger: + _algo: AlgProtocol _adapter: LoggerAdapter _experiment_name: str _metrics_buffer: DefaultDict[str, List[float]] def __init__( self, + algo: AlgProtocol, adapter_factory: LoggerAdapterFactory, experiment_name: str, + n_steps_per_epoch: int, with_timestamp: bool = True, ): if with_timestamp: @@ -157,7 +160,10 @@ def __init__( self._experiment_name = experiment_name + "_" + date else: self._experiment_name = experiment_name - self._adapter = adapter_factory.create(self._experiment_name) + self._algo = algo + self._adapter = adapter_factory.create( + algo, self._experiment_name, n_steps_per_epoch + ) self._metrics_buffer = defaultdict(list) def add_params(self, params: Dict[str, Any]) -> None: @@ -185,6 +191,9 @@ def commit(self, epoch: int, step: int) -> Dict[str, float]: self._adapter.after_write_metric(epoch, step) + # save model parameter metrics + self._adapter.watch_model(epoch, step) + # initialize metrics buffer self._metrics_buffer.clear() return metrics @@ -207,12 +216,3 @@ def measure_time(self, name: str) -> Iterator[None]: @property def adapter(self) -> LoggerAdapter: return self._adapter - - def watch_model( - self, - epoch: int, - step: int, - logging_steps: Optional[int], - algo: AlgProtocol, - ) -> None: - self._adapter.watch_model(epoch, step, logging_steps, algo) diff --git a/d3rlpy/logging/noop_adapter.py b/d3rlpy/logging/noop_adapter.py index de01eec0..18bae16f 100644 --- a/d3rlpy/logging/noop_adapter.py +++ b/d3rlpy/logging/noop_adapter.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict from .logger import ( AlgProtocol, @@ -41,8 +41,6 @@ def watch_model( self, epoch: int, step: int, - logging_steps: Optional[int], - algo: AlgProtocol, ) -> None: pass @@ -53,5 +51,7 @@ class NoopAdapterFactory(LoggerAdapterFactory): This class instantiates ``NoopAdapter`` object. """ - def create(self, experiment_name: str) -> NoopAdapter: + def create( + self, algo: AlgProtocol, experiment_name: str, n_steps_per_epoch: int + ) -> NoopAdapter: return NoopAdapter() diff --git a/d3rlpy/logging/tensorboard_adapter.py b/d3rlpy/logging/tensorboard_adapter.py index ba248785..3112830f 100644 --- a/d3rlpy/logging/tensorboard_adapter.py +++ b/d3rlpy/logging/tensorboard_adapter.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, Optional +from typing import Any, Dict import numpy as np @@ -27,16 +27,18 @@ class TensorboardAdapter(LoggerAdapter): experiment_name (str): Experiment name. """ + _algo: AlgProtocol _experiment_name: str _params: Dict[str, Any] _metrics: Dict[str, float] - def __init__(self, root_dir: str, experiment_name: str): + def __init__(self, algo: AlgProtocol, root_dir: str, experiment_name: str): try: from tensorboardX import SummaryWriter except ImportError as e: raise ImportError("Please install tensorboardX") from e + self._algo = algo self._experiment_name = experiment_name logdir = os.path.join(root_dir, "runs", experiment_name) self._writer = SummaryWriter(logdir=logdir) @@ -73,15 +75,10 @@ def watch_model( self, epoch: int, step: int, - logging_steps: Optional[int], - algo: AlgProtocol, ) -> None: - assert algo.impl - if logging_steps is not None and step % logging_steps == 0: - for name, grad in algo.impl.modules.get_gradients(): - self._writer.add_histogram( - f"histograms/{name}_grad", grad, epoch - ) + assert self._algo.impl + for name, grad in self._algo.impl.modules.get_gradients(): + self._writer.add_histogram(f"histograms/{name}_grad", grad, epoch) class TensorboardAdapterFactory(LoggerAdapterFactory): @@ -98,5 +95,7 @@ class TensorboardAdapterFactory(LoggerAdapterFactory): def __init__(self, root_dir: str = "tensorboard_logs"): self._root_dir = root_dir - def create(self, experiment_name: str) -> TensorboardAdapter: - return TensorboardAdapter(self._root_dir, experiment_name) + def create( + self, algo: AlgProtocol, experiment_name: str, n_steps_per_epoch: int + ) -> TensorboardAdapter: + return TensorboardAdapter(algo, self._root_dir, experiment_name) diff --git a/d3rlpy/logging/utils.py b/d3rlpy/logging/utils.py index 81412395..cdb47e26 100644 --- a/d3rlpy/logging/utils.py +++ b/d3rlpy/logging/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Sequence +from typing import Any, Dict, Sequence from .logger import ( AlgProtocol, @@ -53,11 +53,9 @@ def watch_model( self, epoch: int, step: int, - logging_steps: Optional[int], - algo: AlgProtocol, ) -> None: for adapter in self._adapters: - adapter.watch_model(epoch, step, logging_steps, algo) + adapter.watch_model(epoch, step) class CombineAdapterFactory(LoggerAdapterFactory): @@ -75,10 +73,12 @@ class CombineAdapterFactory(LoggerAdapterFactory): def __init__(self, adapter_factories: Sequence[LoggerAdapterFactory]): self._adapter_factories = adapter_factories - def create(self, experiment_name: str) -> CombineAdapter: + def create( + self, algo: AlgProtocol, experiment_name: str, n_steps_per_epoch: int + ) -> CombineAdapter: return CombineAdapter( [ - factory.create(experiment_name) + factory.create(algo, experiment_name, n_steps_per_epoch) for factory in self._adapter_factories ] ) diff --git a/d3rlpy/logging/wandb_adapter.py b/d3rlpy/logging/wandb_adapter.py index 4e388a26..cf4ba6a4 100644 --- a/d3rlpy/logging/wandb_adapter.py +++ b/d3rlpy/logging/wandb_adapter.py @@ -16,19 +16,30 @@ class WanDBAdapter(LoggerAdapter): This class logs data to Weights & Biases (WandB) for experiment tracking. Args: + algo: Algorithm. experiment_name (str): Name of the experiment. + n_steps_per_epoch: Number of steps per epoch. + project: Project name. """ def __init__( self, + algo: AlgProtocol, experiment_name: str, + n_steps_per_epoch: int, project: Optional[str] = None, ): try: import wandb except ImportError as e: raise ImportError("Please install wandb") from e + assert algo.impl self.run = wandb.init(project=project, name=experiment_name) + self.run.watch( + tuple(algo.impl.modules.get_torch_modules().values()), + log="gradients", + log_freq=n_steps_per_epoch, + ) self._is_model_watched = False def write_params(self, params: Dict[str, Any]) -> None: @@ -62,17 +73,8 @@ def watch_model( self, epoch: int, step: int, - logging_steps: Optional[int], - algo: AlgProtocol, ) -> None: - if not self._is_model_watched: - assert algo.impl - self.run.watch( - tuple(algo.impl.modules.get_torch_modules().values()), - log="gradients", - log_freq=logging_steps, - ) - self._is_model_watched = True + pass class WanDBAdapterFactory(LoggerAdapterFactory): @@ -80,28 +82,22 @@ class WanDBAdapterFactory(LoggerAdapterFactory): This class creates instances of the WandB Logger Adapter for experiment tracking. + + Args: + project (Optional[str], optional): The name of the WandB project. Defaults to None. """ _project: Optional[str] def __init__(self, project: Optional[str] = None) -> None: - """Initialize the WandB Logger Adapter Factory. - - Args: - project (Optional[str], optional): The name of the WandB project. Defaults to None. - """ self._project = project - def create(self, experiment_name: str) -> LoggerAdapter: - """Creates a WandB Logger Adapter instance. - - Args: - experiment_name (str): Name of the experiment. - - Returns: - Instance of the WandB Logger Adapter. - """ + def create( + self, algo: AlgProtocol, experiment_name: str, n_steps_per_epoch: int + ) -> LoggerAdapter: return WanDBAdapter( + algo=algo, experiment_name=experiment_name, + n_steps_per_epoch=n_steps_per_epoch, project=self._project, ) diff --git a/tests/base_test.py b/tests/base_test.py index b337c317..dec99da4 100644 --- a/tests/base_test.py +++ b/tests/base_test.py @@ -60,7 +60,12 @@ def from_json_tester( algo.create_impl(observation_shape, action_size) # save params.json adapter_factory = FileAdapterFactory("test_data") - logger = D3RLPyLogger(adapter_factory, experiment_name="test") + logger = D3RLPyLogger( + algo=algo, + adapter_factory=adapter_factory, + n_steps_per_epoch=1, + experiment_name="test", + ) # save parameters to test_data/test/params.json save_config(algo, logger) # load params.json diff --git a/tests/logging/test_logger.py b/tests/logging/test_logger.py index 9b3ec0f0..cf484a67 100644 --- a/tests/logging/test_logger.py +++ b/tests/logging/test_logger.py @@ -45,14 +45,14 @@ def watch_model( self, epoch: int, step: int, - logging_step: int, - algo: AlgProtocol, ) -> None: self.is_watch_model_called = True class StubLoggerAdapterFactory: - def create(self, experiment_name: str) -> StubLoggerAdapter: + def create( + self, algo: AlgProtocol, experiment_name: str, n_steps_per_epoch: int + ) -> StubLoggerAdapter: return StubLoggerAdapter(experiment_name) @@ -74,7 +74,13 @@ def save(self, fname: str) -> None: @pytest.mark.parametrize("with_timestamp", [False, True]) def test_d3rlpy_logger(with_timestamp: bool) -> None: - logger = D3RLPyLogger(StubLoggerAdapterFactory(), "test", with_timestamp) # type: ignore + logger = D3RLPyLogger( + algo=StubAlgo(), # type: ignore + adapter_factory=StubLoggerAdapterFactory(), + experiment_name="test", + n_steps_per_epoch=1, + with_timestamp=with_timestamp, + ) # check experiment_name adapter = logger.adapter @@ -95,12 +101,14 @@ def test_d3rlpy_logger(with_timestamp: bool) -> None: assert not adapter.is_before_write_metric_called assert not adapter.is_write_metric_called assert not adapter.is_after_write_metric_called + assert not adapter.is_watch_model_called metrics = logger.commit(1, 1) assert "test" in metrics assert "time_test" in metrics assert adapter.is_before_write_metric_called assert adapter.is_write_metric_called assert adapter.is_after_write_metric_called + assert adapter.is_watch_model_called assert not adapter.is_save_model_called logger.save_model(1, StubAlgo()) @@ -109,7 +117,3 @@ def test_d3rlpy_logger(with_timestamp: bool) -> None: assert not adapter.is_close_called logger.close() assert adapter.is_close_called - - assert not adapter.is_watch_model_called - logger.watch_model(1, 1, 1, StubAlgo()) - assert adapter.is_watch_model_called