Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Alternate Stitcher framework #33

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/poyo/configs/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ log_dir: ./logs
sequence_length: 1.0 # in seconds
latent_step: 0.125 # in seconds
readout_modality_name: cursor_velocity_2d
readout_metric_name: r2

epochs: 1000
eval_epochs: 1 # frequency for doing validation
Expand Down
64 changes: 50 additions & 14 deletions examples/poyo/train.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
from typing import Callable, Dict
from typing import Callable, Dict, List
import copy

import hydra
import lightning as L
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
Expand All @@ -20,7 +21,7 @@
from torch_brain.models.poyo import POYOTokenizer, poyo_mp
from torch_brain.utils import callbacks as tbrain_callbacks
from torch_brain.utils import seed_everything
from torch_brain.utils.stitcher import DecodingStitchEvaluator
from torch_brain.utils.stitcher import Stitcher
from torch_brain.data import Dataset, collate
from torch_brain.nn import compute_loss_or_metric
from torch_brain.data.sampler import (
Expand All @@ -39,12 +40,14 @@ def __init__(
cfg: DictConfig,
model: nn.Module,
modality_spec: ModalitySpec,
session_ids: List[str],
):
super().__init__()

self.cfg = cfg
self.model = model
self.modality_spec = modality_spec
self.stitchers = {k: Stitcher() for k in session_ids}
self.save_hyperparameters(OmegaConf.to_container(cfg))

def configure_optimizers(self):
Expand Down Expand Up @@ -123,18 +126,56 @@ def validation_step(self, batch, batch_idx):
# forward pass
output_values = self.model(**batch)

# add removed elements back to batch
batch["target_values"] = target_values
batch["absolute_start"] = absolute_starts
batch["session_id"] = session_ids
batch["output_subtask_index"] = output_subtask_index
batch["output_mask"] = output_mask
for i in range(len(output_values)):
mask = output_mask[i]
self.stitchers[session_ids[i]].update(
timestamps=batch["output_timestamps"][i][mask] + absolute_starts[i],
preds=output_values[i][mask],
target=target_values[i][mask],
)

def on_validation_epoch_end(self, prefix="val"):
metrics = {}
for sess_id, stitcher in self.stitchers.items():
stitched_preds, stitched_target = stitcher.compute()
stitcher.reset()
metrics[sess_id] = compute_loss_or_metric(
loss_or_metric=self.cfg.readout_metric_name,
output_type=self.modality_spec.type,
output=stitched_preds,
target=stitched_target,
)

metrics[f"avg_{prefix}_metric"] = torch.tensor(list(metrics.values())).mean()

# logging
self.log_dict(metrics)
metrics_df = pd.DataFrame(
[{"metric": k, "value": v.item()} for k, v in metrics.items()]
)
if self.trainer.is_global_zero:
from rich import print as rprint

rprint(metrics_df)

return output_values
for logger in self.trainer.loggers:
if isinstance(logger, L.pytorch.loggers.TensorBoardLogger):
logger.experiment.add_text(
f"{prefix}_metrics", metrics_df.to_markdown()
)
if isinstance(logger, L.pytorch.loggers.WandbLogger):
import wandb

logger.experiment.log(
{f"{prefix}_metrics": wandb.Table(dataframe=metrics_df)}
)

def test_step(self, batch, batch_idx):
return self.validation_step(batch, batch_idx)

def on_test_epoch_end(self):
return self.on_validation_epoch_end(prefix="test")


class DataModule(L.LightningDataModule):
def __init__(self, cfg: DictConfig, tokenizer: Callable[[Data], Dict]):
Expand Down Expand Up @@ -311,15 +352,10 @@ def main(cfg: DictConfig):
cfg=cfg,
model=model,
modality_spec=readout_spec,
)

stitch_evaluator = DecodingStitchEvaluator(
session_ids=data_module.get_session_ids(),
modality_spec=readout_spec,
)

callbacks = [
stitch_evaluator,
ModelSummary(max_depth=2), # Displays the number of parameters in the model.
ModelCheckpoint(
save_last=True,
Expand Down
3 changes: 2 additions & 1 deletion torch_brain/nn/loss.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
import torch
import torch.nn.functional as F
from torchmetrics import R2Score
Expand All @@ -10,7 +11,7 @@ def compute_loss_or_metric(
output_type: DataType,
output: torch.Tensor,
target: torch.Tensor,
weights: torch.Tensor,
weights: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""Helper function to compute various losses or metrics for a given output type.

Expand Down
30 changes: 30 additions & 0 deletions torch_brain/utils/stitcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import pandas as pd
from rich import print as rprint
import torch
from torch import Tensor
import lightning as L
import torchmetrics
from torchmetrics.utilities import dim_zero_cat
import wandb

import torch_brain
Expand Down Expand Up @@ -387,3 +389,31 @@ def on_test_batch_end(self, *args, **kwargs):

def on_test_epoch_end(self, *args, **kwargs):
self.on_validation_epoch_end(*args, **kwargs, prefix="test")


class Stitcher(torchmetrics.Metric):
r"""A simple prediction stitcher. Use this if your model output has associated
timestamps and your sampling strategy involves overlapping time windows, requiring
stitching to coalesce the pridiction and targets before computing the evaluation
metric.
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_state("timestamps", default=[], dist_reduce_fx="cat")
self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")

def update(self, timestamps: Tensor, preds: Tensor, target: Tensor) -> None:
self.timestamps.append(timestamps)
self.preds.append(preds)
self.target.append(target)

def compute(self):
timestamps = dim_zero_cat(self.timestamps)
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)

stitched_preds = stitch(timestamps, preds)
stitched_target = stitch(timestamps, target)
return stitched_preds, stitched_target