Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Critical bugfixes to make examples/poyo/train.py work #21

Merged
merged 9 commits into from
Nov 13, 2024
7 changes: 6 additions & 1 deletion examples/poyo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,12 @@ def training_step(self, batch, batch_idx):
spec.loss_fn, spec.type, output, target, weights
)

loss = loss + taskwise_loss[readout_id] * len(target)
# count the number of sequences in the batch that have the current task
num_sequences_with_current_task = torch.any(
batch["output_decoder_index"] == MODALITIY_REGISTRY[readout_id].id,
dim=1,
).sum()
loss = loss + taskwise_loss[readout_id] * num_sequences_with_current_task

batch_size = batch["input_unit_index"].shape[0]
# TODO change batch_size when POYOPlusEfficient is used
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"wandb~=0.15",
"torchtyping~=0.1",
"pydantic~=2.0",
"tabulate",
mazabou marked this conversation as resolved.
Show resolved Hide resolved
],
extras_require={
"dev": [
Expand Down
22 changes: 11 additions & 11 deletions tests/test_stitcher_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from temporaldata import Interval

from torch_brain.data.sampler import DistributedStitchingFixedWindowBatchSampler
from torch_brain.data.sampler import DistributedStitchingFixedWindowSampler


def test_distributed_stitching_sampler():
Expand All @@ -18,39 +18,39 @@ def test_distributed_stitching_sampler():
num_replicas = 2

# Test rank 0
sampler0 = DistributedStitchingFixedWindowBatchSampler(
sampler0 = DistributedStitchingFixedWindowSampler(
interval_dict=interval_dict,
window_length=window_length,
step=step,
batch_size=batch_size,
num_replicas=num_replicas,
rank=0,
)
samples0 = list(sampler0)

# Test rank 1
sampler1 = DistributedStitchingFixedWindowBatchSampler(
sampler1 = DistributedStitchingFixedWindowSampler(
interval_dict=interval_dict,
window_length=window_length,
step=step,
batch_size=batch_size,
num_replicas=num_replicas,
rank=1,
)
samples1 = list(sampler1)

# Get all batches from both samplers
batches0 = list(sampler0)
batches1 = list(sampler1)
batches0 = [
samples0[i : i + batch_size] for i in range(0, len(samples0), batch_size)
]
batches1 = [
samples1[i : i + batch_size] for i in range(0, len(samples1), batch_size)
]

# Basic checks
assert len(batches0) > 0
assert len(batches1) > 0

# Check batch sizes
for batch in batches0[:-1]: # All except last batch
assert len(batch) == batch_size
for batch in batches1[:-1]:
assert len(batch) == batch_size

# Check window properties
for batch in batches0:
for window in batch:
Expand Down
13 changes: 4 additions & 9 deletions torch_brain/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,8 @@ def __iter__(self):
return iter(indices)


class DistributedStitchingFixedWindowBatchSampler(torch.utils.data.BatchSampler):
r"""A batch sampler designed specifically for evaluation that enables sliding window
class DistributedStitchingFixedWindowSampler(torch.utils.data.DistributedSampler):
r"""A sampler designed specifically for evaluation that enables sliding window
inference with prediction stitching across distributed processes.

This sampler divides sequences into overlapping windows and distributes them across
Expand Down Expand Up @@ -463,15 +463,10 @@ def _generate_indices(self) -> List[DatasetIndex]:
return indices, sequence_index

def __iter__(self):
# Create batches from our pre-computed indices
batches = [
self.indices[i : i + self.batch_size]
for i in range(0, len(self.indices), self.batch_size)
]
return iter(batches)
return iter(self.indices)

def __len__(self) -> int:
return (self.num_samples + self.batch_size - 1) // self.batch_size
return self.num_samples

def set_epoch(self, epoch: int) -> None:
"""Set the epoch number. Not strictly necessary for sequential sampler
Expand Down
16 changes: 8 additions & 8 deletions torch_brain/models/poyo_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn as nn
from torchtyping import TensorType

from torch_brain.data import chain, pad, track_mask
from torch_brain.data import chain, pad8, track_mask8
from torch_brain.nn import (
Embedding,
FeedForward,
Expand Down Expand Up @@ -335,17 +335,17 @@ def __call__(self, data):

batch = {
# input sequence
"input_unit_index": pad(spike_unit_index),
"input_timestamps": pad(spike_timestamps),
"input_token_type": pad(spike_token_type_index),
"input_mask": track_mask(spike_unit_index),
"input_unit_index": pad8(spike_unit_index),
"input_timestamps": pad8(spike_timestamps),
"input_token_type": pad8(spike_token_type_index),
"input_mask": track_mask8(spike_unit_index),
# latent sequence
"latent_index": latent_index,
"latent_timestamps": latent_timestamps,
# output sequence
"output_session_index": pad(session_index),
"output_timestamps": pad(output_timestamps),
"output_decoder_index": pad(output_task_index),
"output_session_index": pad8(session_index),
"output_timestamps": pad8(output_timestamps),
"output_decoder_index": pad8(output_task_index),
# ground truth targets
"target_values": chain(output_values, allow_missing_keys=True),
"target_weights": chain(output_weights, allow_missing_keys=True),
Expand Down
22 changes: 15 additions & 7 deletions torch_brain/utils/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from torch_brain.data import Dataset, collate
from torch_brain.data.sampler import (
DistributedStitchingFixedWindowBatchSampler,
DistributedStitchingFixedWindowSampler,
RandomFixedWindowSampler,
)
from torch_brain.models import POYOPlusTokenizer
Expand Down Expand Up @@ -130,18 +130,22 @@ def train_dataloader(self):
return train_loader

def val_dataloader(self):
val_sampler = DistributedStitchingFixedWindowBatchSampler(
batch_size = self.cfg.eval_batch_size or self.cfg.batch_size

val_sampler = DistributedStitchingFixedWindowSampler(
interval_dict=self.val_dataset.get_sampling_intervals(),
window_length=self.sequence_length,
step=self.sequence_length / 2,
batch_size=self.cfg.eval_batch_size or self.cfg.batch_size,
mazabou marked this conversation as resolved.
Show resolved Hide resolved
batch_size=batch_size,
num_replicas=self.trainer.world_size,
rank=self.trainer.global_rank,
)

val_loader = DataLoader(
self.val_dataset,
batch_sampler=val_sampler,
sampler=val_sampler,
shuffle=False,
batch_size=batch_size,
collate_fn=collate,
num_workers=0,
drop_last=False,
Expand All @@ -153,18 +157,22 @@ def val_dataloader(self):
return val_loader

def test_dataloader(self):
test_sampler = DistributedStitchingFixedWindowBatchSampler(
batch_size = self.cfg.eval_batch_size or self.cfg.batch_size

test_sampler = DistributedStitchingFixedWindowSampler(
interval_dict=self.test_dataset.get_sampling_intervals(),
window_length=self.sequence_length,
step=self.sequence_length / 2,
batch_size=self.cfg.eval_batch_size or self.cfg.batch_size,
batch_size=batch_size,
num_replicas=self.trainer.world_size,
rank=self.trainer.global_rank,
)

test_loader = DataLoader(
self.test_dataset,
batch_sampler=test_sampler,
sampler=test_sampler,
shuffle=False,
batch_size=batch_size,
collate_fn=collate,
num_workers=0,
)
Expand Down
13 changes: 11 additions & 2 deletions torch_brain/utils/stitcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx)

# update the cache with the predictions and targets
for readout_index in torch.unique(batch["output_decoder_index"]):
if readout_index.item() == 0:
# skip the padding token
continue

mask = batch["output_decoder_index"] == readout_index
readout_id = torch_brain.get_modality_by_id(readout_index.item())

Expand Down Expand Up @@ -189,16 +193,21 @@ def on_validation_epoch_end(self, trainer, pl_module, prefix="val"):
for task_name in self.metrics[recording_id].keys():
for metric_name in self.metrics[recording_id][task_name].keys():
metrics[f"{recording_id}/{task_name}/{metric_name}/{prefix}"] = (
self.metrics[recording_id][task_name][metric_name].compute()
self.metrics[recording_id][task_name][metric_name]
.to(pl_module.device)
.compute()
)
self.metrics[recording_id][task_name][metric_name].reset()
self.metrics[recording_id][task_name][metric_name].to("cpu")

# log the metrics
self.log_dict(metrics)
logging.info(f"Logged {len(metrics)} {prefix} metrics.")

# compute the average metric
metrics[f"average_{prefix}_metric"] = np.array(list(metrics.values())).mean()
metrics[f"average_{prefix}_metric"] = torch.tensor(
list(metrics.values())
).mean()

metrics_data = []
for metric_name, metric_value in metrics.items():
Expand Down
Loading