Skip to content

Commit

Permalink
Merge pull request #46 from Farama-Foundation/chore/improve-eval-script
Browse files Browse the repository at this point in the history
Seed better, fix evaluation
  • Loading branch information
ffelten authored Mar 11, 2024
2 parents c5ffc33 + 2b959d0 commit 9a287c6
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
4 changes: 2 additions & 2 deletions momaland/learning/cooperative_momappo/continuous_momappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,8 +541,8 @@ def _env_step(runner_state):
eval_env = normalize_obs_v0(eval_env, env_min=-1.0, env_max=1.0)
eval_env = agent_indicator_v0(eval_env)

env.reset()
eval_env.reset()
env.reset(seed=args.seed)
eval_env.reset(seed=args.seed)
current_timestep = 0
reward_dim = env.unwrapped.reward_space(env.possible_agents[0]).shape[0]

Expand Down
2 changes: 2 additions & 0 deletions momaland/learning/cooperative_momappo/exec_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def main():
single_obs_space = env.observation_space(env.possible_agents[0])
single_action_space = env.action_space(env.possible_agents[0])
dummy_local_obs_and_id = jnp.zeros(single_obs_space.shape)
env.reset(seed=args.seed)
key, actor_key = jax.random.split(key, 2)
if args.continuous:
from momaland.learning.cooperative_momappo.continuous_momappo import Actor
Expand Down Expand Up @@ -76,6 +77,7 @@ def main():

# Load the model
actor_state = load_actor_state(args.model_dir, actor_state)
# actor_module.apply = jax.jit(actor_module.apply)
# Perform the replay
vec_ret, disc_vec_return = eval_mo(actor_module=actor_module, actor_state=actor_state, env=env, num_obj=reward_dim)
print("Done!!")
Expand Down
6 changes: 4 additions & 2 deletions momaland/learning/cooperative_momappo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ def eval_mo(actor_module, actor_state, env, num_obj, gamma_decay=0.99) -> Tuple[
key, subkey = jax.random.split(key)
action_keys = jax.random.split(subkey, len(env.possible_agents))

vec_return += np.array(list(rew.values())).sum(axis=0)
disc_vec_return += gamma * vec_return
rewards = np.array(list(rew.values())).sum(axis=0)
vec_return += rewards
disc_vec_return += gamma * rewards
gamma *= gamma_decay

return (
Expand All @@ -100,6 +101,7 @@ def policy_evaluation_mo(
Returns:
(float, float, np.ndarray, np.ndarray): Avg scalarized return, Avg scalarized discounted return, Avg vectorized return, Avg vectorized discounted return
"""
env.reset(seed=42)
evals = [
eval_mo(actor_module=actor_module, actor_state=actor_state, env=env, num_obj=num_obj, gamma_decay=gamma)
for _ in range(rep)
Expand Down

0 comments on commit 9a287c6

Please sign in to comment.