Skip to content

Commit

Permalink
Minor + increase obs norm constants
Browse files Browse the repository at this point in the history
  • Loading branch information
daphne-cornelisse committed Jan 14, 2025
1 parent e4e0f19 commit 19ce04f
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 22 deletions.
1 change: 0 additions & 1 deletion baselines/ippo/ippo_pufferlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,5 +339,4 @@ def run(


if __name__ == "__main__":

app()
2 changes: 1 addition & 1 deletion integrations/rl/puffer/puffer_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def step(self, action):
# (6) Get the next observations. Note that we do this after resetting
# the worlds so that we always return a fresh observation
next_obs = self.env.get_obs(self.controlled_agent_mask)

self.observations = next_obs
self.rewards = reward_controlled
self.terminals = terminal
Expand Down
8 changes: 5 additions & 3 deletions pygpudrive/env/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
MAX_VEH_LEN = 30
MAX_VEH_WIDTH = 15
MAX_VEH_HEIGHT = 5
MIN_REL_GOAL_COORD = -1000
MAX_REL_GOAL_COORD = 1000
MIN_REL_GOAL_COORD = (
-10_000
) # Note: (1000 or 100 should be enough, may indicate a bug or these are just from dead controlled agents)
MAX_REL_GOAL_COORD = 10_000
MIN_REL_AGENT_POS = -1000
MAX_REL_AGENT_POS = 1000
MAX_ORIENTATION_RAD = 2 * np.pi
Expand All @@ -21,5 +23,5 @@

# Feature shape constants
EGO_FEAT_DIM = 6
PARTNER_FEAT_DIM = 6
PARTNER_FEAT_DIM = 6
ROAD_GRAPH_FEAT_DIM = 13
2 changes: 0 additions & 2 deletions pygpudrive/env/env_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,6 @@ def _get_ego_state(self, mask) -> torch.Tensor:
if self.config.norm_obs:
ego_state.normalize()

# TODO: I deleted this permute. Was it needed?
# .permute(1, 2, 0)
return [ego_state.data]
"""
[
Expand Down
36 changes: 21 additions & 15 deletions pygpudrive/visualize/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
AGENT_COLOR_BY_STATE,
)

import pdb

OUT_OF_BOUNDS = 1000


Expand Down Expand Up @@ -209,12 +207,16 @@ def plot_simulator_state(

if time_step is not None and not eval_mode:
# Plot rollout statistics
num_controlled = controlled.sum().item()
num_controlled = controlled.sum().item()
num_off_road = is_offroad.sum().item()
num_collided = is_collided.sum().item()
off_road_rate = num_off_road / num_controlled if num_controlled > 0 else 0
collision_rate = num_collided / num_controlled if num_controlled > 0 else 0

off_road_rate = (
num_off_road / num_controlled if num_controlled > 0 else 0
)
collision_rate = (
num_collided / num_controlled if num_controlled > 0 else 0
)

ax.text(
0.5, # Horizontal center
0.95, # Vertical location near the top
Expand All @@ -229,19 +231,21 @@ def plot_simulator_state(
bbox=dict(facecolor="white", edgecolor="none", alpha=0.9),
)
else:
if eval_mode and results_df is not None:

num_controlled = results_df.iloc[env_idx].controlled_agents_in_scene
if eval_mode and results_df is not None:

num_controlled = results_df.iloc[
env_idx
].controlled_agents_in_scene
off_road_rate = results_df.iloc[env_idx].off_road * 100
collision_rate = results_df.iloc[env_idx].collided * 100
goal_rate = results_df.iloc[env_idx].goal_achieved * 100
other = results_df.iloc[env_idx].not_goal_nor_crashed * 100
other = results_df.iloc[env_idx].not_goal_nor_crashed * 100

ax.text(
0.5, # Horizontal center
0.95, # Vertical location near the top
f"$N_c$ = {num_controlled}; "
f"OR: {off_road_rate:.1f}; "
f"t = {time_step} | $N_c$ = {num_controlled}; "
f"OR: {off_road_rate:.1f}; "
f"CR: {collision_rate:.1f}; "
f"GR: {goal_rate:.1f}; "
f"Other: {other:.1f}",
Expand All @@ -250,7 +254,9 @@ def plot_simulator_state(
transform=ax.transAxes,
fontsize=20 * marker_scale,
color="black",
bbox=dict(facecolor="white", edgecolor="none", alpha=0.9),
bbox=dict(
facecolor="white", edgecolor="none", alpha=0.9
),
)

# Determine center point for zooming
Expand All @@ -268,7 +274,7 @@ def plot_simulator_state(
# Set zoom window around the center
ax.set_xlim(center_x - zoom_radius, center_x + zoom_radius)
ax.set_ylim(center_y - zoom_radius, center_y + zoom_radius)

# Remove ticks
ax.set_xticks([])
ax.set_yticks([])
Expand Down

0 comments on commit 19ce04f

Please sign in to comment.