Skip to content

Commit

Permalink
Merge branch 'main' into keyboard_interrupt
Browse files Browse the repository at this point in the history
  • Loading branch information
bqth29 authored Dec 15, 2024
2 parents 681bad5 + 68c2541 commit a06076f
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
sympy==1.12
tomli==2.0.1
torch==2.0.1
torch==2.2.0
tqdm==4.66.3
typing_extensions==4.7.1
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
dependencies = [
"numpy<2",
"sympy",
"torch>=2.0.1",
"torch>=2.2.0",
"tqdm",
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch
from numpy import minimum
from tqdm import tqdm
from tqdm.auto import tqdm

from .environment import ENVIRONMENT
from .simulated_bifurcation_engine import SimulatedBifurcationEngine
Expand Down
6 changes: 4 additions & 2 deletions src/simulated_bifurcation/optimizer/stop_window.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Tuple, Union

import torch
from tqdm import tqdm
from tqdm.auto import tqdm


class StopWindow:
Expand Down Expand Up @@ -66,7 +66,9 @@ def __init_tensor(self, dtype: torch.dtype) -> torch.Tensor:
return torch.zeros(self.n_agents, device=self.device, dtype=dtype)

def __init_energies(self) -> None:
self.energies = torch.tensor([float("inf") for _ in range(self.n_agents)])
self.energies = torch.tensor(
[float("inf") for _ in range(self.n_agents)], device=self.device
)

def __init_tensors(self) -> None:
self.stability = self.__init_tensor(torch.int16)
Expand Down

0 comments on commit a06076f

Please sign in to comment.