-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathnet.py
94 lines (70 loc) · 2.27 KB
/
net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torch.nn as nn
from torch.distributions import Normal
from typing import Tuple
def soft_clamp(
x: torch.Tensor, bound: tuple
) -> torch.Tensor:
low, high = bound
#x = torch.tanh(x)
x = low + 0.5 * (high - low) * (x + 1)
return x
def MLP(
input_dim: int,
hidden_dim: int,
depth: int,
output_dim: int,
final_activation: str
) -> torch.nn.modules.container.Sequential:
layers = [nn.Linear(input_dim, hidden_dim), nn.ReLU()]
for _ in range(depth -1):
layers.append(nn.Linear(hidden_dim, hidden_dim))
layers.append(nn.ReLU())
layers.append(nn.Linear(hidden_dim, output_dim))
if final_activation == 'relu':
layers.append(nn.ReLU())
elif final_activation == 'tanh':
layers.append(nn.Tanh())
return nn.Sequential(*layers)
class ValueMLP(nn.Module):
_net: torch.nn.modules.container.Sequential
def __init__(
self, state_dim: int, hidden_dim: int, depth: int
) -> None:
super().__init__()
self._net = MLP(state_dim, hidden_dim, depth, 1, 'relu')
def forward(
self, s: torch.Tensor
) -> torch.Tensor:
return self._net(s)
class QMLP(nn.Module):
_net: torch.nn.modules.container.Sequential
def __init__(
self,
state_dim: int, action_dim: int, hidden_dim: int, depth:int
) -> None:
super().__init__()
self._net = MLP((state_dim + action_dim), hidden_dim, depth, 1, 'relu')
def forward(
self, s: torch.Tensor, a: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
sa = torch.cat([s, a], dim=1)
return self._net(sa)
class GaussPolicyMLP(nn.Module):
_net: torch.nn.modules.container.Sequential
_log_std_bound: tuple
def __init__(
self,
state_dim: int, hidden_dim: int, depth: int, action_dim: int,
) -> None:
super().__init__()
self._net = MLP(state_dim, hidden_dim, depth, (2 * action_dim), 'tanh')
self._log_std_bound = (-5., 0.)
def forward(
self, s: torch.Tensor
) -> torch.distributions:
mu, log_std = self._net(s).chunk(2, dim=-1)
log_std = soft_clamp(log_std, self._log_std_bound)
std = log_std.exp()
dist = Normal(mu, std)
return dist