Skip to content

Commit

Permalink
rebase over main
Browse files Browse the repository at this point in the history
  • Loading branch information
daphne-cornelisse committed Dec 19, 2024
1 parent 49309cd commit f624eb0
Showing 1 changed file with 362 additions and 0 deletions.
362 changes: 362 additions & 0 deletions integrations/models/notebooks/03_vbd_output_comparison.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,362 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Comparing VBD Outputs: Waymax vs GPUDrive"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-11-18 17:53:30.179672: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
"E0000 00:00:1731970410.198982 123447 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"E0000 00:00:1731970410.204524 123447 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
"2024-11-18 17:53:30.224247: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
]
}
],
"source": [
"%%capture\n",
"import waymax\n",
"import numpy as np\n",
"import math\n",
"import mediapy\n",
"from tqdm import tqdm\n",
"import dataclasses\n",
"import os\n",
"import shutil\n",
"from pathlib import Path\n",
"import pickle\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"import torch\n",
"from waymax import config as _config\n",
"from waymax import dataloader, datatypes, visualization, dynamics\n",
"from waymax.datatypes.simulator_state import SimulatorState\n",
"from waymax.config import EnvironmentConfig, ObjectType\n",
"\n",
"# Set working directory to the base directory 'gpudrive'\n",
"working_dir = Path.cwd()\n",
"while working_dir.name != 'gpudrive':\n",
" working_dir = working_dir.parent\n",
" if working_dir == Path.home():\n",
" raise FileNotFoundError(\"Base directory 'gpudrive' not found\")\n",
"os.chdir(working_dir)\n",
"\n",
"# VBD dependencies\n",
"from integrations.models.vbd.sim_agent.waymax_env import WaymaxEnvironment\n",
"from integrations.models.vbd.data.dataset import WaymaxTestDataset\n",
"from integrations.models.vbd.waymax_visualization.plotting import plot_state\n",
"from integrations.models.vbd.sim_agent.sim_actor import VBDTest, sample_to_action\n",
"from integrations.models.vbd.model.utils import set_seed\n",
"\n",
"# GPUDrive dependencies\n",
"from pygpudrive.env.config import EnvConfig, RenderConfig, SceneConfig, SelectionDiscipline\n",
"from pygpudrive.env.env_torch import GPUDriveTorchEnv\n",
"\n",
"# Plotting\n",
"sns.set(\"notebook\")\n",
"sns.set_style(\"ticks\", rc={\"figure.facecolor\": \"none\", \"axes.facecolor\": \"none\"})\n",
"#%config InlineBackend.figure_format = 'svg'\n",
"\n",
"# Ignore all warnings\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Configurations"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"DATA_DIR = 'data/processed/debug' # Base data path\n",
"CKPT_DIR = 'data/checkpoints' # Base checkpoint path\n",
"CKPT_PATH = 'integrations/models/vbd/weights/epoch=18.ckpt'\n",
"\n",
"SCENARIO_IDS = []\n",
"for filename in os.listdir(os.path.join(DATA_DIR, 'gpudrive')):\n",
" if filename.endswith('.json') and '_' in filename:\n",
" # Extract the part after the last underscore and before .json\n",
" scenario_id = filename.rsplit('_', 1)[-1].replace('.json', '')\n",
" SCENARIO_IDS.append(scenario_id)\n",
"\n",
"FPS = 20\n",
"INIT_STEPS = 11 # Warmup period\n",
"MAX_CONTROLLED_OBJECTS = 32"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load pre-trained VBD model"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Load model\n",
"model = VBDTest.load_from_checkpoint(CKPT_PATH, torch.device('cpu'))\n",
"_ = model.cpu()\n",
"_ = model.eval();\n",
"\n",
"# Model settings\n",
"replan_freq=80 # Roll out every X steps 80 means openloop\n",
"model.early_stop=0 # Stop Diffusion Early From 100 to X\n",
"model.skip = 1 # Skip Alpha \n",
"model.reward_func = None\n",
"\n",
"# Ensure reproducability\n",
"#set_seed(5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
"\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
"\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
]
}
],
"source": [
"#Function to save frames side by side.\n",
"def save_side_by_side_gif(frames1, frames2, output_path, fps=FPS):\n",
" \"\"\"\n",
" Saves two arrays of frames as a single side-by-side GIF.\n",
"\n",
" Parameters:\n",
" frames1 (list of np.ndarray): First array of frames (images).\n",
" frames2 (list of np.ndarray): Second array of frames (images).\n",
" output_path (str): Path to save the output GIF.\n",
" fps (int): Frames per second for the GIF.\n",
" \"\"\"\n",
" # Ensure both arrays have the same number of frames\n",
" if len(frames1) != len(frames2):\n",
" raise ValueError(\"The two frame arrays must have the same number of frames.\")\n",
" \n",
" # Combine frames side by side\n",
" combined_frames = []\n",
" for frame1, frame2 in zip(frames1, frames2):\n",
" # Ensure frames have the same height\n",
" if frame1.shape[0] != frame2.shape[0]:\n",
" raise ValueError(\"Frames must have the same height to combine them side by side.\")\n",
" \n",
" # Concatenate frames horizontally\n",
" combined_frame = np.hstack((frame1, frame2))\n",
" combined_frames.append(combined_frame)\n",
" \n",
" # Save the combined frames as a GIF\n",
" mediapy.write_video(output_path, combined_frames, fps=fps, codec=\"gif\")\n",
" print(f\"GIF saved at {output_path}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Make Videos"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found 55 scenarios\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Diffusion: 100%|██████████| 50/50 [00:18<00:00, 2.67it/s]\n"
]
}
],
"source": [
"# Create waymax test \"dataset\" obj (need for utils)\n",
"dataset = WaymaxTestDataset(\n",
" data_dir = 'data/processed/debug/waymax', \n",
" anchor_path = 'integrations/models/vbd/data/cluster_64_center_dict.pkl',\n",
" max_object=MAX_CONTROLLED_OBJECTS,\n",
")\n",
"\n",
"for id in SCENARIO_IDS[9:]:\n",
" #Init Waymax env\n",
" env_config = EnvironmentConfig(\n",
" controlled_object=ObjectType.VALID,\n",
" allow_new_objects_after_warmup=False,\n",
" init_steps=INIT_STEPS+1,\n",
" max_num_objects=MAX_CONTROLLED_OBJECTS,\n",
" )\n",
" waymax_env = WaymaxEnvironment(\n",
" dynamics_model=dynamics.StateDynamics(),\n",
" config=env_config,\n",
" log_replay = True,\n",
" )\n",
"\n",
" scenario_path = os.path.join(DATA_DIR, scenario_id + '.pkl')\n",
" with open(f'{DATA_DIR}/waymax/waymax_scenario_{id}.pkl', 'rb') as f:\n",
" scenario = pickle.load(f)\n",
"\n",
" #Generate video with vbd trajectories\n",
" init_state = waymax_env.reset(scenario)\n",
" current_state = init_state\n",
" sample = dataset.process_scenario(\n",
" init_state,\n",
" current_index=init_state.timestep,\n",
" use_log=False\n",
" )\n",
" is_controlled = sample['agents_interested'] > 0\n",
" selected_agents = sample['agents_id'][is_controlled]\n",
" state_logs = [current_state]\n",
"\n",
" for i in range(current_state.remaining_timesteps):\n",
" t = i % replan_freq\n",
" if t == 0:\n",
" sample = dataset.process_scenario(\n",
" current_state, \n",
" current_index = current_state.timestep,\n",
" use_log=False,\n",
" selected_agents=selected_agents, # override the agent selection by distance to the ego\n",
" )\n",
" batch = dataset.__collate_fn__([sample])\n",
" pred = model.sample_denoiser(batch)\n",
" traj_pred = pred['denoised_trajs'].cpu().numpy()[0]\n",
"\n",
" # Get action\n",
" action_sample = traj_pred[:, t, :]\n",
" action = sample_to_action(\n",
" action_sample, \n",
" is_controlled, \n",
" agents_id=selected_agents, \n",
" max_num_objects=MAX_CONTROLLED_OBJECTS\n",
" )\n",
" # Step the simulator\n",
" current_state = waymax_env.step_sim_agent(current_state, [action])\n",
" state_logs.append(current_state)\n",
"\n",
" waymax_frames = [plot_state(state) for state in state_logs]\n",
"\n",
" #Init GPUDrive env\n",
" env_config = EnvConfig(\n",
" init_steps=INIT_STEPS, # Warmup period\n",
" remove_non_vehicles=False, # Control vehicles, pedestrians, and cyclists\n",
" return_vbd_data=True, # Use VBD\n",
" dynamics_model=\"state\", # Use state-based dynamics model\n",
" dist_to_goal_threshold=1e-5, # Trick to make sure the agents don't disappear when they reach the goal\n",
" collision_behavior=\"ignore\", # Ignore collisions|\n",
" )\n",
"\n",
" source = os.path.join(DATA_DIR, 'gpudrive')\n",
" file_target = f\"_{id}.json\"\n",
" for filename in os.listdir(source):\n",
" if filename.endswith(file_target):\n",
" source_path = os.path.join(source, filename)\n",
" dest_path = os.path.join(DATA_DIR, 'active', filename)\n",
" shutil.move(source_path, dest_path)\n",
" break\n",
" \n",
" # Make env\n",
" gpudrive_env = GPUDriveTorchEnv(\n",
" config=env_config,\n",
" scene_config=SceneConfig(path=\"data/processed/debug/active\", num_scenes=1),\n",
" render_config=RenderConfig(draw_obj_idx=True, render_init=True, resolution=(400, 400)),\n",
" max_cont_agents=MAX_CONTROLLED_OBJECTS, # Maximum number of agents to control per scene\n",
" device=\"cpu\",\n",
" )\n",
" shutil.move(dest_path, source_path)\n",
" \n",
" #Generate video with VBD trajectories \n",
" init_state = gpudrive_env.reset()\n",
" selected_agents = torch.nonzero(gpudrive_env.cont_agent_mask[0, :]).flatten().tolist()\n",
" # Obtain all info for diffusion model (warmup)\n",
" gpudrive_sample_batch = gpudrive_env.sample_batch\n",
" # Obtain predicted trajectories\n",
" pred = model.sample_denoiser(gpudrive_sample_batch)#, x_t=x_t)\n",
" vbd_traj_pred = pred['denoised_trajs'].cpu().numpy()[0]\n",
" is_controlled = gpudrive_sample_batch['agents_interested'] > 0\n",
"\n",
" pred_trajs = torch.zeros((MAX_CONTROLLED_OBJECTS, env_config.episode_len-INIT_STEPS, 10))\n",
" pred_trajs[:, :, :2] = torch.Tensor(vbd_traj_pred[:, :, 0:2]) # pos x, y\n",
" pred_trajs[:, :, :2] -= gpudrive_env.sim.world_means_tensor().to_torch()[0, :2] #re-mean the predicted trajectory positions\n",
" pred_trajs[:, :, 3] = torch.Tensor(vbd_traj_pred[:, :, 2]) # yaw \n",
" pred_trajs[:, :, 4:6] = torch.Tensor(vbd_traj_pred[:, :, 3:5]) # vel x, y\n",
" pred_trajs = pred_trajs.unsqueeze(0)\n",
"\n",
" gpudrive_frames = []\n",
"\n",
" # Step\n",
" for t in range(env_config.episode_len-INIT_STEPS):\n",
" gpudrive_env.step_dynamics(pred_trajs[:, :, t, :])\n",
" gpudrive_frames.append(np.rot90(gpudrive_env.render(), k=3))\n",
"\n",
" save_side_by_side_gif(waymax_frames, gpudrive_frames, f'integrations/models/videos/{id}.gif')\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "gpudrive",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit f624eb0

Please sign in to comment.