Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the puffer #327

Draft
wants to merge 12 commits into
base: dc/selfplay_safe
Choose a base branch
from
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ pyrightconfig.json
# To be manually created using .env.template
.env

# Sbatch scripts
*.sh

# Logs
examples/experiments/scripts/logs/*

Expand Down
180 changes: 112 additions & 68 deletions baselines/ippo/ippo_pufferlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import numpy as np
import wandb
from box import Box
import time
import random

from integrations.rl.puffer import ppo
from integrations.rl.puffer.puffer_env import env_creator
Expand All @@ -32,6 +34,73 @@

app = Typer()

def log_normal(mean, scale, clip):
'''Samples normally spaced points on a log 10 scale.
mean: Your center sample point
scale: standard deviation in base 10 orders of magnitude
clip: maximum standard deviations

Example: mean=0.001, scale=1, clip=2 will produce data from
0.1 to 0.00001 with most of it between 0.01 and 0.0001
'''
return 10**np.clip(
np.random.normal(
np.log10(mean),
scale,
),
a_min = np.log10(mean) - clip,
a_max = np.log10(mean) + clip,
)

def logit_normal(mean, scale, clip):
'''log normal but for logit data like gamma and gae_lambda'''
return 1 - log_normal(1 - mean, scale, clip)

def uniform_pow2(min, max):
'''Uniform distribution over powers of 2 between min and max inclusive'''
min_base = np.log2(min)
max_base = np.log2(max)
return 2**np.random.randint(min_base, max_base+1)

def uniform(min, max):
'''Uniform distribution between min and max inclusive'''
return np.random.uniform(min, max)

def int_uniform(min, max):
'''Uniform distribution between min and max inclusive'''
return np.random.randint(min, max+1)

def sample_hyperparameters(sweep_config):
samples = {}
for name, param in sweep_config.items():
if name in ('method', 'name', 'metric'):
continue

assert isinstance(param, dict)
if any(isinstance(param[k], dict) for k in param):
samples[name] = sample_hyperparameters(param)
elif 'values' in param:
assert 'distribution' not in param
samples[name] = random.choice(param['values'])
elif 'distribution' in param:
if param['distribution'] == 'uniform':
samples[name] = uniform(param['min'], param['max'])
elif param['distribution'] == 'int_uniform':
samples[name] = int_uniform(param['min'], param['max'])
elif param['distribution'] == 'uniform_pow2':
samples[name] = uniform_pow2(param['min'], param['max'])
elif param['distribution'] == 'log_normal':
samples[name] = log_normal(
param['mean'], param['scale'], param['clip'])
elif param['distribution'] == 'logit_normal':
samples[name] = logit_normal(
param['mean'], param['scale'], param['clip'])
else:
raise ValueError(f'Invalid distribution: {param["distribution"]}')
else:
raise ValueError('Must specify either values or distribution')

return samples

def get_model_parameters(policy):
"""Helper function to count the number of trainable parameters."""
Expand All @@ -56,42 +125,20 @@ def make_policy(env, config):
dropout=config.train.network.dropout,
).to(config.train.device)


def train(args, make_env):
def train(args, vecenv):
"""Main training loop for the PPO agent."""

backend_mapping = {
# Note: Only native backend is currently supported with GPUDrive
"native": pufferlib.vector.Native,
"serial": pufferlib.vector.Serial,
"multiprocessing": pufferlib.vector.Multiprocessing,
"ray": pufferlib.vector.Ray,
}

backend = backend_mapping.get(args.vec.backend)
if not backend:
raise ValueError("Invalid --vec.backend.")

vecenv = pufferlib.vector.make(
make_env,
num_envs=1, # GPUDrive is already batched
num_workers=args.vec.num_workers,
batch_size=args.vec.env_batch_size,
zero_copy=args.vec.zero_copy,
backend=backend,
)

policy = make_policy(env=vecenv.driver_env, config=args).to(
args.train.device
)

args.train.network.num_parameters = get_model_parameters(policy)
args.train.env = args.environment.name

args.wandb = init_wandb(args, args.train.exp_id, id=args.train.exp_id)
args.train.__dict__.update(dict(args.wandb.config.train))
wandb_run = init_wandb(args, args.train.exp_id, id=args.train.exp_id)
args.train.update(dict(wandb_run.config.train))

data = ppo.create(args.train, vecenv, policy, wandb=args.wandb)
data = ppo.create(args.train, vecenv, policy, wandb=wandb_run)
while data.global_step < args.train.total_timesteps:
try:
ppo.evaluate(data) # Rollout
Expand All @@ -107,8 +154,24 @@ def train(args, make_env):
ppo.evaluate(data)
ppo.close(data)

def set_experiment_metadata(config):
datetime_ = datetime.now().strftime("%m_%d_%H_%M_%S_%f")[:-3]
if config["train"]["resample_scenes"]:
if config["train"]["resample_scenes"]:
dataset_size = config["train"]["resample_dataset_size"]
config["train"][
"exp_id"
] = f'PPO_R_{dataset_size}__{datetime_}'
else:
dataset_size = str(config["environment"]["k_unique_scenes"])
config["train"][
"exp_id"
] = f'PPO_S_{dataset_size}__{datetime_}'

config["environment"]["dataset_size"] = dataset_size


def init_wandb(args, name, id=None, resume=True):
def init_wandb(args, name, id=None, resume=True, tag=None):
wandb.init(
id=id or wandb.util.generate_id(),
project=args.wandb.project,
Expand All @@ -128,29 +191,6 @@ def init_wandb(args, name, id=None, resume=True):

return wandb


def sweep(args, project="PPO", sweep_name="my_sweep"):
"""Initialize a WandB sweep with hyperparameters."""
sweep_id = wandb.sweep(
sweep=dict(
method="random",
name=sweep_name,
metric={"goal": "maximize", "name": "environment/episode_return"},
parameters={
"learning_rate": {
"distribution": "log_uniform_values",
"min": 1e-4,
"max": 1e-1,
},
"batch_size": {"values": [512, 1024, 2048]},
"minibatch_size": {"values": [128, 256, 512]},
},
),
project=project,
)
wandb.agent(sweep_id, lambda: train(args), count=100)


@app.command()
def run(
config_path: Annotated[
Expand Down Expand Up @@ -187,6 +227,7 @@ def run(
project: Annotated[Optional[str], typer.Option(help="WandB project name")] = None,
entity: Annotated[Optional[str], typer.Option(help="WandB entity name")] = None,
group: Annotated[Optional[str], typer.Option(help="WandB group name")] = None,
max_runs: Annotated[Optional[int], typer.Option(help="Maximum number of sweep runs")] = 100,
render: Annotated[Optional[int], typer.Option(help="Whether to render the environment; 0 or 1")] = None,
):
"""Run PPO training with the given configuration."""
Expand Down Expand Up @@ -244,21 +285,6 @@ def run(
{k: v for k, v in wandb_config.items() if v is not None}
)

datetime_ = datetime.now().strftime("%m_%d_%H_%M_%S_%f")[:-3]

if config["train"]["resample_scenes"]:
if config["train"]["resample_scenes"]:
dataset_size = config["train"]["resample_dataset_size"]
config["train"][
"exp_id"
] = f'{config["train"]["exp_id"]}__R_{dataset_size}__{datetime_}'
else:
dataset_size = str(config["environment"]["k_unique_scenes"])
config["train"][
"exp_id"
] = f'{config["train"]["exp_id"]}__S_{dataset_size}__{datetime_}'

config["environment"]["dataset_size"] = dataset_size
config["train"]["device"] = config["train"].get(
"device", "cpu"
) # Default to 'cpu' if not set
Expand All @@ -283,11 +309,29 @@ def run(
train_config=config.train,
device=config.train.device,
)
vecenv = pufferlib.vector.make(
make_env,
num_envs=1, # GPUDrive is already batched
num_workers=config.vec.num_workers,
batch_size=config.vec.env_batch_size,
zero_copy=config.vec.zero_copy,
backend=pufferlib.vector.Native,
)

if config.mode == "train":
train(config, make_env)

set_experiment_metadata(config)
train(config, vecenv)
elif config.mode == "sweep":
for i in range(max_runs):
np.random.seed(int(time.time()))
random.seed(int(time.time()))
set_experiment_metadata(config)
hypers = sample_hyperparameters(config.sweep)
config.train.update(hypers['train'])
config.environment.update(hypers['environment'])
train(config, vecenv)

if __name__ == "__main__":

import cProfile
#cProfile.run('app()', 'profiled')
app()
8 changes: 4 additions & 4 deletions examples/experiments/eval/config/eval_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ res_path: examples/experiments/eval/dataframes/0120 # Store dataframes here
test_dataset_size: 10_000 # Number of test scenarios to evaluate on

# Environment settings
train_dir: "/scratch/kj2676/gpudrive/data/processed/training"
test_dir: "/scratch/kj2676/gpudrive/data/processed/validation"
train_dir: "data/processed/training"
test_dir: "data/processed/validation"

num_worlds: 200 # Number of parallel environments for evaluation
max_controlled_agents: 128 # Maximum number of agents controlled by the model.
max_controlled_agents: 64 # Maximum number of agents controlled by the model.
ego_state: true
road_map_obs: true
partner_obs: true
Expand All @@ -24,4 +24,4 @@ polyline_reduction_threshold: 0.1 # Rate at which to sample points from the poly
sampling_seed: 42 # If given, the set of scenes to sample from will be deterministic, if None, the set of scenes will be random
obs_radius: 50.0 # Visibility radius of the agents

device: "cuda" # Options: "cpu", "cuda"
device: "cuda" # Options: "cpu", "cuda"
13 changes: 4 additions & 9 deletions examples/experiments/eval/config/model_config.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
models_path: examples/experiments/eval/models/0120
models_path: wandb/run-20250121_225758-PPO_R_1000__01_21_22_57_53_461/files/runs/PPO_R_1000__01_21_22_57_53_461

models:
- name: random_baseline
train_dataset_size: null
- name: model_PPO_R_1000__01_21_22_57_53_461_001520
train_dataset_size: 1000
wandb: null
trained_on: null

- name: model_PPO__R_1000__01_19_11_15_25_854_002500
train_dataset_size: 1000
wandb: https://wandb.ai/emerge_/self_play_rl_safe/runs/PPO__R_1000__01_19_11_15_25_854?nw=nwuserdaphnecor
trained_on: cluster

# - name: model_PPO__R_1000__01_10_17_06_33_697_003500
# train_dataset_size: 1000
Expand Down Expand Up @@ -39,4 +34,4 @@ models:
# # - name: model_PPO__R_100000__01_06_11_29_36_390_012000
# # train_dataset_size: 100_000
# # wandb: https://wandb.ai/emerge_/paper_1_self_play/runs/PPO__R_100000__01_06_11_29_36_390?nw=nwuserdaphnecor
# # trained_on: cluster
# # trained_on: cluster
22 changes: 13 additions & 9 deletions examples/experiments/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def rollout(
episode_len = env.config.episode_len

# Reset episode
next_obs = env.reset()
next_obs = env.reset(env.cont_agent_mask)

# Storage
goal_achieved = torch.zeros((num_worlds, max_agent_count), device=device)
Expand All @@ -123,7 +123,6 @@ def rollout(
)

controlled_agent_mask = env.cont_agent_mask.clone() & ~bugged_agent_mask

live_agent_mask = controlled_agent_mask.clone()

for time_step in range(episode_len):
Expand All @@ -132,14 +131,14 @@ def rollout(
# Get actions for active agents
if live_agent_mask.any():
action, _, _, _ = policy(
next_obs[live_agent_mask], deterministic=deterministic
next_obs, deterministic=deterministic
)

# Insert actions into a template
action_template = torch.zeros(
(num_worlds, max_agent_count), dtype=torch.int64, device=device
)
action_template[live_agent_mask] = action.to(device)
action_template[env.cont_agent_mask] = action.to(device)

# Step the environment
env.step_dynamics(action_template)
Expand All @@ -166,7 +165,7 @@ def rollout(
)

# Update observations, dones, and infos
next_obs = env.get_obs()
next_obs = env.get_obs(env.cont_agent_mask)
dones = env.get_dones().bool()
infos = env.get_infos()

Expand All @@ -191,8 +190,8 @@ def rollout(

for world in done_worlds:
if world in active_worlds:
active_worlds.remove(world)
logging.debug(f"World {world} done at time step {time_step}")
active_worlds.remove(world)

if not active_worlds: # Exit early if all worlds are done
break
Expand Down Expand Up @@ -374,9 +373,10 @@ def make_env(config, train_loader):
train_loader = SceneDataLoader(
root=eval_config.train_dir,
batch_size=eval_config.num_worlds,
dataset_size=model.train_dataset_size
if model.name != "random_baseline"
else 1000,
dataset_size=1000, #Below didn't work. Was None
#model.train_dataset_size
#if model.name != "random_baseline"
#else 1000,
sample_with_replacement=False,
shuffle=False,
)
Expand Down Expand Up @@ -405,6 +405,10 @@ def make_env(config, train_loader):
render_sim_state=False,
)

result = df_res_train.groupby('dataset')[['goal_achieved', 'collided', 'off_road', 'not_goal_nor_crashed']].agg(['mean', 'std'])
print('Result: ', result)
breakpoint()

df_res_test = evaluate_policy(
env=env,
policy=policy,
Expand Down
Loading
Loading