From 1bc9c738975174e921c79f61c42237334e429d02 Mon Sep 17 00:00:00 2001 From: Vinam Arora Date: Mon, 11 Nov 2024 15:03:24 -0500 Subject: [PATCH 1/9] Changes BatchSampler to DistributedSampler for val/test dataloaders Lightning doesn't play well with BatchSamplers in validation/testing dataloaders --- torch_brain/data/sampler.py | 11 +++-------- torch_brain/utils/datamodules.py | 16 ++++++++++++---- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/torch_brain/data/sampler.py b/torch_brain/data/sampler.py index f4f0b51..ae6e089 100644 --- a/torch_brain/data/sampler.py +++ b/torch_brain/data/sampler.py @@ -331,7 +331,7 @@ def __iter__(self): return iter(indices) -class DistributedStitchingFixedWindowBatchSampler(torch.utils.data.BatchSampler): +class DistributedStitchingFixedWindowBatchSampler(torch.utils.data.DistributedSampler): r"""A batch sampler designed specifically for evaluation that enables sliding window inference with prediction stitching across distributed processes. @@ -463,15 +463,10 @@ def _generate_indices(self) -> List[DatasetIndex]: return indices, sequence_index def __iter__(self): - # Create batches from our pre-computed indices - batches = [ - self.indices[i : i + self.batch_size] - for i in range(0, len(self.indices), self.batch_size) - ] - return iter(batches) + return iter(self.indices) def __len__(self) -> int: - return (self.num_samples + self.batch_size - 1) // self.batch_size + return self.num_samples def set_epoch(self, epoch: int) -> None: """Set the epoch number. Not strictly necessary for sequential sampler diff --git a/torch_brain/utils/datamodules.py b/torch_brain/utils/datamodules.py index 02a9afa..c4e690c 100644 --- a/torch_brain/utils/datamodules.py +++ b/torch_brain/utils/datamodules.py @@ -130,18 +130,22 @@ def train_dataloader(self): return train_loader def val_dataloader(self): + batch_size = self.cfg.eval_batch_size or self.cfg.batch_size + val_sampler = DistributedStitchingFixedWindowBatchSampler( interval_dict=self.val_dataset.get_sampling_intervals(), window_length=self.sequence_length, step=self.sequence_length / 2, - batch_size=self.cfg.eval_batch_size or self.cfg.batch_size, + batch_size=batch_size, num_replicas=self.trainer.world_size, rank=self.trainer.global_rank, ) val_loader = DataLoader( self.val_dataset, - batch_sampler=val_sampler, + sampler=val_sampler, + shuffle=False, + batch_size=batch_size, collate_fn=collate, num_workers=0, drop_last=False, @@ -153,18 +157,22 @@ def val_dataloader(self): return val_loader def test_dataloader(self): + batch_size = self.cfg.eval_batch_size or self.cfg.batch_size + test_sampler = DistributedStitchingFixedWindowBatchSampler( interval_dict=self.test_dataset.get_sampling_intervals(), window_length=self.sequence_length, step=self.sequence_length / 2, - batch_size=self.cfg.eval_batch_size or self.cfg.batch_size, + batch_size=batch_size, num_replicas=self.trainer.world_size, rank=self.trainer.global_rank, ) test_loader = DataLoader( self.test_dataset, - batch_sampler=test_sampler, + sampler=test_sampler, + shuffle=False, + batch_size=batch_size, collate_fn=collate, num_workers=0, ) From bf14232e39c682e53b3ff97141bb4958ee6a3b0d Mon Sep 17 00:00:00 2001 From: Vinam Arora Date: Mon, 11 Nov 2024 15:18:21 -0500 Subject: [PATCH 2/9] Fixes bugs in stitcher.py that show up in distributed setting --- setup.py | 1 + torch_brain/utils/stitcher.py | 9 +++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 625fc8f..4f1adde 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ "wandb~=0.15", "torchtyping~=0.1", "pydantic~=2.0", + "tabulate", ], extras_require={ "dev": [ diff --git a/torch_brain/utils/stitcher.py b/torch_brain/utils/stitcher.py index e86322c..9ee4481 100644 --- a/torch_brain/utils/stitcher.py +++ b/torch_brain/utils/stitcher.py @@ -189,16 +189,21 @@ def on_validation_epoch_end(self, trainer, pl_module, prefix="val"): for task_name in self.metrics[recording_id].keys(): for metric_name in self.metrics[recording_id][task_name].keys(): metrics[f"{recording_id}/{task_name}/{metric_name}/{prefix}"] = ( - self.metrics[recording_id][task_name][metric_name].compute() + self.metrics[recording_id][task_name][metric_name] + .to(pl_module.device) + .compute() ) self.metrics[recording_id][task_name][metric_name].reset() + self.metrics[recording_id][task_name][metric_name].to("cpu") # log the metrics self.log_dict(metrics) logging.info(f"Logged {len(metrics)} {prefix} metrics.") # compute the average metric - metrics[f"average_{prefix}_metric"] = np.array(list(metrics.values())).mean() + metrics[f"average_{prefix}_metric"] = torch.tensor( + list(metrics.values()) + ).mean() metrics_data = [] for metric_name, metric_value in metrics.items(): From 820ad78e842766706f87f8bc47a793cc15d0ed85 Mon Sep 17 00:00:00 2001 From: Vinam Arora Date: Mon, 11 Nov 2024 15:45:03 -0500 Subject: [PATCH 3/9] Rename DistributedStitchingFixedWindowBatchSampler -> DistributedStitchingFixedWindowSampler --- tests/test_stitcher_sampler.py | 6 +++--- torch_brain/data/sampler.py | 4 ++-- torch_brain/utils/datamodules.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_stitcher_sampler.py b/tests/test_stitcher_sampler.py index 9e2ccaa..3c3302e 100644 --- a/tests/test_stitcher_sampler.py +++ b/tests/test_stitcher_sampler.py @@ -2,7 +2,7 @@ import numpy as np from temporaldata import Interval -from torch_brain.data.sampler import DistributedStitchingFixedWindowBatchSampler +from torch_brain.data.sampler import DistributedStitchingFixedWindowSampler def test_distributed_stitching_sampler(): @@ -18,7 +18,7 @@ def test_distributed_stitching_sampler(): num_replicas = 2 # Test rank 0 - sampler0 = DistributedStitchingFixedWindowBatchSampler( + sampler0 = DistributedStitchingFixedWindowSampler( interval_dict=interval_dict, window_length=window_length, step=step, @@ -28,7 +28,7 @@ def test_distributed_stitching_sampler(): ) # Test rank 1 - sampler1 = DistributedStitchingFixedWindowBatchSampler( + sampler1 = DistributedStitchingFixedWindowSampler( interval_dict=interval_dict, window_length=window_length, step=step, diff --git a/torch_brain/data/sampler.py b/torch_brain/data/sampler.py index ae6e089..7d71b38 100644 --- a/torch_brain/data/sampler.py +++ b/torch_brain/data/sampler.py @@ -331,8 +331,8 @@ def __iter__(self): return iter(indices) -class DistributedStitchingFixedWindowBatchSampler(torch.utils.data.DistributedSampler): - r"""A batch sampler designed specifically for evaluation that enables sliding window +class DistributedStitchingFixedWindowSampler(torch.utils.data.DistributedSampler): + r"""A sampler designed specifically for evaluation that enables sliding window inference with prediction stitching across distributed processes. This sampler divides sequences into overlapping windows and distributes them across diff --git a/torch_brain/utils/datamodules.py b/torch_brain/utils/datamodules.py index c4e690c..e34a731 100644 --- a/torch_brain/utils/datamodules.py +++ b/torch_brain/utils/datamodules.py @@ -10,7 +10,7 @@ from torch_brain.data import Dataset, collate from torch_brain.data.sampler import ( - DistributedStitchingFixedWindowBatchSampler, + DistributedStitchingFixedWindowSampler, RandomFixedWindowSampler, ) from torch_brain.models import POYOPlusTokenizer @@ -132,7 +132,7 @@ def train_dataloader(self): def val_dataloader(self): batch_size = self.cfg.eval_batch_size or self.cfg.batch_size - val_sampler = DistributedStitchingFixedWindowBatchSampler( + val_sampler = DistributedStitchingFixedWindowSampler( interval_dict=self.val_dataset.get_sampling_intervals(), window_length=self.sequence_length, step=self.sequence_length / 2, @@ -159,7 +159,7 @@ def val_dataloader(self): def test_dataloader(self): batch_size = self.cfg.eval_batch_size or self.cfg.batch_size - test_sampler = DistributedStitchingFixedWindowBatchSampler( + test_sampler = DistributedStitchingFixedWindowSampler( interval_dict=self.test_dataset.get_sampling_intervals(), window_length=self.sequence_length, step=self.sequence_length / 2, From 5896cc2782f6e4becc677271563529f4bc7b51d0 Mon Sep 17 00:00:00 2001 From: Vinam Arora Date: Mon, 11 Nov 2024 15:49:45 -0500 Subject: [PATCH 4/9] Adapt test for Batch->Distributed sampler conversion --- tests/test_stitcher_sampler.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_stitcher_sampler.py b/tests/test_stitcher_sampler.py index 3c3302e..1037a03 100644 --- a/tests/test_stitcher_sampler.py +++ b/tests/test_stitcher_sampler.py @@ -26,6 +26,7 @@ def test_distributed_stitching_sampler(): num_replicas=num_replicas, rank=0, ) + samples0 = list(sampler0) # Test rank 1 sampler1 = DistributedStitchingFixedWindowSampler( @@ -36,21 +37,20 @@ def test_distributed_stitching_sampler(): num_replicas=num_replicas, rank=1, ) + samples1 = list(sampler1) # Get all batches from both samplers - batches0 = list(sampler0) - batches1 = list(sampler1) + batches0 = [ + samples0[i : i + batch_size] for i in range(0, len(samples0), batch_size) + ] + batches1 = [ + samples1[i : i + batch_size] for i in range(0, len(samples1), batch_size) + ] # Basic checks assert len(batches0) > 0 assert len(batches1) > 0 - # Check batch sizes - for batch in batches0[:-1]: # All except last batch - assert len(batch) == batch_size - for batch in batches1[:-1]: - assert len(batch) == batch_size - # Check window properties for batch in batches0: for window in batch: From 90b5f3b5d016eef1f6e08d8cbf0f84205b9122a0 Mon Sep 17 00:00:00 2001 From: Mehdi Azabou Date: Tue, 12 Nov 2024 12:06:41 -0500 Subject: [PATCH 5/9] skip padding token in readout --- torch_brain/utils/stitcher.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch_brain/utils/stitcher.py b/torch_brain/utils/stitcher.py index 9ee4481..3532a80 100644 --- a/torch_brain/utils/stitcher.py +++ b/torch_brain/utils/stitcher.py @@ -117,6 +117,10 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) # update the cache with the predictions and targets for readout_index in torch.unique(batch["output_decoder_index"]): + if readout_index.item() == 0: + # skip the padding token + continue + mask = batch["output_decoder_index"] == readout_index readout_id = torch_brain.get_modality_by_id(readout_index.item()) From 0e248c3c90aaf49d5a47e768710c335e835ebdaa Mon Sep 17 00:00:00 2001 From: Mehdi Azabou Date: Tue, 12 Nov 2024 12:09:35 -0500 Subject: [PATCH 6/9] replace pad with pad8 to use mem_efficient attn --- torch_brain/models/poyo_plus.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/torch_brain/models/poyo_plus.py b/torch_brain/models/poyo_plus.py index bf35d0c..9c6c84f 100644 --- a/torch_brain/models/poyo_plus.py +++ b/torch_brain/models/poyo_plus.py @@ -4,7 +4,7 @@ import torch.nn as nn from torchtyping import TensorType -from torch_brain.data import chain, pad, track_mask +from torch_brain.data import chain, pad8, track_mask from torch_brain.nn import ( Embedding, FeedForward, @@ -335,17 +335,17 @@ def __call__(self, data): batch = { # input sequence - "input_unit_index": pad(spike_unit_index), - "input_timestamps": pad(spike_timestamps), - "input_token_type": pad(spike_token_type_index), + "input_unit_index": pad8(spike_unit_index), + "input_timestamps": pad8(spike_timestamps), + "input_token_type": pad8(spike_token_type_index), "input_mask": track_mask(spike_unit_index), # latent sequence "latent_index": latent_index, "latent_timestamps": latent_timestamps, # output sequence - "output_session_index": pad(session_index), - "output_timestamps": pad(output_timestamps), - "output_decoder_index": pad(output_task_index), + "output_session_index": pad8(session_index), + "output_timestamps": pad8(output_timestamps), + "output_decoder_index": pad8(output_task_index), # ground truth targets "target_values": chain(output_values, allow_missing_keys=True), "target_weights": chain(output_weights, allow_missing_keys=True), From fe3b22df02052e2e798e9c31752433a74d3658c6 Mon Sep 17 00:00:00 2001 From: Mehdi Azabou Date: Tue, 12 Nov 2024 12:12:05 -0500 Subject: [PATCH 7/9] replace pad with pad8 to use mem_efficient attn --- torch_brain/models/poyo_plus.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_brain/models/poyo_plus.py b/torch_brain/models/poyo_plus.py index 9c6c84f..3a871e6 100644 --- a/torch_brain/models/poyo_plus.py +++ b/torch_brain/models/poyo_plus.py @@ -4,7 +4,7 @@ import torch.nn as nn from torchtyping import TensorType -from torch_brain.data import chain, pad8, track_mask +from torch_brain.data import chain, pad8, track_mask8 from torch_brain.nn import ( Embedding, FeedForward, @@ -338,7 +338,7 @@ def __call__(self, data): "input_unit_index": pad8(spike_unit_index), "input_timestamps": pad8(spike_timestamps), "input_token_type": pad8(spike_token_type_index), - "input_mask": track_mask(spike_unit_index), + "input_mask": track_mask8(spike_unit_index), # latent sequence "latent_index": latent_index, "latent_timestamps": latent_timestamps, From dde60f5bf3392929be101265bc50f938d5771d0c Mon Sep 17 00:00:00 2001 From: Mehdi Azabou Date: Tue, 12 Nov 2024 12:36:57 -0500 Subject: [PATCH 8/9] fix loss scaling error --- examples/poyo/train.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/poyo/train.py b/examples/poyo/train.py index a526e06..17fda4c 100644 --- a/examples/poyo/train.py +++ b/examples/poyo/train.py @@ -89,7 +89,11 @@ def training_step(self, batch, batch_idx): spec.loss_fn, spec.type, output, target, weights ) - loss = loss + taskwise_loss[readout_id] * len(target) + # count the number of sequences in the batch that have the current task + num_sequences_with_current_task = torch.any( + batch["output_decoder_index"] == MODALITIY_REGISTRY[readout_id], dim=1 + ).sum() + loss = loss + taskwise_loss[readout_id] * num_sequences_with_current_task batch_size = batch["input_unit_index"].shape[0] # TODO change batch_size when POYOPlusEfficient is used From c8828a6474b6d1cc1c943595fd383a3fe901a36f Mon Sep 17 00:00:00 2001 From: Mehdi Azabou Date: Tue, 12 Nov 2024 12:47:02 -0500 Subject: [PATCH 9/9] fix loss scaling error --- examples/poyo/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/poyo/train.py b/examples/poyo/train.py index 17fda4c..143569c 100644 --- a/examples/poyo/train.py +++ b/examples/poyo/train.py @@ -91,7 +91,8 @@ def training_step(self, batch, batch_idx): # count the number of sequences in the batch that have the current task num_sequences_with_current_task = torch.any( - batch["output_decoder_index"] == MODALITIY_REGISTRY[readout_id], dim=1 + batch["output_decoder_index"] == MODALITIY_REGISTRY[readout_id].id, + dim=1, ).sum() loss = loss + taskwise_loss[readout_id] * num_sequences_with_current_task