Skip to content

Commit

Permalink
[Feature] Pass replay buffers to SyncDataCollector
Browse files Browse the repository at this point in the history
ghstack-source-id: d4949410af9604e64c4d179608ebec7377710758
Pull Request resolved: #2384
  • Loading branch information
vmoens committed Aug 9, 2024
1 parent a6310ae commit 6d3421b
Showing 1 changed file with 38 additions and 12 deletions.
50 changes: 38 additions & 12 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
VERBOSE,
)
from torchrl.collectors.utils import split_trajectories
from torchrl.data import ReplayBuffer
from torchrl.data.tensor_specs import TensorSpec
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
from torchrl.envs.common import _do_nothing, EnvBase
Expand Down Expand Up @@ -357,6 +358,8 @@ class SyncDataCollector(DataCollectorBase):
use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data.
This isn't compatible with environments with dynamic specs. Defaults to ``True``
for envs without dynamic specs, ``False`` for others.
replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordict
but populate the buffer instead. Defaults to ``None``.
Examples:
>>> from torchrl.envs.libs.gym import GymEnv
Expand Down Expand Up @@ -446,6 +449,7 @@ def __init__(
interruptor=None,
set_truncated: bool = False,
use_buffers: bool | None = None,
replay_buffer: ReplayBuffer | None = None,
):
from torchrl.envs.batched_envs import BatchedEnvBase

Expand Down Expand Up @@ -538,9 +542,17 @@ def __init__(

self.env: EnvBase = env
del env
self.replay_buffer = replay_buffer
if self.replay_buffer is not None:
if postproc is not None:
raise TypeError("postproc must be None when a replay buffer is passed.")
if use_buffers:
raise TypeError("replay_buffer is exclusive with use_buffers.")
if use_buffers is None:
use_buffers = not self.env._has_dynamic_specs
use_buffers = not self.env._has_dynamic_specs and self.replay_buffer is None
self._use_buffers = use_buffers
self.replay_buffer = replay_buffer

self.closed = False
if not reset_when_done:
raise ValueError("reset_when_done is deprectated.")
Expand Down Expand Up @@ -873,6 +885,11 @@ def set_seed(self, seed: int, static_seed: bool = False) -> int:
"""
return self.env.set_seed(seed, static_seed=static_seed)

def _increment_frames(self, numel):
self._frames += numel
if self._frames >= self.total_frames:
self.env.close()

def iterator(self) -> Iterator[TensorDictBase]:
"""Iterates through the DataCollector.
Expand Down Expand Up @@ -917,14 +934,15 @@ def cuda_check(tensor: torch.Tensor):
for stream in streams:
stack.enter_context(torch.cuda.stream(stream))

total_frames = self.total_frames

while self._frames < self.total_frames:
self._iter += 1
tensordict_out = self.rollout()
self._frames += tensordict_out.numel()
if self._frames >= total_frames:
self.env.close()
if tensordict_out is None:
# if a replay buffer is passed, there is no tensordict_out
# frames are updated within the rollout function
yield
continue
self._increment_frames(tensordict_out.numel())

if self.split_trajs:
tensordict_out = split_trajectories(
Expand Down Expand Up @@ -1053,13 +1071,17 @@ def rollout(self) -> TensorDictBase:
next_data.clear_device_()
self._shuttle.set("next", next_data)

if self.storing_device is not None:
tensordicts.append(
self._shuttle.to(self.storing_device, non_blocking=True)
)
self._sync_storage()
if self.replay_buffer is not None:
self.replay_buffer.add(self._shuttle)
self._increment_frames(self._shuttle.numel())
else:
tensordicts.append(self._shuttle)
if self.storing_device is not None:
tensordicts.append(
self._shuttle.to(self.storing_device, non_blocking=True)
)
self._sync_storage()
else:
tensordicts.append(self._shuttle)

# carry over collector data without messing up devices
collector_data = self._shuttle.get("collector").copy()
Expand All @@ -1074,6 +1096,8 @@ def rollout(self) -> TensorDictBase:
self.interruptor is not None
and self.interruptor.collection_stopped()
):
if self.replay_buffer is not None:
return
result = self._final_rollout
if self._use_buffers:
try:
Expand Down Expand Up @@ -1109,6 +1133,8 @@ def rollout(self) -> TensorDictBase:
self._final_rollout.ndim - 1,
out=self._final_rollout,
)
elif self.replay_buffer is not None:
return
else:
result = TensorDict.maybe_dense_stack(tensordicts, dim=-1)
result.refine_names(..., "time")
Expand Down

0 comments on commit 6d3421b

Please sign in to comment.