Skip to content

Commit

Permalink
Move d3rlpy.dataset.types to d3rlpy.types
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Oct 21, 2023
1 parent d7ad45f commit 3543576
Show file tree
Hide file tree
Showing 63 changed files with 73 additions and 99 deletions.
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/awac.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import ActionSpace
from ...dataset import Shape
from ...models.builders import (
create_continuous_q_function,
create_normal_policy,
Expand All @@ -13,6 +12,7 @@
from ...models.optimizers import OptimizerFactory, make_optimizer_field
from ...models.q_functions import QFunctionFactory, make_q_func_field
from ...models.torch import Parameter
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.awac_impl import AWACImpl
from .torch.sac_impl import SACModules
Expand Down
3 changes: 1 addition & 2 deletions d3rlpy/algos/qlearning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from ...base import ImplBase, LearnableBase, LearnableConfig, save_config
from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace
from ...dataset import (
Observation,
ReplayBuffer,
TransitionMiniBatch,
check_non_1d_array,
Expand All @@ -46,7 +45,7 @@
sync_optimizer_state,
train_api,
)
from ...types import NDArray
from ...types import NDArray, Observation
from ..utility import (
assert_action_space_with_dataset,
assert_action_space_with_env,
Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import ActionSpace
from ...dataset import Shape
from ...models.builders import (
create_categorical_policy,
create_deterministic_policy,
create_normal_policy,
)
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.optimizers import OptimizerFactory, make_optimizer_field
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.bc_impl import (
BCBaseImpl,
Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import ActionSpace
from ...dataset import Shape
from ...models.builders import (
create_categorical_policy,
create_conditional_vae,
Expand All @@ -14,6 +13,7 @@
from ...models.optimizers import OptimizerFactory, make_optimizer_field
from ...models.q_functions import QFunctionFactory, make_q_func_field
from ...models.torch import CategoricalPolicy, compute_output_size
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.bcq_impl import (
BCQImpl,
Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/bear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import ActionSpace
from ...dataset import Shape
from ...models.builders import (
create_conditional_vae,
create_continuous_q_function,
Expand All @@ -13,6 +12,7 @@
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.optimizers import OptimizerFactory, make_optimizer_field
from ...models.q_functions import QFunctionFactory, make_q_func_field
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.bear_impl import BEARImpl, BEARModules

Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import ActionSpace
from ...dataset import Shape
from ...models.builders import (
create_continuous_q_function,
create_discrete_q_function,
Expand All @@ -13,6 +12,7 @@
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.optimizers import OptimizerFactory, make_optimizer_field
from ...models.q_functions import QFunctionFactory, make_q_func_field
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.cql_impl import CQLImpl, CQLModules, DiscreteCQLImpl
from .torch.dqn_impl import DQNModules
Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/crr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import ActionSpace
from ...dataset import Shape
from ...models.builders import (
create_continuous_q_function,
create_normal_policy,
)
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.optimizers import OptimizerFactory, make_optimizer_field
from ...models.q_functions import QFunctionFactory, make_q_func_field
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.crr_impl import CRRImpl, CRRModules

Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import ActionSpace
from ...dataset import Shape
from ...models.builders import (
create_continuous_q_function,
create_deterministic_policy,
)
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.optimizers import OptimizerFactory, make_optimizer_field
from ...models.q_functions import QFunctionFactory, make_q_func_field
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.ddpg_impl import DDPGImpl, DDPGModules

Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import ActionSpace
from ...dataset import Shape
from ...models.builders import create_discrete_q_function
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.optimizers import OptimizerFactory, make_optimizer_field
from ...models.q_functions import QFunctionFactory, make_q_func_field
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.dqn_impl import DoubleDQNImpl, DQNImpl, DQNModules

Expand Down
3 changes: 1 addition & 2 deletions d3rlpy/algos/qlearning/explorers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@

import numpy as np

from ...dataset import Observation
from ...interface import QLearningAlgoProtocol
from ...preprocessing.action_scalers import MinMaxActionScaler
from ...types import NDArray
from ...types import NDArray, Observation

__all__ = [
"Explorer",
Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import ActionSpace
from ...dataset import Shape
from ...models.builders import (
create_continuous_q_function,
create_normal_policy,
Expand All @@ -11,6 +10,7 @@
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.optimizers import OptimizerFactory, make_optimizer_field
from ...models.q_functions import MeanQFunctionFactory
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.iql_impl import IQLImpl, IQLModules

Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/nfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import ActionSpace
from ...dataset import Shape
from ...models.builders import create_discrete_q_function
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.optimizers import OptimizerFactory, make_optimizer_field
from ...models.q_functions import QFunctionFactory, make_q_func_field
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.dqn_impl import DQNImpl, DQNModules

Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/plas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import ActionSpace
from ...dataset import Shape
from ...models.builders import (
create_conditional_vae,
create_continuous_q_function,
Expand All @@ -12,6 +11,7 @@
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.optimizers import OptimizerFactory, make_optimizer_field
from ...models.q_functions import QFunctionFactory, make_q_func_field
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.plas_impl import (
PLASImpl,
Expand Down
3 changes: 1 addition & 2 deletions d3rlpy/algos/qlearning/random_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import ActionSpace
from ...dataset import Observation, Shape
from ...torch_utility import TorchMiniBatch
from ...types import NDArray
from ...types import NDArray, Observation, Shape
from .base import QLearningAlgoBase

__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import ActionSpace
from ...dataset import Shape
from ...models.builders import (
create_categorical_policy,
create_continuous_q_function,
Expand All @@ -14,6 +13,7 @@
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.optimizers import OptimizerFactory, make_optimizer_field
from ...models.q_functions import QFunctionFactory, make_q_func_field
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.sac_impl import (
DiscreteSACImpl,
Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import ActionSpace
from ...dataset import Shape
from ...models.builders import (
create_continuous_q_function,
create_deterministic_policy,
)
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.optimizers import OptimizerFactory, make_optimizer_field
from ...models.q_functions import QFunctionFactory, make_q_func_field
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.ddpg_impl import DDPGModules
from .torch.td3_impl import TD3Impl
Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/td3_plus_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import ActionSpace
from ...dataset import Shape
from ...models.builders import (
create_continuous_q_function,
create_deterministic_policy,
)
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.optimizers import OptimizerFactory, make_optimizer_field
from ...models.q_functions import QFunctionFactory, make_q_func_field
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.ddpg_impl import DDPGModules
from .torch.td3_plus_bc_impl import TD3PlusBCImpl
Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/torch/awac_impl.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import torch
import torch.nn.functional as F

from ....dataset import Shape
from ....models.torch import (
ContinuousEnsembleQFunctionForwarder,
build_gaussian_distribution,
)
from ....torch_utility import TorchMiniBatch
from ....types import Shape
from .sac_impl import SACImpl, SACModules

__all__ = ["AWACImpl"]
Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/torch/bc_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from torch.optim import Optimizer

from ....dataset import Shape
from ....models.torch import (
CategoricalPolicy,
DeterministicPolicy,
Expand All @@ -16,6 +15,7 @@
compute_stochastic_imitation_loss,
)
from ....torch_utility import Modules, TorchMiniBatch
from ....types import Shape
from ..base import QLearningAlgoImplBase

__all__ = ["BCImpl", "DiscreteBCImpl", "BCModules", "DiscreteBCModules"]
Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/torch/bcq_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch.nn.functional as F
from torch.optim import Optimizer

from ....dataset import Shape
from ....models.torch import (
CategoricalPolicy,
ConditionalVAE,
Expand All @@ -19,6 +18,7 @@
forward_vae_decode,
)
from ....torch_utility import TorchMiniBatch, soft_sync
from ....types import Shape
from .ddpg_impl import DDPGBaseImpl, DDPGBaseModules
from .dqn_impl import DoubleDQNImpl, DQNLoss, DQNModules

Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/torch/bear_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
from torch.optim import Optimizer

from ....dataset import Shape
from ....models.torch import (
ConditionalVAE,
ContinuousEnsembleQFunctionForwarder,
Expand All @@ -15,6 +14,7 @@
forward_vae_sample_n,
)
from ....torch_utility import TorchMiniBatch
from ....types import Shape
from .sac_impl import SACImpl, SACModules

__all__ = ["BEARImpl", "BEARModules"]
Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/torch/cql_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import torch.nn.functional as F
from torch.optim import Optimizer

from ....dataset import Shape
from ....models.torch import (
ContinuousEnsembleQFunctionForwarder,
DiscreteEnsembleQFunctionForwarder,
Parameter,
build_squashed_gaussian_distribution,
)
from ....torch_utility import TorchMiniBatch
from ....types import Shape
from .dqn_impl import DoubleDQNImpl, DQNLoss, DQNModules
from .sac_impl import SACImpl, SACModules

Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/torch/crr_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import torch
import torch.nn.functional as F

from ....dataset import Shape
from ....models.torch import (
ContinuousEnsembleQFunctionForwarder,
NormalPolicy,
build_gaussian_distribution,
)
from ....torch_utility import TorchMiniBatch, hard_sync, soft_sync
from ....types import Shape
from .ddpg_impl import DDPGBaseImpl, DDPGBaseModules

__all__ = ["CRRImpl", "CRRModules"]
Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/torch/ddpg_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from torch import nn
from torch.optim import Optimizer

from ....dataset import Shape
from ....models.torch import ContinuousEnsembleQFunctionForwarder, Policy
from ....torch_utility import Modules, TorchMiniBatch, hard_sync, soft_sync
from ....types import Shape
from ..base import QLearningAlgoImplBase
from .utility import ContinuousQFunctionMixin

Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/torch/dqn_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from torch.optim import Optimizer

from ....dataclass_utils import asdict_as_float
from ....dataset import Shape
from ....models.torch import DiscreteEnsembleQFunctionForwarder
from ....torch_utility import Modules, TorchMiniBatch, hard_sync
from ....types import Shape
from ..base import QLearningAlgoImplBase
from .utility import DiscreteQFunctionMixin

Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/torch/iql_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

import torch

from ....dataset import Shape
from ....models.torch import (
ContinuousEnsembleQFunctionForwarder,
NormalPolicy,
ValueFunction,
build_gaussian_distribution,
)
from ....torch_utility import TorchMiniBatch
from ....types import Shape
from .ddpg_impl import DDPGBaseImpl, DDPGBaseModules

__all__ = ["IQLImpl", "IQLModules"]
Expand Down
Loading

0 comments on commit 3543576

Please sign in to comment.