Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First draft for modular Hindsight Experience Replay Transform #2667

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

dtsaras
Copy link

@dtsaras dtsaras commented Dec 19, 2024

Description

I have added the Hindsight Experience Replay Transform specifically implementing the future and last strategy as described in the paper. The transform is a combination of 3 transforms:

  • HERSubGoalSampler: It's responsible for sampling indexes for the subgoals and can be changed with another subgoal sampling method for the specific use case.
  • HERSubGoalAssigner: It's the method responsible for creating new trajectories given the subgoal indices.
  • HERRewardTransform: While it might not necessarily need to be a separate transform yet, it's the method responsible for reassigning the rewards to the newly generated trajectories.

Motivation and Context

It's a modular implementation for hindsight experience replay as requested #1819

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

Copy link

pytorch-bot bot commented Dec 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2667

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 19, 2024
@vmoens vmoens added the enhancement New feature or request label Jan 8, 2025
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me! Thanks for working on this.

I'd love to see some tests to understand better how this all works.
I left some comments here and there, mostly about formatting and high-level design decisions. Happy to give it a more thorough technical look later once there's an example of how to run it and/or some tests to rely upon.

Comment on lines +9271 to +9275
"""Returns a TensorDict with a key `subgoal_idx` of shape [batch_size, num_samples] represebting the subgoal index.
Available strategies are: `last` and `future`. The `last` strategy assigns the last state as the subgoal. The `future` strategy samples up to `num_samples` subgoal from the future states.
Args:
num_samples (int): Number of subgoals to sample from each trajectory. Defaults to 4.
out_keys (str): The key to store the subgoal index. Defaults to "subgoal_idx".
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Returns a TensorDict with a key `subgoal_idx` of shape [batch_size, num_samples] represebting the subgoal index.
Available strategies are: `last` and `future`. The `last` strategy assigns the last state as the subgoal. The `future` strategy samples up to `num_samples` subgoal from the future states.
Args:
num_samples (int): Number of subgoals to sample from each trajectory. Defaults to 4.
out_keys (str): The key to store the subgoal index. Defaults to "subgoal_idx".
"""Returns a TensorDict with a key `subgoal_idx` of shape [batch_size, num_samples] represebting the subgoal index.
Available strategies are: `last` and `future`. The `last` strategy assigns the last state as the subgoal. The `future` strategy samples up to `num_samples` subgoal from the future states.
Args:
num_samples (int): Number of subgoals to sample from each trajectory. Defaults to 4.
==> NOT PRESENT: out_keys (str): The key to store the subgoal index. Defaults to "subgoal_idx".
subgoal_idx_key: TODO
strategy: TODO

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a .. seealso:: with other related classes.

Comment on lines +9280 to +9281
subgoal_idx_key: str = "subgoal_idx",
strategy: str = "future"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
subgoal_idx_key: str = "subgoal_idx",
strategy: str = "future"
subgoal_idx_key: NestedKey = "subgoal_idx",
strategy: NestedKey = "future"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe let's create a dedicated file for these?

self.strategy = strategy

def forward(self, trajectories: TensorDictBase) -> TensorDictBase:
if len(trajectories.shape) == 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if 0 or greater than 2, raise an error?

if len(trajectories.shape) == 1:
trajectories = trajectories.unsqueeze(0)

batch_size, trajectory_len = trajectories.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe

Suggested change
batch_size, trajectory_len = trajectories.shape
*batch_size, trajectory_len = trajectories.shape

to account for batch size > 2

Comment on lines +9366 to +9375
SubGoalSampler: Transform = HERSubGoalSampler(),
SubGoalAssigner: Transform = HERSubGoalAssigner(),
RewardTransform: Transform = HERRewardTransform(),
assign_subgoal_idxs: bool = False,
):
super().__init__(
in_keys=None,
in_keys_inv=None,
out_keys_inv=None,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
SubGoalSampler: Transform = HERSubGoalSampler(),
SubGoalAssigner: Transform = HERSubGoalAssigner(),
RewardTransform: Transform = HERRewardTransform(),
assign_subgoal_idxs: bool = False,
):
super().__init__(
in_keys=None,
in_keys_inv=None,
out_keys_inv=None,
)
SubGoalSampler: Transform | None = None,
SubGoalAssigner: Transform | None = None,
RewardTransform: Transform | None= None,
assign_subgoal_idxs: bool = False,
):
if SubGoalSampler is None:
SubGoalSampler = HERSubGoalSampler()
if SubGoalAssigner is None:
SubGoalAssigner = HERSubGoalAssigner()
if HERRewardTransform is None:
HERRewardTransform = HERRewardTransform()
super().__init__(
in_keys=None,
in_keys_inv=None,
out_keys_inv=None,
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No PascalCase but snake_case for instantiated classes

Comment on lines +9376 to +9378
self.SubGoalSampler = SubGoalSampler
self.SubGoalAssigner = SubGoalAssigner
self.RewardTransform = RewardTransform
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

def her_augmentation(self, trajectories: TensorDictBase):
if len(trajectories.shape) == 1:
trajectories = trajectories.unsqueeze(0)
batch_size, trajectory_length = trajectories.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe

Suggested change
batch_size, trajectory_length = trajectories.shape
*batch_size, trajectory_length = trajectories.shape

# Create new trajectories
augmented_trajectories = []
list_idxs = []
for i in range(batch_size):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for i in range(batch_size):
for i in range(batch_size.numel()):

which also works with batch_size=torch.Size([])!

return trajectories


class HindsightExperienceReplayTransform(Transform):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to modify the specs?
Does this work with replay buffer (static data) or only envs? If the latter, we should not be using forward.

If you look at Compose, there are a bunch of things that need to be implemented when nesting transforms, like clone, cache eraser etc.

Perhaps we could inherit from Compose and rewrite forward, _apply_transform, _call, _reset etc such that the logic hold but the extra features are included automatically?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants