Skip to content

Commit

Permalink
pin minari version
Browse files Browse the repository at this point in the history
  • Loading branch information
grahamannett committed Oct 9, 2023
1 parent 0fb2f8d commit 88bc8c6
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 17 deletions.
34 changes: 18 additions & 16 deletions d3rlpy/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,21 +351,23 @@ def play(
@cli.command(short_help="Install additional packages.")
@click.argument("name")
def install(name: str) -> None:
match name:
case "atari":
_install_module(["gym[atari,accept-rom-license]"], upgrade=True)
case "d4rl_atari":
install("atari")
_install_module(["git+https://github.com/takuseno/d4rl-atari"])
case "d4rl":
_install_module(["git+https://github.com/Farama-Foundation/D4RL"])
_install_module(["gym"], upgrade=True)
_install_module(["-y", "pybullet"], upgrade=True)
case "minari":
_install_module(["minari"], upgrade=True)
case _:
raise ValueError(f"Unsupported command: {name}")

def _install_module(name: list[str], upgrade: bool = False, check: bool = True) -> None:
if name == "atari":
_install_module(["gym[atari,accept-rom-license]"], upgrade=True)
elif name == "d4rl_atari":
install("atari")
_install_module(["git+https://github.com/takuseno/d4rl-atari"])
elif name == "d4rl":
_install_module(["git+https://github.com/Farama-Foundation/D4RL"])
_install_module(["gym"], upgrade=True)
_install_module(["-y", "pybullet"], upgrade=True)
elif name == "minari":
_install_module(["minari==0.4.2"], upgrade=True)
else:
raise ValueError(f"Unsupported command: {name}")


def _install_module(
name: list[str], upgrade: bool = False, check: bool = True
) -> None:
name = ["-U", *name] if upgrade else name
subprocess.run(["pip3", "install", *name], check=check)
1 change: 1 addition & 0 deletions docs/references/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ learning algorithms.
d3rlpy.datasets.get_atari
d3rlpy.datasets.get_atari_transitions
d3rlpy.datasets.get_d4rl
d3rlpy.datasets.get_minari
d3rlpy.datasets.get_dataset
14 changes: 13 additions & 1 deletion tests/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from d3rlpy.datasets import get_cartpole, get_dataset, get_pendulum
from d3rlpy.datasets import get_cartpole, get_dataset, get_pendulum, get_minari


@pytest.mark.parametrize("dataset_type", ["replay", "random"])
Expand All @@ -23,3 +23,15 @@ def test_get_dataset(env_name: str) -> None:
assert env.unwrapped.spec.id == "CartPole-v1"
elif env_name == "pendulum-random":
assert env.unwrapped.spec.id == "Pendulum-v1"


@pytest.mark.parametrize(
"dataset_name, env_name",
[
("door-cloned-v1", "AdroitHandDoor-v1"),
("relocate-expert-v1", "AdroitHandRelocate-v1"),
],
)
def test_get_minari(dataset_name: str, env_name: str) -> None:
_, env = get_minari(dataset_name)
assert env.unwrapped.spec.id == env_name

0 comments on commit 88bc8c6

Please sign in to comment.