-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
First draft for modular Hindsight Experience Replay Transform #2667
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2667
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense to me! Thanks for working on this.
I'd love to see some tests to understand better how this all works.
I left some comments here and there, mostly about formatting and high-level design decisions. Happy to give it a more thorough technical look later once there's an example of how to run it and/or some tests to rely upon.
"""Returns a TensorDict with a key `subgoal_idx` of shape [batch_size, num_samples] represebting the subgoal index. | ||
Available strategies are: `last` and `future`. The `last` strategy assigns the last state as the subgoal. The `future` strategy samples up to `num_samples` subgoal from the future states. | ||
Args: | ||
num_samples (int): Number of subgoals to sample from each trajectory. Defaults to 4. | ||
out_keys (str): The key to store the subgoal index. Defaults to "subgoal_idx". |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""Returns a TensorDict with a key `subgoal_idx` of shape [batch_size, num_samples] represebting the subgoal index. | |
Available strategies are: `last` and `future`. The `last` strategy assigns the last state as the subgoal. The `future` strategy samples up to `num_samples` subgoal from the future states. | |
Args: | |
num_samples (int): Number of subgoals to sample from each trajectory. Defaults to 4. | |
out_keys (str): The key to store the subgoal index. Defaults to "subgoal_idx". | |
"""Returns a TensorDict with a key `subgoal_idx` of shape [batch_size, num_samples] represebting the subgoal index. | |
Available strategies are: `last` and `future`. The `last` strategy assigns the last state as the subgoal. The `future` strategy samples up to `num_samples` subgoal from the future states. | |
Args: | |
num_samples (int): Number of subgoals to sample from each trajectory. Defaults to 4. | |
==> NOT PRESENT: out_keys (str): The key to store the subgoal index. Defaults to "subgoal_idx". | |
subgoal_idx_key: TODO | |
strategy: TODO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add a .. seealso::
with other related classes.
subgoal_idx_key: str = "subgoal_idx", | ||
strategy: str = "future" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
subgoal_idx_key: str = "subgoal_idx", | |
strategy: str = "future" | |
subgoal_idx_key: NestedKey = "subgoal_idx", | |
strategy: NestedKey = "future" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe let's create a dedicated file for these?
self.strategy = strategy | ||
|
||
def forward(self, trajectories: TensorDictBase) -> TensorDictBase: | ||
if len(trajectories.shape) == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if 0 or greater than 2, raise an error?
if len(trajectories.shape) == 1: | ||
trajectories = trajectories.unsqueeze(0) | ||
|
||
batch_size, trajectory_len = trajectories.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe
batch_size, trajectory_len = trajectories.shape | |
*batch_size, trajectory_len = trajectories.shape |
to account for batch size > 2
SubGoalSampler: Transform = HERSubGoalSampler(), | ||
SubGoalAssigner: Transform = HERSubGoalAssigner(), | ||
RewardTransform: Transform = HERRewardTransform(), | ||
assign_subgoal_idxs: bool = False, | ||
): | ||
super().__init__( | ||
in_keys=None, | ||
in_keys_inv=None, | ||
out_keys_inv=None, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SubGoalSampler: Transform = HERSubGoalSampler(), | |
SubGoalAssigner: Transform = HERSubGoalAssigner(), | |
RewardTransform: Transform = HERRewardTransform(), | |
assign_subgoal_idxs: bool = False, | |
): | |
super().__init__( | |
in_keys=None, | |
in_keys_inv=None, | |
out_keys_inv=None, | |
) | |
SubGoalSampler: Transform | None = None, | |
SubGoalAssigner: Transform | None = None, | |
RewardTransform: Transform | None= None, | |
assign_subgoal_idxs: bool = False, | |
): | |
if SubGoalSampler is None: | |
SubGoalSampler = HERSubGoalSampler() | |
if SubGoalAssigner is None: | |
SubGoalAssigner = HERSubGoalAssigner() | |
if HERRewardTransform is None: | |
HERRewardTransform = HERRewardTransform() | |
super().__init__( | |
in_keys=None, | |
in_keys_inv=None, | |
out_keys_inv=None, | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No PascalCase
but snake_case
for instantiated classes
self.SubGoalSampler = SubGoalSampler | ||
self.SubGoalAssigner = SubGoalAssigner | ||
self.RewardTransform = RewardTransform |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
def her_augmentation(self, trajectories: TensorDictBase): | ||
if len(trajectories.shape) == 1: | ||
trajectories = trajectories.unsqueeze(0) | ||
batch_size, trajectory_length = trajectories.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe
batch_size, trajectory_length = trajectories.shape | |
*batch_size, trajectory_length = trajectories.shape |
# Create new trajectories | ||
augmented_trajectories = [] | ||
list_idxs = [] | ||
for i in range(batch_size): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for i in range(batch_size): | |
for i in range(batch_size.numel()): |
which also works with batch_size=torch.Size([])
!
return trajectories | ||
|
||
|
||
class HindsightExperienceReplayTransform(Transform): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we need to modify the specs?
Does this work with replay buffer (static data) or only envs? If the latter, we should not be using forward
.
If you look at Compose
, there are a bunch of things that need to be implemented when nesting transforms, like clone
, cache eraser etc.
Perhaps we could inherit from Compose and rewrite forward
, _apply_transform
, _call
, _reset
etc such that the logic hold but the extra features are included automatically?
Description
I have added the Hindsight Experience Replay Transform specifically implementing the
future
andlast
strategy as described in the paper. The transform is a combination of 3 transforms:Motivation and Context
It's a modular implementation for hindsight experience replay as requested #1819
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!