Skip to content

Commit

Permalink
Fixes bugs in stitcher.py that show up in distributed setting
Browse files Browse the repository at this point in the history
  • Loading branch information
vinamarora8 committed Nov 11, 2024
1 parent 1bc9c73 commit bf14232
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"wandb~=0.15",
"torchtyping~=0.1",
"pydantic~=2.0",
"tabulate",
],
extras_require={
"dev": [
Expand Down
9 changes: 7 additions & 2 deletions torch_brain/utils/stitcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,16 +189,21 @@ def on_validation_epoch_end(self, trainer, pl_module, prefix="val"):
for task_name in self.metrics[recording_id].keys():
for metric_name in self.metrics[recording_id][task_name].keys():
metrics[f"{recording_id}/{task_name}/{metric_name}/{prefix}"] = (
self.metrics[recording_id][task_name][metric_name].compute()
self.metrics[recording_id][task_name][metric_name]
.to(pl_module.device)
.compute()
)
self.metrics[recording_id][task_name][metric_name].reset()
self.metrics[recording_id][task_name][metric_name].to("cpu")

# log the metrics
self.log_dict(metrics)
logging.info(f"Logged {len(metrics)} {prefix} metrics.")

# compute the average metric
metrics[f"average_{prefix}_metric"] = np.array(list(metrics.values())).mean()
metrics[f"average_{prefix}_metric"] = torch.tensor(
list(metrics.values())
).mean()

metrics_data = []
for metric_name, metric_value in metrics.items():
Expand Down

0 comments on commit bf14232

Please sign in to comment.