diff --git a/integrations/models/notebooks/03_vbd_output_comparison.ipynb b/integrations/models/notebooks/03_vbd_output_comparison.ipynb
new file mode 100644
index 00000000..0b739865
--- /dev/null
+++ b/integrations/models/notebooks/03_vbd_output_comparison.ipynb
@@ -0,0 +1,362 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Comparing VBD Outputs: Waymax vs GPUDrive"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2024-11-18 17:53:30.179672: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
+ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
+ "E0000 00:00:1731970410.198982 123447 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
+ "E0000 00:00:1731970410.204524 123447 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
+ "2024-11-18 17:53:30.224247: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
+ "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%capture\n",
+ "import waymax\n",
+ "import numpy as np\n",
+ "import math\n",
+ "import mediapy\n",
+ "from tqdm import tqdm\n",
+ "import dataclasses\n",
+ "import os\n",
+ "import shutil\n",
+ "from pathlib import Path\n",
+ "import pickle\n",
+ "import pandas as pd\n",
+ "import tensorflow as tf\n",
+ "import matplotlib.pyplot as plt\n",
+ "import seaborn as sns\n",
+ "import torch\n",
+ "from waymax import config as _config\n",
+ "from waymax import dataloader, datatypes, visualization, dynamics\n",
+ "from waymax.datatypes.simulator_state import SimulatorState\n",
+ "from waymax.config import EnvironmentConfig, ObjectType\n",
+ "\n",
+ "# Set working directory to the base directory 'gpudrive'\n",
+ "working_dir = Path.cwd()\n",
+ "while working_dir.name != 'gpudrive':\n",
+ " working_dir = working_dir.parent\n",
+ " if working_dir == Path.home():\n",
+ " raise FileNotFoundError(\"Base directory 'gpudrive' not found\")\n",
+ "os.chdir(working_dir)\n",
+ "\n",
+ "# VBD dependencies\n",
+ "from integrations.models.vbd.sim_agent.waymax_env import WaymaxEnvironment\n",
+ "from integrations.models.vbd.data.dataset import WaymaxTestDataset\n",
+ "from integrations.models.vbd.waymax_visualization.plotting import plot_state\n",
+ "from integrations.models.vbd.sim_agent.sim_actor import VBDTest, sample_to_action\n",
+ "from integrations.models.vbd.model.utils import set_seed\n",
+ "\n",
+ "# GPUDrive dependencies\n",
+ "from pygpudrive.env.config import EnvConfig, RenderConfig, SceneConfig, SelectionDiscipline\n",
+ "from pygpudrive.env.env_torch import GPUDriveTorchEnv\n",
+ "\n",
+ "# Plotting\n",
+ "sns.set(\"notebook\")\n",
+ "sns.set_style(\"ticks\", rc={\"figure.facecolor\": \"none\", \"axes.facecolor\": \"none\"})\n",
+ "#%config InlineBackend.figure_format = 'svg'\n",
+ "\n",
+ "# Ignore all warnings\n",
+ "import warnings\n",
+ "warnings.filterwarnings(\"ignore\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Configurations"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "DATA_DIR = 'data/processed/debug' # Base data path\n",
+ "CKPT_DIR = 'data/checkpoints' # Base checkpoint path\n",
+ "CKPT_PATH = 'integrations/models/vbd/weights/epoch=18.ckpt'\n",
+ "\n",
+ "SCENARIO_IDS = []\n",
+ "for filename in os.listdir(os.path.join(DATA_DIR, 'gpudrive')):\n",
+ " if filename.endswith('.json') and '_' in filename:\n",
+ " # Extract the part after the last underscore and before .json\n",
+ " scenario_id = filename.rsplit('_', 1)[-1].replace('.json', '')\n",
+ " SCENARIO_IDS.append(scenario_id)\n",
+ "\n",
+ "FPS = 20\n",
+ "INIT_STEPS = 11 # Warmup period\n",
+ "MAX_CONTROLLED_OBJECTS = 32"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Load pre-trained VBD model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Load model\n",
+ "model = VBDTest.load_from_checkpoint(CKPT_PATH, torch.device('cpu'))\n",
+ "_ = model.cpu()\n",
+ "_ = model.eval();\n",
+ "\n",
+ "# Model settings\n",
+ "replan_freq=80 # Roll out every X steps 80 means openloop\n",
+ "model.early_stop=0 # Stop Diffusion Early From 100 to X\n",
+ "model.skip = 1 # Skip Alpha \n",
+ "model.reward_func = None\n",
+ "\n",
+ "# Ensure reproducability\n",
+ "#set_seed(5)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
+ "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
+ "\u001b[1;31mClick here for more info. \n",
+ "\u001b[1;31mView Jupyter log for further details."
+ ]
+ }
+ ],
+ "source": [
+ "#Function to save frames side by side.\n",
+ "def save_side_by_side_gif(frames1, frames2, output_path, fps=FPS):\n",
+ " \"\"\"\n",
+ " Saves two arrays of frames as a single side-by-side GIF.\n",
+ "\n",
+ " Parameters:\n",
+ " frames1 (list of np.ndarray): First array of frames (images).\n",
+ " frames2 (list of np.ndarray): Second array of frames (images).\n",
+ " output_path (str): Path to save the output GIF.\n",
+ " fps (int): Frames per second for the GIF.\n",
+ " \"\"\"\n",
+ " # Ensure both arrays have the same number of frames\n",
+ " if len(frames1) != len(frames2):\n",
+ " raise ValueError(\"The two frame arrays must have the same number of frames.\")\n",
+ " \n",
+ " # Combine frames side by side\n",
+ " combined_frames = []\n",
+ " for frame1, frame2 in zip(frames1, frames2):\n",
+ " # Ensure frames have the same height\n",
+ " if frame1.shape[0] != frame2.shape[0]:\n",
+ " raise ValueError(\"Frames must have the same height to combine them side by side.\")\n",
+ " \n",
+ " # Concatenate frames horizontally\n",
+ " combined_frame = np.hstack((frame1, frame2))\n",
+ " combined_frames.append(combined_frame)\n",
+ " \n",
+ " # Save the combined frames as a GIF\n",
+ " mediapy.write_video(output_path, combined_frames, fps=fps, codec=\"gif\")\n",
+ " print(f\"GIF saved at {output_path}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Make Videos"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Found 55 scenarios\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Diffusion: 100%|██████████| 50/50 [00:18<00:00, 2.67it/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Create waymax test \"dataset\" obj (need for utils)\n",
+ "dataset = WaymaxTestDataset(\n",
+ " data_dir = 'data/processed/debug/waymax', \n",
+ " anchor_path = 'integrations/models/vbd/data/cluster_64_center_dict.pkl',\n",
+ " max_object=MAX_CONTROLLED_OBJECTS,\n",
+ ")\n",
+ "\n",
+ "for id in SCENARIO_IDS[9:]:\n",
+ " #Init Waymax env\n",
+ " env_config = EnvironmentConfig(\n",
+ " controlled_object=ObjectType.VALID,\n",
+ " allow_new_objects_after_warmup=False,\n",
+ " init_steps=INIT_STEPS+1,\n",
+ " max_num_objects=MAX_CONTROLLED_OBJECTS,\n",
+ " )\n",
+ " waymax_env = WaymaxEnvironment(\n",
+ " dynamics_model=dynamics.StateDynamics(),\n",
+ " config=env_config,\n",
+ " log_replay = True,\n",
+ " )\n",
+ "\n",
+ " scenario_path = os.path.join(DATA_DIR, scenario_id + '.pkl')\n",
+ " with open(f'{DATA_DIR}/waymax/waymax_scenario_{id}.pkl', 'rb') as f:\n",
+ " scenario = pickle.load(f)\n",
+ "\n",
+ " #Generate video with vbd trajectories\n",
+ " init_state = waymax_env.reset(scenario)\n",
+ " current_state = init_state\n",
+ " sample = dataset.process_scenario(\n",
+ " init_state,\n",
+ " current_index=init_state.timestep,\n",
+ " use_log=False\n",
+ " )\n",
+ " is_controlled = sample['agents_interested'] > 0\n",
+ " selected_agents = sample['agents_id'][is_controlled]\n",
+ " state_logs = [current_state]\n",
+ "\n",
+ " for i in range(current_state.remaining_timesteps):\n",
+ " t = i % replan_freq\n",
+ " if t == 0:\n",
+ " sample = dataset.process_scenario(\n",
+ " current_state, \n",
+ " current_index = current_state.timestep,\n",
+ " use_log=False,\n",
+ " selected_agents=selected_agents, # override the agent selection by distance to the ego\n",
+ " )\n",
+ " batch = dataset.__collate_fn__([sample])\n",
+ " pred = model.sample_denoiser(batch)\n",
+ " traj_pred = pred['denoised_trajs'].cpu().numpy()[0]\n",
+ "\n",
+ " # Get action\n",
+ " action_sample = traj_pred[:, t, :]\n",
+ " action = sample_to_action(\n",
+ " action_sample, \n",
+ " is_controlled, \n",
+ " agents_id=selected_agents, \n",
+ " max_num_objects=MAX_CONTROLLED_OBJECTS\n",
+ " )\n",
+ " # Step the simulator\n",
+ " current_state = waymax_env.step_sim_agent(current_state, [action])\n",
+ " state_logs.append(current_state)\n",
+ "\n",
+ " waymax_frames = [plot_state(state) for state in state_logs]\n",
+ "\n",
+ " #Init GPUDrive env\n",
+ " env_config = EnvConfig(\n",
+ " init_steps=INIT_STEPS, # Warmup period\n",
+ " remove_non_vehicles=False, # Control vehicles, pedestrians, and cyclists\n",
+ " return_vbd_data=True, # Use VBD\n",
+ " dynamics_model=\"state\", # Use state-based dynamics model\n",
+ " dist_to_goal_threshold=1e-5, # Trick to make sure the agents don't disappear when they reach the goal\n",
+ " collision_behavior=\"ignore\", # Ignore collisions|\n",
+ " )\n",
+ "\n",
+ " source = os.path.join(DATA_DIR, 'gpudrive')\n",
+ " file_target = f\"_{id}.json\"\n",
+ " for filename in os.listdir(source):\n",
+ " if filename.endswith(file_target):\n",
+ " source_path = os.path.join(source, filename)\n",
+ " dest_path = os.path.join(DATA_DIR, 'active', filename)\n",
+ " shutil.move(source_path, dest_path)\n",
+ " break\n",
+ " \n",
+ " # Make env\n",
+ " gpudrive_env = GPUDriveTorchEnv(\n",
+ " config=env_config,\n",
+ " scene_config=SceneConfig(path=\"data/processed/debug/active\", num_scenes=1),\n",
+ " render_config=RenderConfig(draw_obj_idx=True, render_init=True, resolution=(400, 400)),\n",
+ " max_cont_agents=MAX_CONTROLLED_OBJECTS, # Maximum number of agents to control per scene\n",
+ " device=\"cpu\",\n",
+ " )\n",
+ " shutil.move(dest_path, source_path)\n",
+ " \n",
+ " #Generate video with VBD trajectories \n",
+ " init_state = gpudrive_env.reset()\n",
+ " selected_agents = torch.nonzero(gpudrive_env.cont_agent_mask[0, :]).flatten().tolist()\n",
+ " # Obtain all info for diffusion model (warmup)\n",
+ " gpudrive_sample_batch = gpudrive_env.sample_batch\n",
+ " # Obtain predicted trajectories\n",
+ " pred = model.sample_denoiser(gpudrive_sample_batch)#, x_t=x_t)\n",
+ " vbd_traj_pred = pred['denoised_trajs'].cpu().numpy()[0]\n",
+ " is_controlled = gpudrive_sample_batch['agents_interested'] > 0\n",
+ "\n",
+ " pred_trajs = torch.zeros((MAX_CONTROLLED_OBJECTS, env_config.episode_len-INIT_STEPS, 10))\n",
+ " pred_trajs[:, :, :2] = torch.Tensor(vbd_traj_pred[:, :, 0:2]) # pos x, y\n",
+ " pred_trajs[:, :, :2] -= gpudrive_env.sim.world_means_tensor().to_torch()[0, :2] #re-mean the predicted trajectory positions\n",
+ " pred_trajs[:, :, 3] = torch.Tensor(vbd_traj_pred[:, :, 2]) # yaw \n",
+ " pred_trajs[:, :, 4:6] = torch.Tensor(vbd_traj_pred[:, :, 3:5]) # vel x, y\n",
+ " pred_trajs = pred_trajs.unsqueeze(0)\n",
+ "\n",
+ " gpudrive_frames = []\n",
+ "\n",
+ " # Step\n",
+ " for t in range(env_config.episode_len-INIT_STEPS):\n",
+ " gpudrive_env.step_dynamics(pred_trajs[:, :, t, :])\n",
+ " gpudrive_frames.append(np.rot90(gpudrive_env.render(), k=3))\n",
+ "\n",
+ " save_side_by_side_gif(waymax_frames, gpudrive_frames, f'integrations/models/videos/{id}.gif')\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "gpudrive",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}