Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
daphne-cornelisse committed Dec 28, 2024
2 parents 25ca980 + 1c77f0a commit 84c7453
Show file tree
Hide file tree
Showing 24 changed files with 394 additions and 188 deletions.
13 changes: 13 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# .env template

# Path for logs
LOG_FOLDER=

# Your HPC account code
NYU_HPC_ACCOUNT=

# NYU ID
USERNAME=

SINGULARITY_IMAGE=
OVERLAY_FILE=
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ celerybeat.pid
*.sage.py

# Environments
.env
.venv
venv/
ENV/
Expand Down Expand Up @@ -239,4 +238,11 @@ pyrightconfig.json

*~

# Environment variables
# To be manually created using .env.template
.env

# Logs
examples/experiments/scripts/logs/*

# End of https://www.toptal.com/developers/gitignore/api/python,c++
6 changes: 4 additions & 2 deletions baselines/ippo/config/ippo_ff_sb3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ collision_weight: 0.0
goal_achieved_weight: 1.0
off_road_weight: 0.0
remove_non_vehicles: false
polyline_reduction_threshold: 0.4
observation_radius: 60.0

resample_scenarios: false
resample_criterion: "global_step" # Options: "global_step"
Expand All @@ -19,8 +21,8 @@ resample_mode: "random" # Options: "random"

render: true
render_mode: "rgb_array"
render_freq: 50 # Render every k rollouts
render_n_worlds: 10 # Number of scenarios to render
render_freq: 100 # Render every k rollouts
render_n_worlds: 1 # Number of scenarios to render

track_time_to_solve: false

Expand Down
97 changes: 88 additions & 9 deletions baselines/ippo/ippo_pufferlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,50 +7,60 @@
"""

import os
from typing import Optional
from typing_extensions import Annotated
import yaml
from datetime import datetime
import torch
import wandb
from box import Box
from integrations.rl.puffer import ppo
from integrations.rl.puffer.puffer_env import env_creator
from integrations.rl.puffer.utils import Policy, LiDARPolicy
from integrations.rl.puffer.utils import Policy

import pufferlib
import pufferlib.vector
import pufferlib.frameworks.cleanrl
from rich.console import Console

import typer
from typer import Typer

app = Typer()


def load_config(config_path):
"""Load the configuration file."""
# fmt: off
with open(config_path, "r") as f:
config = Box(yaml.safe_load(f))

datetime_ = datetime.now().strftime("%m_%d_%H_%M_%S")

if config["train"]["resample_scenes"]:

if config["train"]["resample_scenes"]:
dataset_size = config["train"]["resample_dataset_size"]

config["train"]["exp_id"] = f'{config["train"]["exp_id"]}__R_{dataset_size}__{datetime_}'
else:
dataset_size = str(config["environment"]["k_unique_scenes"])
config["train"]["exp_id"] = f'{config["train"]["exp_id"]}__S_{dataset_size}__{datetime_}'

config["environment"]["dataset_size"] = dataset_size
config["train"]["device"] = config["train"].get("device", "cpu") # Default to 'cpu' if not set
if torch.cuda.is_available():
config["train"]["device"] = "cuda" # Set to 'cuda' if available
# fmt: on
return pufferlib.namespace(**config)


def make_policy(env):
"""Create a policy based on the environment."""
return pufferlib.frameworks.cleanrl.Policy(Policy(env))

def train(args):

def train(args, make_env):
"""Main training loop for the PPO agent."""
args.wandb = init_wandb(args, args.train.exp_id, id=args.train.exp_id)
args.train.__dict__.update(dict(args.wandb.config.train))
Expand Down Expand Up @@ -139,9 +149,74 @@ def sweep(args, project="PPO", sweep_name="my_sweep"):
wandb.agent(sweep_id, lambda: train(args), count=100)


if __name__ == "__main__":
@app.command()
def run(
config_path: Annotated[
str, typer.Argument(help="The path to the default configuration file")
] = "baselines/ippo/config/ippo_ff_puffer.yaml",
*,
# fmt: off
# Environment options
num_worlds: Annotated[Optional[int], typer.Option(help="Number of parallel envs")] = None,
k_unique_scenes: Annotated[Optional[int], typer.Option(help="The number of unique scenes to sample")] = None,
collision_weight: Annotated[Optional[float], typer.Option(help="The weight for collision penalty")] = None,
off_road_weight: Annotated[Optional[float], typer.Option(help="The weight for off-road penalty")] = None,
goal_achieved_weight: Annotated[Optional[float], typer.Option(help="The weight for goal-achieved reward")] = None,
dist_to_goal_threshold: Annotated[Optional[float], typer.Option(help="The distance threshold for goal-achieved")] = None,
sampling_seed: Annotated[Optional[int], typer.Option(help="The seed for sampling scenes")] = None,
obs_radius: Annotated[Optional[float], typer.Option(help="The radius for the observation")] = None,
# Train options
learning_rate: Annotated[Optional[float], typer.Option(help="The learning rate for training")] = None,
resample_scenes: Annotated[Optional[int], typer.Option(help="Whether to resample scenes during training; 0 or 1")] = None,
resample_interval: Annotated[Optional[int], typer.Option(help="The interval for resampling scenes")] = None,
total_timesteps: Annotated[Optional[int], typer.Option(help="The total number of training steps")] = None,
ent_coef: Annotated[Optional[float], typer.Option(help="Entropy coefficient")] = None,
# Wandb logging options
project: Annotated[Optional[str], typer.Option(help="WandB project name")] = None,
entity: Annotated[Optional[str], typer.Option(help="WandB entity name")] = None,
group: Annotated[Optional[str], typer.Option(help="WandB group name")] = None,
):
"""Run PPO training with the given configuration."""
# fmt: on

config = load_config("baselines/ippo/config/ippo_ff_puffer.yaml")
# Load default configs
config = load_config(config_path)

# Override configs with command-line arguments
env_config = {
"num_worlds": num_worlds,
"k_unique_scenes": k_unique_scenes,
"collision_weight": collision_weight,
"off_road_weight": off_road_weight,
"goal_achieved_weight": goal_achieved_weight,
"dist_to_goal_threshold": dist_to_goal_threshold,
"sampling_seed": sampling_seed,
"obs_radius": obs_radius,
}
config.environment.update(
{k: v for k, v in env_config.items() if v is not None}
)
train_config = {
"learning_rate": learning_rate,
"resample_scenes": None
if resample_scenes is None
else bool(resample_scenes),
"resample_interval": resample_interval,
"total_timesteps": total_timesteps,
"ent_coef": ent_coef,
}
config.train.update(
{k: v for k, v in train_config.items() if v is not None}
)

wandb_config = {
"project": project,
"entity": entity,
"group": group,
}
config.wandb.update(
{k: v for k, v in wandb_config.items() if v is not None}
)

make_env = env_creator(
data_dir=config.data_dir,
Expand All @@ -151,4 +226,8 @@ def sweep(args, project="PPO", sweep_name="my_sweep"):
)

if config.mode == "train":
train(config)
train(config, make_env)


if __name__ == "__main__":
app()
2 changes: 2 additions & 0 deletions baselines/ippo/ippo_sb3.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def train(exp_config: Box, scene_config: SceneConfig):
off_road_weight=exp_config.off_road_weight,
episode_len=exp_config.episode_len,
remove_non_vehicles=exp_config.remove_non_vehicles,
polyline_reduction_threshold=exp_config.polyline_reduction_threshold,
obs_radius=exp_config.observation_radius,
)

# Select model
Expand Down
20 changes: 0 additions & 20 deletions baselines/scripts/bash_exec_paper_fig.sh

This file was deleted.

4 changes: 0 additions & 4 deletions baselines/scripts/bash_exec_solve_n_scenes.sh

This file was deleted.

14 changes: 0 additions & 14 deletions baselines/scripts/sbatch_ippo.sh

This file was deleted.

17 changes: 0 additions & 17 deletions baselines/scripts/sbatch_paper_fig.sh

This file was deleted.

14 changes: 0 additions & 14 deletions baselines/scripts/sbatch_solve_n_scenes.sh

This file was deleted.

This file was deleted.

Binary file removed data/processed/waymax/scenario_ab2a72c63f8fd589.pkl
Binary file not shown.
Binary file not shown.
Binary file not shown.
7 changes: 6 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,9 @@ dependencies:
- urllib3==2.2.1
- virtualenv==20.25.1
- zipp==3.18.1
- huggingface_hub==0.26.5
- huggingface_hub==0.26.5
- wandb==0.19.1
- python-box==7.3.0
- python-dotenv==1.0.1
- jax==0.4.0
- typer==0.9.0
Loading

0 comments on commit 84c7453

Please sign in to comment.