Skip to content

Commit

Permalink
[Refactor] Use default device instead of CPU in losses
Browse files Browse the repository at this point in the history
ghstack-source-id: 52a013a04a763bdb8c1c77a43a0984babe32bd77
Pull Request resolved: #2687
  • Loading branch information
vmoens committed Jan 10, 2025
1 parent 11305a7 commit b7db5a5
Show file tree
Hide file tree
Showing 7 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def __init__(
try:
device = next(self.parameters()).device
except AttributeError:
device = torch.device("cpu")
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
if bool(min_alpha) ^ bool(max_alpha):
min_alpha = min_alpha if min_alpha else 0.0
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def __init__(
try:
device = next(self.parameters()).device
except AttributeError:
device = torch.device("cpu")
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
if bool(min_alpha) ^ bool(max_alpha):
min_alpha = min_alpha if min_alpha else 0.0
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(
try:
device = next(self.parameters()).device
except AttributeError:
device = torch.device("cpu")
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()

self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
if bool(min_alpha) ^ bool(max_alpha):
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def __init__(
try:
device = next(self.parameters()).device
except AttributeError:
device = torch.device("cpu")
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()

self.register_buffer("alpha_init", torch.as_tensor(alpha_init, device=device))
self.register_buffer(
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def __init__(
try:
device = next(self.parameters()).device
except (AttributeError, StopIteration):
device = torch.device("cpu")
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()

self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device))
if critic_coef is not None:
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def __init__(
try:
device = next(self.parameters()).device
except AttributeError:
device = torch.device("cpu")
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()

self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
self.register_buffer(
Expand Down
4 changes: 2 additions & 2 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def __init__(
try:
device = next(self.parameters()).device
except AttributeError:
device = torch.device("cpu")
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
if bool(min_alpha) ^ bool(max_alpha):
min_alpha = min_alpha if min_alpha else 0.0
Expand Down Expand Up @@ -1102,7 +1102,7 @@ def __init__(
try:
device = next(self.parameters()).device
except AttributeError:
device = torch.device("cpu")
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()

self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
if bool(min_alpha) ^ bool(max_alpha):
Expand Down

0 comments on commit b7db5a5

Please sign in to comment.