Skip to content

Commit

Permalink
Add VBD predicted outputs to notebooks for inspection
Browse files Browse the repository at this point in the history
  • Loading branch information
daphne-cornelisse committed Nov 4, 2024
1 parent 0e9b33b commit 60abc1b
Show file tree
Hide file tree
Showing 6 changed files with 276 additions and 50 deletions.
93 changes: 69 additions & 24 deletions notebooks/00_align_simulators_vbd.ipynb

Large diffs are not rendered by default.

204 changes: 184 additions & 20 deletions notebooks/01_features_deepdive.ipynb

Large diffs are not rendered by default.

Binary file modified notebooks/gpudrive_vbd_sample_11671609ebfa3185.pkl
Binary file not shown.
Binary file modified notebooks/waymax_vbd_sample_11671609ebfa3185.pkl
Binary file not shown.
11 changes: 8 additions & 3 deletions vbd/model/VBD.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
#import lightning.pytorch as pl

# import lightning.pytorch as pl
import pytorch_lightning as pl

from .modules import Encoder, Denoiser, GoalPredictor
Expand Down Expand Up @@ -155,7 +156,11 @@ def forward(self, inputs, noised_actions_normalized, diffusion_step):
return output_dict

def forward_denoiser(
self, encoder_outputs, noised_actions_normalized, diffusion_step
self,
encoder_outputs,
noised_actions_normalized,
diffusion_step,
global_frame=True,
):
"""
Forward pass of the denoiser module.
Expand Down Expand Up @@ -185,7 +190,7 @@ def forward_denoiser(
current_states,
denoised_actions,
action_len=self.denoiser._action_len,
global_frame=True,
global_frame=global_frame,
)

return {
Expand Down
18 changes: 15 additions & 3 deletions vbd/sim_agent/sim_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,9 @@ def ibr_guidance(
return denoiser_output, x_t_prev, guide_history

################### Denoising ###################
def step_denoiser(self, x_t: torch.Tensor, c: dict, t: int):
def step_denoiser(
self, x_t: torch.Tensor, c: dict, t: int, global_frame: bool = True
):
"""
Perform a denoising step to sample x_{t-1} ~ P[x_{t-1} | x_t, D(x_t, c, t)].
Expand All @@ -538,7 +540,10 @@ def step_denoiser(self, x_t: torch.Tensor, c: dict, t: int):

# Denoise to reconstruct x_0 ~ D(x_t, c, t)
denoiser_output = self.forward_denoiser(
encoder_outputs=c, noised_actions_normalized=x_t, diffusion_step=t
encoder_outputs=c,
noised_actions_normalized=x_t,
diffusion_step=t,
global_frame=global_frame,
)

x_0 = denoiser_output["denoised_actions_normalized"]
Expand All @@ -554,7 +559,13 @@ def step_denoiser(self, x_t: torch.Tensor, c: dict, t: int):

@torch.no_grad()
def sample_denoiser(
self, batch, num_samples=1, x_t=None, use_tqdm=True, **kwargs
self,
batch,
num_samples=1,
x_t=None,
use_tqdm=True,
global_frame=True,
**kwargs
):
"""
Perform denoising inference on the given batch of data.
Expand Down Expand Up @@ -626,6 +637,7 @@ def sample_denoiser(
x_t=x_t,
c=encoder_outputs,
t=t,
global_frame=global_frame,
)
guide = None

Expand Down

0 comments on commit 60abc1b

Please sign in to comment.