diff --git a/d3rlpy/algos/qlearning/awac.py b/d3rlpy/algos/qlearning/awac.py index c69a7327..d8414e6e 100644 --- a/d3rlpy/algos/qlearning/awac.py +++ b/d3rlpy/algos/qlearning/awac.py @@ -4,7 +4,6 @@ from ...base import DeviceArg, LearnableConfig, register_learnable from ...constants import ActionSpace -from ...dataset import Shape from ...models.builders import ( create_continuous_q_function, create_normal_policy, @@ -13,6 +12,7 @@ from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field from ...models.torch import Parameter +from ...types import Shape from .base import QLearningAlgoBase from .torch.awac_impl import AWACImpl from .torch.sac_impl import SACModules diff --git a/d3rlpy/algos/qlearning/base.py b/d3rlpy/algos/qlearning/base.py index 06134cd8..1e03cbfd 100644 --- a/d3rlpy/algos/qlearning/base.py +++ b/d3rlpy/algos/qlearning/base.py @@ -21,7 +21,6 @@ from ...base import ImplBase, LearnableBase, LearnableConfig, save_config from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace from ...dataset import ( - Observation, ReplayBuffer, TransitionMiniBatch, check_non_1d_array, @@ -46,7 +45,7 @@ sync_optimizer_state, train_api, ) -from ...types import NDArray +from ...types import NDArray, Observation from ..utility import ( assert_action_space_with_dataset, assert_action_space_with_env, diff --git a/d3rlpy/algos/qlearning/bc.py b/d3rlpy/algos/qlearning/bc.py index 0d8eb6e2..507cf9e2 100644 --- a/d3rlpy/algos/qlearning/bc.py +++ b/d3rlpy/algos/qlearning/bc.py @@ -2,7 +2,6 @@ from ...base import DeviceArg, LearnableConfig, register_learnable from ...constants import ActionSpace -from ...dataset import Shape from ...models.builders import ( create_categorical_policy, create_deterministic_policy, @@ -10,6 +9,7 @@ ) from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field +from ...types import Shape from .base import QLearningAlgoBase from .torch.bc_impl import ( BCBaseImpl, diff --git a/d3rlpy/algos/qlearning/bcq.py b/d3rlpy/algos/qlearning/bcq.py index 3a40c89c..7cc50255 100644 --- a/d3rlpy/algos/qlearning/bcq.py +++ b/d3rlpy/algos/qlearning/bcq.py @@ -2,7 +2,6 @@ from ...base import DeviceArg, LearnableConfig, register_learnable from ...constants import ActionSpace -from ...dataset import Shape from ...models.builders import ( create_categorical_policy, create_conditional_vae, @@ -14,6 +13,7 @@ from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field from ...models.torch import CategoricalPolicy, compute_output_size +from ...types import Shape from .base import QLearningAlgoBase from .torch.bcq_impl import ( BCQImpl, diff --git a/d3rlpy/algos/qlearning/bear.py b/d3rlpy/algos/qlearning/bear.py index 2c016201..edc17910 100644 --- a/d3rlpy/algos/qlearning/bear.py +++ b/d3rlpy/algos/qlearning/bear.py @@ -3,7 +3,6 @@ from ...base import DeviceArg, LearnableConfig, register_learnable from ...constants import ActionSpace -from ...dataset import Shape from ...models.builders import ( create_conditional_vae, create_continuous_q_function, @@ -13,6 +12,7 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field +from ...types import Shape from .base import QLearningAlgoBase from .torch.bear_impl import BEARImpl, BEARModules diff --git a/d3rlpy/algos/qlearning/cql.py b/d3rlpy/algos/qlearning/cql.py index 42ba612a..3650f8f8 100644 --- a/d3rlpy/algos/qlearning/cql.py +++ b/d3rlpy/algos/qlearning/cql.py @@ -3,7 +3,6 @@ from ...base import DeviceArg, LearnableConfig, register_learnable from ...constants import ActionSpace -from ...dataset import Shape from ...models.builders import ( create_continuous_q_function, create_discrete_q_function, @@ -13,6 +12,7 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field +from ...types import Shape from .base import QLearningAlgoBase from .torch.cql_impl import CQLImpl, CQLModules, DiscreteCQLImpl from .torch.dqn_impl import DQNModules diff --git a/d3rlpy/algos/qlearning/crr.py b/d3rlpy/algos/qlearning/crr.py index c7aeb90b..d690a994 100644 --- a/d3rlpy/algos/qlearning/crr.py +++ b/d3rlpy/algos/qlearning/crr.py @@ -2,7 +2,6 @@ from ...base import DeviceArg, LearnableConfig, register_learnable from ...constants import ActionSpace -from ...dataset import Shape from ...models.builders import ( create_continuous_q_function, create_normal_policy, @@ -10,6 +9,7 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field +from ...types import Shape from .base import QLearningAlgoBase from .torch.crr_impl import CRRImpl, CRRModules diff --git a/d3rlpy/algos/qlearning/ddpg.py b/d3rlpy/algos/qlearning/ddpg.py index e8742b33..5e83f78e 100644 --- a/d3rlpy/algos/qlearning/ddpg.py +++ b/d3rlpy/algos/qlearning/ddpg.py @@ -2,7 +2,6 @@ from ...base import DeviceArg, LearnableConfig, register_learnable from ...constants import ActionSpace -from ...dataset import Shape from ...models.builders import ( create_continuous_q_function, create_deterministic_policy, @@ -10,6 +9,7 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field +from ...types import Shape from .base import QLearningAlgoBase from .torch.ddpg_impl import DDPGImpl, DDPGModules diff --git a/d3rlpy/algos/qlearning/dqn.py b/d3rlpy/algos/qlearning/dqn.py index 8b9b7752..ed21b1d0 100644 --- a/d3rlpy/algos/qlearning/dqn.py +++ b/d3rlpy/algos/qlearning/dqn.py @@ -2,11 +2,11 @@ from ...base import DeviceArg, LearnableConfig, register_learnable from ...constants import ActionSpace -from ...dataset import Shape from ...models.builders import create_discrete_q_function from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field +from ...types import Shape from .base import QLearningAlgoBase from .torch.dqn_impl import DoubleDQNImpl, DQNImpl, DQNModules diff --git a/d3rlpy/algos/qlearning/explorers.py b/d3rlpy/algos/qlearning/explorers.py index 951b9ed8..bd0b4b6c 100644 --- a/d3rlpy/algos/qlearning/explorers.py +++ b/d3rlpy/algos/qlearning/explorers.py @@ -3,10 +3,9 @@ import numpy as np -from ...dataset import Observation from ...interface import QLearningAlgoProtocol from ...preprocessing.action_scalers import MinMaxActionScaler -from ...types import NDArray +from ...types import NDArray, Observation __all__ = [ "Explorer", diff --git a/d3rlpy/algos/qlearning/iql.py b/d3rlpy/algos/qlearning/iql.py index 85959519..a95863cb 100644 --- a/d3rlpy/algos/qlearning/iql.py +++ b/d3rlpy/algos/qlearning/iql.py @@ -2,7 +2,6 @@ from ...base import DeviceArg, LearnableConfig, register_learnable from ...constants import ActionSpace -from ...dataset import Shape from ...models.builders import ( create_continuous_q_function, create_normal_policy, @@ -11,6 +10,7 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import MeanQFunctionFactory +from ...types import Shape from .base import QLearningAlgoBase from .torch.iql_impl import IQLImpl, IQLModules diff --git a/d3rlpy/algos/qlearning/nfq.py b/d3rlpy/algos/qlearning/nfq.py index 663b6b0e..46e889ac 100644 --- a/d3rlpy/algos/qlearning/nfq.py +++ b/d3rlpy/algos/qlearning/nfq.py @@ -2,11 +2,11 @@ from ...base import DeviceArg, LearnableConfig, register_learnable from ...constants import ActionSpace -from ...dataset import Shape from ...models.builders import create_discrete_q_function from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field +from ...types import Shape from .base import QLearningAlgoBase from .torch.dqn_impl import DQNImpl, DQNModules diff --git a/d3rlpy/algos/qlearning/plas.py b/d3rlpy/algos/qlearning/plas.py index fdaf119e..dfad71d1 100644 --- a/d3rlpy/algos/qlearning/plas.py +++ b/d3rlpy/algos/qlearning/plas.py @@ -2,7 +2,6 @@ from ...base import DeviceArg, LearnableConfig, register_learnable from ...constants import ActionSpace -from ...dataset import Shape from ...models.builders import ( create_conditional_vae, create_continuous_q_function, @@ -12,6 +11,7 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field +from ...types import Shape from .base import QLearningAlgoBase from .torch.plas_impl import ( PLASImpl, diff --git a/d3rlpy/algos/qlearning/random_policy.py b/d3rlpy/algos/qlearning/random_policy.py index 6f9e77d5..2f3a3c4d 100644 --- a/d3rlpy/algos/qlearning/random_policy.py +++ b/d3rlpy/algos/qlearning/random_policy.py @@ -5,9 +5,8 @@ from ...base import DeviceArg, LearnableConfig, register_learnable from ...constants import ActionSpace -from ...dataset import Observation, Shape from ...torch_utility import TorchMiniBatch -from ...types import NDArray +from ...types import NDArray, Observation, Shape from .base import QLearningAlgoBase __all__ = [ diff --git a/d3rlpy/algos/qlearning/sac.py b/d3rlpy/algos/qlearning/sac.py index 4b2afe02..f5d286b6 100644 --- a/d3rlpy/algos/qlearning/sac.py +++ b/d3rlpy/algos/qlearning/sac.py @@ -3,7 +3,6 @@ from ...base import DeviceArg, LearnableConfig, register_learnable from ...constants import ActionSpace -from ...dataset import Shape from ...models.builders import ( create_categorical_policy, create_continuous_q_function, @@ -14,6 +13,7 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field +from ...types import Shape from .base import QLearningAlgoBase from .torch.sac_impl import ( DiscreteSACImpl, diff --git a/d3rlpy/algos/qlearning/td3.py b/d3rlpy/algos/qlearning/td3.py index a59ef941..2633bce0 100644 --- a/d3rlpy/algos/qlearning/td3.py +++ b/d3rlpy/algos/qlearning/td3.py @@ -2,7 +2,6 @@ from ...base import DeviceArg, LearnableConfig, register_learnable from ...constants import ActionSpace -from ...dataset import Shape from ...models.builders import ( create_continuous_q_function, create_deterministic_policy, @@ -10,6 +9,7 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field +from ...types import Shape from .base import QLearningAlgoBase from .torch.ddpg_impl import DDPGModules from .torch.td3_impl import TD3Impl diff --git a/d3rlpy/algos/qlearning/td3_plus_bc.py b/d3rlpy/algos/qlearning/td3_plus_bc.py index 8656a688..ae1940fd 100644 --- a/d3rlpy/algos/qlearning/td3_plus_bc.py +++ b/d3rlpy/algos/qlearning/td3_plus_bc.py @@ -2,7 +2,6 @@ from ...base import DeviceArg, LearnableConfig, register_learnable from ...constants import ActionSpace -from ...dataset import Shape from ...models.builders import ( create_continuous_q_function, create_deterministic_policy, @@ -10,6 +9,7 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field +from ...types import Shape from .base import QLearningAlgoBase from .torch.ddpg_impl import DDPGModules from .torch.td3_plus_bc_impl import TD3PlusBCImpl diff --git a/d3rlpy/algos/qlearning/torch/awac_impl.py b/d3rlpy/algos/qlearning/torch/awac_impl.py index 81646adc..af73d020 100644 --- a/d3rlpy/algos/qlearning/torch/awac_impl.py +++ b/d3rlpy/algos/qlearning/torch/awac_impl.py @@ -1,12 +1,12 @@ import torch import torch.nn.functional as F -from ....dataset import Shape from ....models.torch import ( ContinuousEnsembleQFunctionForwarder, build_gaussian_distribution, ) from ....torch_utility import TorchMiniBatch +from ....types import Shape from .sac_impl import SACImpl, SACModules __all__ = ["AWACImpl"] diff --git a/d3rlpy/algos/qlearning/torch/bc_impl.py b/d3rlpy/algos/qlearning/torch/bc_impl.py index dbefeb55..06842f35 100644 --- a/d3rlpy/algos/qlearning/torch/bc_impl.py +++ b/d3rlpy/algos/qlearning/torch/bc_impl.py @@ -5,7 +5,6 @@ import torch from torch.optim import Optimizer -from ....dataset import Shape from ....models.torch import ( CategoricalPolicy, DeterministicPolicy, @@ -16,6 +15,7 @@ compute_stochastic_imitation_loss, ) from ....torch_utility import Modules, TorchMiniBatch +from ....types import Shape from ..base import QLearningAlgoImplBase __all__ = ["BCImpl", "DiscreteBCImpl", "BCModules", "DiscreteBCModules"] diff --git a/d3rlpy/algos/qlearning/torch/bcq_impl.py b/d3rlpy/algos/qlearning/torch/bcq_impl.py index fcedd24c..51f9bf14 100644 --- a/d3rlpy/algos/qlearning/torch/bcq_impl.py +++ b/d3rlpy/algos/qlearning/torch/bcq_impl.py @@ -6,7 +6,6 @@ import torch.nn.functional as F from torch.optim import Optimizer -from ....dataset import Shape from ....models.torch import ( CategoricalPolicy, ConditionalVAE, @@ -19,6 +18,7 @@ forward_vae_decode, ) from ....torch_utility import TorchMiniBatch, soft_sync +from ....types import Shape from .ddpg_impl import DDPGBaseImpl, DDPGBaseModules from .dqn_impl import DoubleDQNImpl, DQNLoss, DQNModules diff --git a/d3rlpy/algos/qlearning/torch/bear_impl.py b/d3rlpy/algos/qlearning/torch/bear_impl.py index aa2cc173..c16f3501 100644 --- a/d3rlpy/algos/qlearning/torch/bear_impl.py +++ b/d3rlpy/algos/qlearning/torch/bear_impl.py @@ -4,7 +4,6 @@ import torch from torch.optim import Optimizer -from ....dataset import Shape from ....models.torch import ( ConditionalVAE, ContinuousEnsembleQFunctionForwarder, @@ -15,6 +14,7 @@ forward_vae_sample_n, ) from ....torch_utility import TorchMiniBatch +from ....types import Shape from .sac_impl import SACImpl, SACModules __all__ = ["BEARImpl", "BEARModules"] diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index a531abd6..a63586e3 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -6,7 +6,6 @@ import torch.nn.functional as F from torch.optim import Optimizer -from ....dataset import Shape from ....models.torch import ( ContinuousEnsembleQFunctionForwarder, DiscreteEnsembleQFunctionForwarder, @@ -14,6 +13,7 @@ build_squashed_gaussian_distribution, ) from ....torch_utility import TorchMiniBatch +from ....types import Shape from .dqn_impl import DoubleDQNImpl, DQNLoss, DQNModules from .sac_impl import SACImpl, SACModules diff --git a/d3rlpy/algos/qlearning/torch/crr_impl.py b/d3rlpy/algos/qlearning/torch/crr_impl.py index 38f2168d..33d00507 100644 --- a/d3rlpy/algos/qlearning/torch/crr_impl.py +++ b/d3rlpy/algos/qlearning/torch/crr_impl.py @@ -4,13 +4,13 @@ import torch import torch.nn.functional as F -from ....dataset import Shape from ....models.torch import ( ContinuousEnsembleQFunctionForwarder, NormalPolicy, build_gaussian_distribution, ) from ....torch_utility import TorchMiniBatch, hard_sync, soft_sync +from ....types import Shape from .ddpg_impl import DDPGBaseImpl, DDPGBaseModules __all__ = ["CRRImpl", "CRRModules"] diff --git a/d3rlpy/algos/qlearning/torch/ddpg_impl.py b/d3rlpy/algos/qlearning/torch/ddpg_impl.py index a1535589..58396a65 100644 --- a/d3rlpy/algos/qlearning/torch/ddpg_impl.py +++ b/d3rlpy/algos/qlearning/torch/ddpg_impl.py @@ -6,9 +6,9 @@ from torch import nn from torch.optim import Optimizer -from ....dataset import Shape from ....models.torch import ContinuousEnsembleQFunctionForwarder, Policy from ....torch_utility import Modules, TorchMiniBatch, hard_sync, soft_sync +from ....types import Shape from ..base import QLearningAlgoImplBase from .utility import ContinuousQFunctionMixin diff --git a/d3rlpy/algos/qlearning/torch/dqn_impl.py b/d3rlpy/algos/qlearning/torch/dqn_impl.py index e5caee13..a04e9c88 100644 --- a/d3rlpy/algos/qlearning/torch/dqn_impl.py +++ b/d3rlpy/algos/qlearning/torch/dqn_impl.py @@ -6,9 +6,9 @@ from torch.optim import Optimizer from ....dataclass_utils import asdict_as_float -from ....dataset import Shape from ....models.torch import DiscreteEnsembleQFunctionForwarder from ....torch_utility import Modules, TorchMiniBatch, hard_sync +from ....types import Shape from ..base import QLearningAlgoImplBase from .utility import DiscreteQFunctionMixin diff --git a/d3rlpy/algos/qlearning/torch/iql_impl.py b/d3rlpy/algos/qlearning/torch/iql_impl.py index 880a4c09..f2acf2c4 100644 --- a/d3rlpy/algos/qlearning/torch/iql_impl.py +++ b/d3rlpy/algos/qlearning/torch/iql_impl.py @@ -3,7 +3,6 @@ import torch -from ....dataset import Shape from ....models.torch import ( ContinuousEnsembleQFunctionForwarder, NormalPolicy, @@ -11,6 +10,7 @@ build_gaussian_distribution, ) from ....torch_utility import TorchMiniBatch +from ....types import Shape from .ddpg_impl import DDPGBaseImpl, DDPGBaseModules __all__ = ["IQLImpl", "IQLModules"] diff --git a/d3rlpy/algos/qlearning/torch/plas_impl.py b/d3rlpy/algos/qlearning/torch/plas_impl.py index df3a081b..7930540b 100644 --- a/d3rlpy/algos/qlearning/torch/plas_impl.py +++ b/d3rlpy/algos/qlearning/torch/plas_impl.py @@ -4,7 +4,6 @@ import torch from torch.optim import Optimizer -from ....dataset import Shape from ....models.torch import ( ConditionalVAE, ContinuousEnsembleQFunctionForwarder, @@ -14,6 +13,7 @@ forward_vae_decode, ) from ....torch_utility import TorchMiniBatch, soft_sync +from ....types import Shape from .ddpg_impl import DDPGBaseImpl, DDPGBaseModules __all__ = [ diff --git a/d3rlpy/algos/qlearning/torch/sac_impl.py b/d3rlpy/algos/qlearning/torch/sac_impl.py index b6612295..309e0d91 100644 --- a/d3rlpy/algos/qlearning/torch/sac_impl.py +++ b/d3rlpy/algos/qlearning/torch/sac_impl.py @@ -7,7 +7,6 @@ from torch import nn from torch.optim import Optimizer -from ....dataset import Shape from ....models.torch import ( CategoricalPolicy, ContinuousEnsembleQFunctionForwarder, @@ -18,6 +17,7 @@ build_squashed_gaussian_distribution, ) from ....torch_utility import Modules, TorchMiniBatch, hard_sync +from ....types import Shape from ..base import QLearningAlgoImplBase from .ddpg_impl import DDPGBaseImpl, DDPGBaseModules from .utility import DiscreteQFunctionMixin diff --git a/d3rlpy/algos/qlearning/torch/td3_impl.py b/d3rlpy/algos/qlearning/torch/td3_impl.py index 34cb127c..36e74816 100644 --- a/d3rlpy/algos/qlearning/torch/td3_impl.py +++ b/d3rlpy/algos/qlearning/torch/td3_impl.py @@ -2,9 +2,9 @@ import torch -from ....dataset import Shape from ....models.torch import ContinuousEnsembleQFunctionForwarder from ....torch_utility import TorchMiniBatch +from ....types import Shape from .ddpg_impl import DDPGImpl, DDPGModules __all__ = ["TD3Impl"] diff --git a/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py b/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py index 1edd9149..85bb92fd 100644 --- a/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py +++ b/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py @@ -2,9 +2,9 @@ import torch -from ....dataset import Shape from ....models.torch import ContinuousEnsembleQFunctionForwarder from ....torch_utility import TorchMiniBatch +from ....types import Shape from .ddpg_impl import DDPGModules from .td3_impl import TD3Impl diff --git a/d3rlpy/algos/transformer/base.py b/d3rlpy/algos/transformer/base.py index 6cefdcb9..a42f3b6a 100644 --- a/d3rlpy/algos/transformer/base.py +++ b/d3rlpy/algos/transformer/base.py @@ -10,7 +10,7 @@ from ...base import ImplBase, LearnableBase, LearnableConfig, save_config from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace -from ...dataset import Observation, ReplayBuffer, TrajectoryMiniBatch +from ...dataset import ReplayBuffer, TrajectoryMiniBatch from ...envs import GymEnv from ...logging import ( LOG, @@ -20,7 +20,7 @@ ) from ...metrics import evaluate_transformer_with_environment from ...torch_utility import TorchTrajectoryMiniBatch, train_api -from ...types import NDArray +from ...types import NDArray, Observation from ..utility import ( assert_action_space_with_dataset, build_scalers_with_trajectory_slicer, diff --git a/d3rlpy/algos/transformer/decision_transformer.py b/d3rlpy/algos/transformer/decision_transformer.py index c5988fe9..a5dd2298 100644 --- a/d3rlpy/algos/transformer/decision_transformer.py +++ b/d3rlpy/algos/transformer/decision_transformer.py @@ -4,7 +4,6 @@ from ...base import DeviceArg, register_learnable from ...constants import ActionSpace, PositionEncodingType -from ...dataset import Shape from ...models import ( EncoderFactory, OptimizerFactory, @@ -15,6 +14,7 @@ create_continuous_decision_transformer, create_discrete_decision_transformer, ) +from ...types import Shape from .base import TransformerAlgoBase, TransformerConfig from .torch.decision_transformer_impl import ( DecisionTransformerImpl, diff --git a/d3rlpy/algos/transformer/inputs.py b/d3rlpy/algos/transformer/inputs.py index c9bc0b0a..2a7e175f 100644 --- a/d3rlpy/algos/transformer/inputs.py +++ b/d3rlpy/algos/transformer/inputs.py @@ -5,7 +5,6 @@ import torch from ...dataset import ( - ObservationSequence, batch_pad_array, batch_pad_observations, get_axis_size, @@ -13,7 +12,7 @@ ) from ...preprocessing import ActionScaler, ObservationScaler, RewardScaler from ...torch_utility import convert_to_torch, convert_to_torch_recursively -from ...types import NDArray +from ...types import NDArray, ObservationSequence __all__ = ["TransformerInput", "TorchTransformerInput"] diff --git a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py index f847747c..30f7028e 100644 --- a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py +++ b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py @@ -6,12 +6,12 @@ import torch.nn.functional as F from torch.optim import Optimizer -from ....dataset import Shape from ....models.torch import ( ContinuousDecisionTransformer, DiscreteDecisionTransformer, ) from ....torch_utility import Modules, TorchTrajectoryMiniBatch, eval_api +from ....types import Shape from ..base import TransformerAlgoImplBase from ..inputs import TorchTransformerInput diff --git a/d3rlpy/base.py b/d3rlpy/base.py index 7e52b9eb..732949ff 100644 --- a/d3rlpy/base.py +++ b/d3rlpy/base.py @@ -10,7 +10,7 @@ from ._version import __version__ from .constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace -from .dataset import ReplayBuffer, Shape, detect_action_size_from_env +from .dataset import ReplayBuffer, detect_action_size_from_env from .envs import GymEnv from .logging import LOG, D3RLPyLogger from .preprocessing import ( @@ -23,6 +23,7 @@ ) from .serializable_config import DynamicConfig, generate_config_registration from .torch_utility import Checkpointer, Modules +from .types import Shape __all__ = [ "DeviceArg", diff --git a/d3rlpy/dataset/__init__.py b/d3rlpy/dataset/__init__.py index 37e8a816..2d5a48ec 100644 --- a/d3rlpy/dataset/__init__.py +++ b/d3rlpy/dataset/__init__.py @@ -7,6 +7,5 @@ from .replay_buffer import * from .trajectory_slicers import * from .transition_pickers import * -from .types import * from .utils import * from .writers import * diff --git a/d3rlpy/dataset/compat.py b/d3rlpy/dataset/compat.py index f5568d14..e1a2cb79 100644 --- a/d3rlpy/dataset/compat.py +++ b/d3rlpy/dataset/compat.py @@ -1,13 +1,12 @@ from typing import Optional from ..constants import ActionSpace -from ..types import NDArray +from ..types import NDArray, ObservationSequence from .buffers import InfiniteBuffer from .episode_generator import EpisodeGenerator from .replay_buffer import ReplayBuffer from .trajectory_slicers import TrajectorySlicerProtocol from .transition_pickers import TransitionPickerProtocol -from .types import ObservationSequence __all__ = ["MDPDataset"] diff --git a/d3rlpy/dataset/components.py b/d3rlpy/dataset/components.py index 600921a0..03b0190c 100644 --- a/d3rlpy/dataset/components.py +++ b/d3rlpy/dataset/components.py @@ -5,8 +5,7 @@ from typing_extensions import Protocol from ..constants import ActionSpace -from ..types import DType, NDArray -from .types import Observation, ObservationSequence +from ..types import DType, NDArray, Observation, ObservationSequence from .utils import ( get_dtype_from_observation, get_dtype_from_observation_sequence, diff --git a/d3rlpy/dataset/episode_generator.py b/d3rlpy/dataset/episode_generator.py index de15c82a..a41c0870 100644 --- a/d3rlpy/dataset/episode_generator.py +++ b/d3rlpy/dataset/episode_generator.py @@ -3,9 +3,8 @@ import numpy as np from typing_extensions import Protocol -from ..types import NDArray +from ..types import NDArray, ObservationSequence from .components import Episode, EpisodeBase -from .types import ObservationSequence from .utils import slice_observations __all__ = ["EpisodeGeneratorProtocol", "EpisodeGenerator"] diff --git a/d3rlpy/dataset/mini_batch.py b/d3rlpy/dataset/mini_batch.py index a2d5af6e..b65850a9 100644 --- a/d3rlpy/dataset/mini_batch.py +++ b/d3rlpy/dataset/mini_batch.py @@ -3,9 +3,8 @@ import numpy as np -from ..types import NDArray +from ..types import NDArray, Shape from .components import PartialTrajectory, Transition -from .types import Shape from .utils import ( cast_recursively, check_dtype, diff --git a/d3rlpy/dataset/replay_buffer.py b/d3rlpy/dataset/replay_buffer.py index 9c0389d9..70b05818 100644 --- a/d3rlpy/dataset/replay_buffer.py +++ b/d3rlpy/dataset/replay_buffer.py @@ -5,7 +5,7 @@ from ..constants import ActionSpace from ..envs import GymEnv from ..logging import LOG -from ..types import NDArray +from ..types import NDArray, Observation from .buffers import BufferProtocol, FIFOBuffer, InfiniteBuffer from .components import ( DatasetInfo, @@ -20,7 +20,6 @@ from .mini_batch import TrajectoryMiniBatch, TransitionMiniBatch from .trajectory_slicers import BasicTrajectorySlicer, TrajectorySlicerProtocol from .transition_pickers import BasicTransitionPicker, TransitionPickerProtocol -from .types import Observation from .utils import ( detect_action_size_from_env, detect_action_space, diff --git a/d3rlpy/dataset/types.py b/d3rlpy/dataset/types.py deleted file mode 100644 index d25c0d3c..00000000 --- a/d3rlpy/dataset/types.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Sequence, Union - -from ..types import NDArray - -__all__ = ["Observation", "ObservationSequence", "Shape"] - - -Observation = Union[NDArray, Sequence[NDArray]] -ObservationSequence = Union[NDArray, Sequence[NDArray]] -Shape = Union[Sequence[int], Sequence[Sequence[int]]] diff --git a/d3rlpy/dataset/utils.py b/d3rlpy/dataset/utils.py index 4a10fd6d..d6db605f 100644 --- a/d3rlpy/dataset/utils.py +++ b/d3rlpy/dataset/utils.py @@ -7,8 +7,7 @@ from ..constants import ActionSpace from ..envs.types import GymEnv -from ..types import DType, NDArray -from .types import Observation, ObservationSequence, Shape +from ..types import DType, NDArray, Observation, ObservationSequence, Shape __all__ = [ "retrieve_observation", diff --git a/d3rlpy/dataset/writers.py b/d3rlpy/dataset/writers.py index 31c1d794..e3ff16e0 100644 --- a/d3rlpy/dataset/writers.py +++ b/d3rlpy/dataset/writers.py @@ -3,10 +3,9 @@ import numpy as np from typing_extensions import Protocol -from ..types import NDArray +from ..types import NDArray, Observation, ObservationSequence from .buffers import BufferProtocol from .components import Episode, EpisodeBase, Signature -from .types import Observation, ObservationSequence from .utils import get_dtype_from_observation, get_shape_from_observation __all__ = [ diff --git a/d3rlpy/interface.py b/d3rlpy/interface.py index 0fa8b071..5525f1d0 100644 --- a/d3rlpy/interface.py +++ b/d3rlpy/interface.py @@ -2,9 +2,8 @@ from typing_extensions import Protocol -from .dataset import Observation from .preprocessing import ActionScaler, ObservationScaler, RewardScaler -from .types import NDArray +from .types import NDArray, Observation __all__ = ["QLearningAlgoProtocol", "StatefulTransformerAlgoProtocol"] diff --git a/d3rlpy/models/builders.py b/d3rlpy/models/builders.py index 266706bd..e7c9ffec 100644 --- a/d3rlpy/models/builders.py +++ b/d3rlpy/models/builders.py @@ -4,7 +4,7 @@ from torch import nn from ..constants import PositionEncodingType -from ..dataset import Shape +from ..types import Shape from .encoders import EncoderFactory from .q_functions import QFunctionFactory from .torch import ( diff --git a/d3rlpy/models/encoders.py b/d3rlpy/models/encoders.py index f516c066..cda5d1c2 100644 --- a/d3rlpy/models/encoders.py +++ b/d3rlpy/models/encoders.py @@ -1,8 +1,9 @@ from dataclasses import dataclass, field from typing import List, Optional, Union -from ..dataset import Shape, cast_flat_shape +from ..dataset import cast_flat_shape from ..serializable_config import DynamicConfig, generate_config_registration +from ..types import Shape from .torch import ( Encoder, EncoderWithAction, diff --git a/d3rlpy/models/torch/encoders.py b/d3rlpy/models/torch/encoders.py index bdc8cbfc..8eaa3033 100644 --- a/d3rlpy/models/torch/encoders.py +++ b/d3rlpy/models/torch/encoders.py @@ -5,8 +5,8 @@ import torch.nn.functional as F from torch import nn -from ...dataset import Shape from ...itertools import last_flag +from ...types import Shape __all__ = [ "Encoder", diff --git a/d3rlpy/ope/fqe.py b/d3rlpy/ope/fqe.py index 2ae7291a..b7b8c76f 100644 --- a/d3rlpy/ope/fqe.py +++ b/d3rlpy/ope/fqe.py @@ -4,7 +4,6 @@ from ..algos.qlearning import QLearningAlgoBase, QLearningAlgoImplBase from ..base import DeviceArg, LearnableConfig, register_learnable from ..constants import ALGO_NOT_GIVEN_ERROR, ActionSpace -from ..dataset import Observation, Shape from ..models.builders import ( create_continuous_q_function, create_discrete_q_function, @@ -12,7 +11,7 @@ from ..models.encoders import EncoderFactory, make_encoder_field from ..models.optimizers import OptimizerFactory, make_optimizer_field from ..models.q_functions import QFunctionFactory, make_q_func_field -from ..types import NDArray +from ..types import NDArray, Observation, Shape from .torch.fqe_impl import ( DiscreteFQEImpl, FQEBaseImpl, diff --git a/d3rlpy/ope/torch/fqe_impl.py b/d3rlpy/ope/torch/fqe_impl.py index c98e3925..9e89d01a 100644 --- a/d3rlpy/ope/torch/fqe_impl.py +++ b/d3rlpy/ope/torch/fqe_impl.py @@ -10,12 +10,12 @@ ContinuousQFunctionMixin, DiscreteQFunctionMixin, ) -from ...dataset import Shape from ...models.torch import ( ContinuousEnsembleQFunctionForwarder, DiscreteEnsembleQFunctionForwarder, ) from ...torch_utility import Modules, TorchMiniBatch, hard_sync +from ...types import Shape __all__ = ["FQEBaseImpl", "FQEImpl", "DiscreteFQEImpl", "FQEBaseModules"] diff --git a/d3rlpy/types.py b/d3rlpy/types.py index 87dfb6f3..05622637 100644 --- a/d3rlpy/types.py +++ b/d3rlpy/types.py @@ -1,9 +1,13 @@ -from typing import Any +from typing import Any, Sequence, Union import numpy.typing as npt -__all__ = ["NDArray", "DType"] +__all__ = ["NDArray", "DType", "Observation", "ObservationSequence", "Shape"] NDArray = npt.NDArray[Any] DType = npt.DTypeLike + +Observation = Union[NDArray, Sequence[NDArray]] +ObservationSequence = Union[NDArray, Sequence[NDArray]] +Shape = Union[Sequence[int], Sequence[Sequence[int]]] diff --git a/tests/algos/qlearning/test_explorers.py b/tests/algos/qlearning/test_explorers.py index bc13bd53..f79a334f 100644 --- a/tests/algos/qlearning/test_explorers.py +++ b/tests/algos/qlearning/test_explorers.py @@ -8,14 +8,13 @@ LinearDecayEpsilonGreedy, NormalNoise, ) -from d3rlpy.dataset import Observation from d3rlpy.preprocessing import ( ActionScaler, MinMaxActionScaler, ObservationScaler, RewardScaler, ) -from d3rlpy.types import NDArray +from d3rlpy.types import NDArray, Observation class DummyAlgo: diff --git a/tests/dataset/test_episode_generator.py b/tests/dataset/test_episode_generator.py index a85001e4..9f88afb8 100644 --- a/tests/dataset/test_episode_generator.py +++ b/tests/dataset/test_episode_generator.py @@ -1,7 +1,8 @@ import numpy as np import pytest -from d3rlpy.dataset import EpisodeGenerator, Shape +from d3rlpy.dataset import EpisodeGenerator +from d3rlpy.types import Shape from ..testing_utils import create_observations diff --git a/tests/dataset/test_io.py b/tests/dataset/test_io.py index cef85538..875a5ccc 100644 --- a/tests/dataset/test_io.py +++ b/tests/dataset/test_io.py @@ -3,7 +3,8 @@ import numpy as np import pytest -from d3rlpy.dataset import Episode, Shape, dump, load +from d3rlpy.dataset import Episode, dump, load +from d3rlpy.types import Shape from ..testing_utils import create_episode diff --git a/tests/dataset/test_mini_batch.py b/tests/dataset/test_mini_batch.py index c19ee053..43c0dd66 100644 --- a/tests/dataset/test_mini_batch.py +++ b/tests/dataset/test_mini_batch.py @@ -1,7 +1,8 @@ import numpy as np import pytest -from d3rlpy.dataset import Shape, TrajectoryMiniBatch, TransitionMiniBatch +from d3rlpy.dataset import TrajectoryMiniBatch, TransitionMiniBatch +from d3rlpy.types import Shape from ..testing_utils import create_partial_trajectory, create_transition diff --git a/tests/dataset/test_replay_buffer.py b/tests/dataset/test_replay_buffer.py index bf4cdd31..9e14ab89 100644 --- a/tests/dataset/test_replay_buffer.py +++ b/tests/dataset/test_replay_buffer.py @@ -11,10 +11,10 @@ FIFOBuffer, InfiniteBuffer, ReplayBuffer, - Shape, create_fifo_replay_buffer, create_infinite_replay_buffer, ) +from d3rlpy.types import Shape from ..testing_utils import create_episode, create_observation diff --git a/tests/dataset/test_trajectory_slicer.py b/tests/dataset/test_trajectory_slicer.py index 9faf3da7..69ffb066 100644 --- a/tests/dataset/test_trajectory_slicer.py +++ b/tests/dataset/test_trajectory_slicer.py @@ -7,8 +7,8 @@ BasicTrajectorySlicer, FrameStackTrajectorySlicer, FrameStackTransitionPicker, - Shape, ) +from d3rlpy.types import Shape from ..testing_utils import create_episode diff --git a/tests/dataset/test_transition_pickers.py b/tests/dataset/test_transition_pickers.py index 634e3f76..1d63f39b 100644 --- a/tests/dataset/test_transition_pickers.py +++ b/tests/dataset/test_transition_pickers.py @@ -5,8 +5,8 @@ BasicTransitionPicker, FrameStackTransitionPicker, MultiStepTransitionPicker, - Shape, ) +from d3rlpy.types import Shape from ..testing_utils import create_episode diff --git a/tests/dataset/test_utils.py b/tests/dataset/test_utils.py index 69b4d1bf..6e6aa4a6 100644 --- a/tests/dataset/test_utils.py +++ b/tests/dataset/test_utils.py @@ -7,7 +7,6 @@ from d3rlpy.constants import ActionSpace from d3rlpy.dataset import ( - Shape, batch_pad_array, batch_pad_observations, cast_recursively, @@ -26,7 +25,7 @@ stack_observations, stack_recent_observations, ) -from d3rlpy.types import DType +from d3rlpy.types import DType, Shape from ..testing_utils import create_observation, create_observations diff --git a/tests/dataset/test_writers.py b/tests/dataset/test_writers.py index 42048571..7f217abc 100644 --- a/tests/dataset/test_writers.py +++ b/tests/dataset/test_writers.py @@ -9,8 +9,8 @@ ExperienceWriter, InfiniteBuffer, LastFrameWriterPreprocess, - Shape, ) +from d3rlpy.types import Shape from ..testing_utils import create_episode, create_observation diff --git a/tests/metrics/test_evaluators.py b/tests/metrics/test_evaluators.py index 700e6ce8..cb47b4b8 100644 --- a/tests/metrics/test_evaluators.py +++ b/tests/metrics/test_evaluators.py @@ -8,7 +8,6 @@ BasicTransitionPicker, Episode, InfiniteBuffer, - Observation, ReplayBuffer, TransitionMiniBatch, ) @@ -29,7 +28,7 @@ ObservationScaler, RewardScaler, ) -from d3rlpy.types import NDArray +from d3rlpy.types import NDArray, Observation from ..testing_utils import create_episode diff --git a/tests/metrics/test_utility.py b/tests/metrics/test_utility.py index 47d39500..e32e48d6 100644 --- a/tests/metrics/test_utility.py +++ b/tests/metrics/test_utility.py @@ -7,10 +7,9 @@ import pytest from gym import spaces -from d3rlpy.dataset import Observation from d3rlpy.metrics.utility import evaluate_qlearning_with_environment from d3rlpy.preprocessing import ActionScaler, ObservationScaler, RewardScaler -from d3rlpy.types import NDArray +from d3rlpy.types import NDArray, Observation class DummyAlgo: diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 3ae6db36..32d91477 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -2,14 +2,7 @@ import numpy as np -from d3rlpy.dataset import ( - Episode, - Observation, - ObservationSequence, - PartialTrajectory, - Shape, - Transition, -) +from d3rlpy.dataset import Episode, PartialTrajectory, Transition from d3rlpy.preprocessing import ( ActionScaler, MinMaxActionScaler, @@ -18,7 +11,7 @@ ObservationScaler, RewardScaler, ) -from d3rlpy.types import DType, NDArray +from d3rlpy.types import DType, NDArray, Observation, ObservationSequence, Shape @overload