Skip to content

Commit

Permalink
fix device allocation for the batch
Browse files Browse the repository at this point in the history
  • Loading branch information
limiteinductive committed Jan 12, 2024
1 parent 07f4c97 commit 9e36fdb
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions tests/training_utils/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@ def __len__(self):
return 20

def __getitem__(self, _: int) -> MockBatch:
return MockBatch(
inputs=torch.randn(1, 10, device=torch.device("cuda:0")),
targets=torch.randn(1, 10, device=torch.device("cuda:0")),
)
return MockBatch(inputs=torch.randn(1, 10), targets=torch.randn(1, 10))

def collate_fn(self, batch: list[MockBatch]) -> MockBatch:
return MockBatch(
Expand Down Expand Up @@ -66,13 +63,15 @@ def load_models(self) -> dict[str, fl.Module]:
return {"mock_model": self.mock_model}

def compute_loss(self, batch: MockBatch) -> Tensor:
inputs, targets = batch.inputs, batch.targets
inputs, targets = batch.inputs.to(self.device), batch.targets.to(self.device)
outputs = self.mock_model(inputs)
return norm(outputs - targets)


@pytest.fixture
def mock_config(test_device: torch.device) -> MockConfig:
if not test_device.type == "cuda":
pytest.skip("Skipping test because test_device is not CUDA")
config = MockConfig.load_from_toml(Path(__file__).parent / "mock_config.toml")
config.training.gpu_index = test_device.index
return config
Expand Down

0 comments on commit 9e36fdb

Please sign in to comment.