Skip to content

Commit

Permalink
Record gradients every epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Oct 15, 2024
1 parent 7951f3d commit 19064fa
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 134 deletions.
59 changes: 22 additions & 37 deletions d3rlpy/algos/qlearning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,6 @@ def fit(
experiment_name: Optional[str] = None,
with_timestamp: bool = True,
logging_steps: int = 500,
gradient_logging_steps: Optional[int] = None,
logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH,
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
show_progress: bool = True,
Expand All @@ -404,7 +403,6 @@ def fit(
directory name.
logging_steps: Number of steps to log metrics. This will be ignored
if logging_strategy is EPOCH.
gradient_logging_steps: Number of steps to log gradients.
logging_strategy: Logging strategy to use.
logger_adapter: LoggerAdapterFactory object.
show_progress: Flag to show progress bar for iterations.
Expand All @@ -427,7 +425,6 @@ def fit(
experiment_name=experiment_name,
with_timestamp=with_timestamp,
logging_steps=logging_steps,
gradient_logging_steps=gradient_logging_steps,
logging_strategy=logging_strategy,
logger_adapter=logger_adapter,
show_progress=show_progress,
Expand All @@ -445,7 +442,6 @@ def fitter(
n_steps: int,
n_steps_per_epoch: int = 10000,
logging_steps: int = 500,
gradient_logging_steps: Optional[int] = None,
logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH,
experiment_name: Optional[str] = None,
with_timestamp: bool = True,
Expand Down Expand Up @@ -476,7 +472,6 @@ def fitter(
directory name.
logging_steps: Number of steps to log metrics. This will be ignored
if logging_strategy is EPOCH.
gradient_logging_steps: Number of steps to log gradients.
logging_strategy: Logging strategy to use.
logger_adapter: LoggerAdapterFactory object.
show_progress: Flag to show progress bar for iterations.
Expand All @@ -499,15 +494,6 @@ def fitter(
# initialize scalers
build_scalers_with_transition_picker(self, dataset)

# setup logger
if experiment_name is None:
experiment_name = self.__class__.__name__
logger = D3RLPyLogger(
adapter_factory=logger_adapter,
experiment_name=experiment_name,
with_timestamp=with_timestamp,
)

# instantiate implementation
if self._impl is None:
LOG.debug("Building models...")
Expand All @@ -522,11 +508,20 @@ def fitter(
else:
LOG.warning("Skip building models since they're already built.")

# setup logger
if experiment_name is None:
experiment_name = self.__class__.__name__
logger = D3RLPyLogger(
algo=self,
adapter_factory=logger_adapter,
experiment_name=experiment_name,
n_steps_per_epoch=n_steps_per_epoch,
with_timestamp=with_timestamp,
)

# save hyperparameters
save_config(self, logger)

logger.watch_model(0, 0, gradient_logging_steps, self)

# training loop
n_epochs = n_steps // n_steps_per_epoch
total_step = 0
Expand Down Expand Up @@ -566,10 +561,6 @@ def fitter(

total_step += 1

logger.watch_model(
epoch, total_step, gradient_logging_steps, self
)

if (
logging_strategy == LoggingStrategy.STEPS
and total_step % logging_steps == 0
Expand Down Expand Up @@ -619,7 +610,6 @@ def fit_online(
experiment_name: Optional[str] = None,
with_timestamp: bool = True,
logging_steps: int = 500,
gradient_logging_steps: Optional[int] = None,
logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH,
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
show_progress: bool = True,
Expand Down Expand Up @@ -648,7 +638,6 @@ def fit_online(
directory name.
logging_steps: Number of steps to log metrics. This will be ignored
if logging_strategy is EPOCH.
gradient_logging_steps: Number of steps to log gradients.
logging_strategy: Logging strategy to use.
logger_adapter: LoggerAdapterFactory object.
show_progress: Flag to show progress bar for iterations.
Expand All @@ -663,15 +652,6 @@ def fit_online(
# check action-space
assert_action_space_with_env(self, env)

# setup logger
if experiment_name is None:
experiment_name = self.__class__.__name__ + "_online"
logger = D3RLPyLogger(
adapter_factory=logger_adapter,
experiment_name=experiment_name,
with_timestamp=with_timestamp,
)

# initialize algorithm parameters
build_scalers_with_env(self, env)

Expand All @@ -683,11 +663,20 @@ def fit_online(
else:
LOG.warning("Skip building models since they're already built.")

# setup logger
if experiment_name is None:
experiment_name = self.__class__.__name__ + "_online"
logger = D3RLPyLogger(
algo=self,
adapter_factory=logger_adapter,
experiment_name=experiment_name,
n_steps_per_epoch=n_steps_per_epoch,
with_timestamp=with_timestamp,
)

# save hyperparameters
save_config(self, logger)

logger.watch_model(0, 0, gradient_logging_steps, self)

# switch based on show_progress flag
xrange = trange if show_progress else range

Expand Down Expand Up @@ -756,10 +745,6 @@ def fit_online(
for name, val in loss.items():
logger.add_metric(name, val)

logger.watch_model(
epoch, total_step, gradient_logging_steps, self
)

if (
logging_strategy == LoggingStrategy.STEPS
and total_step % logging_steps == 0
Expand Down
20 changes: 11 additions & 9 deletions d3rlpy/algos/transformer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,15 +418,6 @@ def fit(
# initialize scalers
build_scalers_with_trajectory_slicer(self, dataset)

# setup logger
if experiment_name is None:
experiment_name = self.__class__.__name__
logger = D3RLPyLogger(
adapter_factory=logger_adapter,
experiment_name=experiment_name,
with_timestamp=with_timestamp,
)

# instantiate implementation
if self._impl is None:
LOG.debug("Building models...")
Expand All @@ -441,6 +432,17 @@ def fit(
else:
LOG.warning("Skip building models since they're already built.")

# setup logger
if experiment_name is None:
experiment_name = self.__class__.__name__
logger = D3RLPyLogger(
algo=self,
adapter_factory=logger_adapter,
experiment_name=experiment_name,
n_steps_per_epoch=n_steps_per_epoch,
with_timestamp=with_timestamp,
)

# save hyperparameters
save_config(self, logger)

Expand Down
38 changes: 20 additions & 18 deletions d3rlpy/logging/file_adapter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
from enum import Enum, IntEnum
from typing import Any, Dict, Optional
from typing import Any, Dict

import numpy as np

Expand Down Expand Up @@ -36,12 +36,15 @@ class FileAdapter(LoggerAdapter):
models as d3 files.
Args:
algo: Algorithm.
logdir (str): Log directory.
"""

_algo: AlgProtocol
_logdir: str

def __init__(self, logdir: str):
def __init__(self, algo: AlgProtocol, logdir: str):
self._algo = algo
self._logdir = logdir
if not os.path.exists(self._logdir):
os.makedirs(self._logdir)
Expand Down Expand Up @@ -86,21 +89,18 @@ def watch_model(
self,
epoch: int,
step: int,
logging_steps: Optional[int],
algo: AlgProtocol,
) -> None:
assert algo.impl
if logging_steps is not None and step % logging_steps == 0:
for name, grad in algo.impl.modules.get_gradients():
path = os.path.join(self._logdir, f"{name}_grad.csv")
with open(path, "a") as f:
min_grad = grad.min()
max_grad = grad.max()
mean_grad = grad.mean()
print(
f"{epoch},{step},{name},{min_grad},{max_grad},{mean_grad}",
file=f,
)
assert self._algo.impl
for name, grad in self._algo.impl.modules.get_gradients():
path = os.path.join(self._logdir, f"{name}_grad.csv")
with open(path, "a") as f:
min_grad = grad.min()
max_grad = grad.max()
mean_grad = grad.mean()
print(

Check warning on line 100 in d3rlpy/logging/file_adapter.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/logging/file_adapter.py#L93-L100

Added lines #L93 - L100 were not covered by tests
f"{epoch},{step},{min_grad},{max_grad},{mean_grad}",
file=f,
)


class FileAdapterFactory(LoggerAdapterFactory):
Expand All @@ -118,6 +118,8 @@ class FileAdapterFactory(LoggerAdapterFactory):
def __init__(self, root_dir: str = "d3rlpy_logs"):
self._root_dir = root_dir

def create(self, experiment_name: str) -> FileAdapter:
def create(
self, algo: AlgProtocol, experiment_name: str, n_steps_per_epoch: int
) -> FileAdapter:
logdir = os.path.join(self._root_dir, experiment_name)
return FileAdapter(logdir)
return FileAdapter(algo, logdir)
30 changes: 15 additions & 15 deletions d3rlpy/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,52 +112,58 @@ def watch_model(
self,
epoch: int,
step: int,
logging_steps: Optional[int],
algo: AlgProtocol,
) -> None:
r"""Watch model parameters / gradients during training.
Args:
epoch: Epoch.
step: Training step.
logging_steps: Training step.
algo: Algorithm.
"""


class LoggerAdapterFactory(Protocol):
r"""Interface of LoggerAdapterFactory."""

def create(self, experiment_name: str) -> LoggerAdapter:
def create(
self, algo: AlgProtocol, experiment_name: str, n_steps_per_epoch: int
) -> LoggerAdapter:
r"""Creates LoggerAdapter.
This method instantiates ``LoggerAdapter`` with a given
``experiment_name``.
This method is usually called at the beginning of training.
Args:
algo: Algorithm.
experiment_name: Experiment name.
steps_per_epoch: Number of steps per epoch.
"""
raise NotImplementedError


class D3RLPyLogger:
_algo: AlgProtocol
_adapter: LoggerAdapter
_experiment_name: str
_metrics_buffer: DefaultDict[str, List[float]]

def __init__(
self,
algo: AlgProtocol,
adapter_factory: LoggerAdapterFactory,
experiment_name: str,
n_steps_per_epoch: int,
with_timestamp: bool = True,
):
if with_timestamp:
date = datetime.now().strftime("%Y%m%d%H%M%S")
self._experiment_name = experiment_name + "_" + date
else:
self._experiment_name = experiment_name
self._adapter = adapter_factory.create(self._experiment_name)
self._algo = algo
self._adapter = adapter_factory.create(
algo, self._experiment_name, n_steps_per_epoch
)
self._metrics_buffer = defaultdict(list)

def add_params(self, params: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -185,6 +191,9 @@ def commit(self, epoch: int, step: int) -> Dict[str, float]:

self._adapter.after_write_metric(epoch, step)

# save model parameter metrics
self._adapter.watch_model(epoch, step)

# initialize metrics buffer
self._metrics_buffer.clear()
return metrics
Expand All @@ -207,12 +216,3 @@ def measure_time(self, name: str) -> Iterator[None]:
@property
def adapter(self) -> LoggerAdapter:
return self._adapter

def watch_model(
self,
epoch: int,
step: int,
logging_steps: Optional[int],
algo: AlgProtocol,
) -> None:
self._adapter.watch_model(epoch, step, logging_steps, algo)
8 changes: 4 additions & 4 deletions d3rlpy/logging/noop_adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any, Dict

from .logger import (
AlgProtocol,
Expand Down Expand Up @@ -41,8 +41,6 @@ def watch_model(
self,
epoch: int,
step: int,
logging_steps: Optional[int],
algo: AlgProtocol,
) -> None:
pass

Expand All @@ -53,5 +51,7 @@ class NoopAdapterFactory(LoggerAdapterFactory):
This class instantiates ``NoopAdapter`` object.
"""

def create(self, experiment_name: str) -> NoopAdapter:
def create(
self, algo: AlgProtocol, experiment_name: str, n_steps_per_epoch: int
) -> NoopAdapter:
return NoopAdapter()
Loading

0 comments on commit 19064fa

Please sign in to comment.