From 9879694fb19cdecf9adf46cd977294616ac6ff1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roxana=20R=C4=83dulescu?= <8026679+rradules@users.noreply.github.com> Date: Tue, 7 Nov 2023 17:34:39 +0100 Subject: [PATCH] correct factory functions for BPD --- .../envs/beach_domain/beach_domain.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/momadm_benchmarks/envs/beach_domain/beach_domain.py b/momadm_benchmarks/envs/beach_domain/beach_domain.py index 18e6e725..e73f5abc 100644 --- a/momadm_benchmarks/envs/beach_domain/beach_domain.py +++ b/momadm_benchmarks/envs/beach_domain/beach_domain.py @@ -26,8 +26,8 @@ def parallel_env(**kwargs): - """Env factory function for the beach domain.""" - return MOBeachDomain(**kwargs) + """Parallel env factory function for the beach problem domain.""" + return raw_env(**kwargs) def env(**kwargs): @@ -39,17 +39,17 @@ def env(**kwargs): Returns: A fully wrapped env """ - env = raw_env(**kwargs) + env = parallel_env(**kwargs) + env = mo_parallel_to_aec(env) + # this wrapper helps error handling for discrete action spaces env = wrappers.AssertOutOfBoundsWrapper(env) return env def raw_env(**kwargs): - """To support the AEC API, the raw_env function just uses the from_parallel function to convert from a ParallelEnv to an AEC env.""" - env = parallel_env(**kwargs) - env = mo_parallel_to_aec(env) - return env + """Env factory function for the beach problem domain.""" + return MOBeachDomain(**kwargs) class MOBeachDomain(MOParallelEnv):