diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index d65546c..f8b779b 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -11,7 +11,7 @@ permissions: jobs: linting: - runs-on: self-hosted + runs-on: ubuntu-latest steps: - name: Checkout code uses: actions/checkout@v3 diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 2b8cfe9..a71e3ea 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -11,7 +11,7 @@ permissions: jobs: testing: - runs-on: self-hosted + runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - name: Set up Python 3.9 @@ -22,6 +22,13 @@ jobs: run: | python -m venv venv source venv/bin/activate + - name: Install brainsets + env: + GITHUB_TOKEN: ${{ secrets.GH_TOKEN }} + run: | + git config --global url."https://${GITHUB_TOKEN}@github.com/".insteadOf "https://github.com/" + python -m pip install --no-user --upgrade pip + pip install --no-user "brainsets @ git+https://github.com/neuro-galaxy/brainsets.git@main#egg=brainsets-0.1.0[all]" - name: Install dependencies run: | source venv/bin/activate diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..7ca035a --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,13 @@ +# Changelog + +All notable changes to this project will be documented in this file. +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] +### Added +- Added multitask decoder taxonomy. ([#8](https://github.com/neuro-galaxy/torch_brain/pull/8)) + +### Changed +- Update workflow to use ubuntu-latest instances from github actions. ([#8](httpps://github.com/neuro-galaxy/torch_brain/pull/8)) + +### Fixed diff --git a/setup.py b/setup.py index 5dd1d6a..490c2b9 100644 --- a/setup.py +++ b/setup.py @@ -11,8 +11,8 @@ packages=find_packages() + find_namespace_packages(include=["hydra_plugins.*"]), include_package_data=True, install_requires=[ - "temporaldata @ git+https://github.com/neuro-galaxy/temporaldata@main#egg=temporaldata", - "brainsets @ git+https://github.com/neuro-galaxy/brainsets@main#egg=brainsets", + "temporaldata @ git+https://github.com/neuro-galaxy/temporaldata@main#egg=temporaldata-0.1.1", + "brainsets @ git+https://github.com/neuro-galaxy/brainsets@main#egg=brainsets-0.1.0", "torch==2.2.0", "einops~=0.6.0", # "setuptools~=60.2.0", diff --git a/torch_brain/nn/__init__.py b/torch_brain/nn/__init__.py index f8741b1..878b99a 100644 --- a/torch_brain/nn/__init__.py +++ b/torch_brain/nn/__init__.py @@ -8,8 +8,9 @@ from .perceiver_rotary import PerceiverRotary # readout layers -from .loss import compute_loss_or_metric +from .loss import compute_loss_or_metric, OutputType from .multitask_readout import ( MultitaskReadout, prepare_for_multitask_readout, + Decoder, ) diff --git a/torch_brain/nn/loss.py b/torch_brain/nn/loss.py index 6ce02dc..17695a3 100644 --- a/torch_brain/nn/loss.py +++ b/torch_brain/nn/loss.py @@ -3,7 +3,14 @@ from torchmetrics import R2Score -from brainsets.taxonomy import OutputType +from brainsets.taxonomy import StringIntEnum + + +class OutputType(StringIntEnum): + CONTINUOUS = 0 + BINARY = 1 + MULTILABEL = 2 + MULTINOMIAL = 3 def compute_loss_or_metric( diff --git a/torch_brain/nn/multitask_readout.py b/torch_brain/nn/multitask_readout.py index 2325a32..6a0b03e 100644 --- a/torch_brain/nn/multitask_readout.py +++ b/torch_brain/nn/multitask_readout.py @@ -5,9 +5,269 @@ import torch.nn as nn from torchtyping import TensorType -from brainsets.taxonomy import DecoderSpec, Decoder, Task +from brainsets.taxonomy import StringIntEnum, Task from torch_brain.data.collate import collate, chain, track_batch -from torch_brain.nn import compute_loss_or_metric +from torch_brain.nn import compute_loss_or_metric, OutputType + +from typing import Dict, List, Tuple, Optional, Union, Any + +from pydantic.dataclasses import dataclass + + +class Decoder(StringIntEnum): + NA = 0 + # Classic BCI outputs. + ARMVELOCITY2D = 1 + CURSORPOSITION2D = 2 + EYE2D = 3 + FINGER3D = 4 + + # Shenoy handwriting style outputs. + WRITING_CHARACTER = 5 + WRITING_LINE = 6 + + DISCRETE_TRIAL_ONSET_OFFSET = 7 + CONTINUOUS_TRIAL_ONSET_OFFSET = 8 + + CURSORVELOCITY2D = 9 + + # Allen data + DRIFTING_GRATINGS_ORIENTATION = 13 + DRIFTING_GRATINGS_TEMPORAL_FREQUENCY = 23 + STATIC_GRATINGS_ORIENTATION = 17 + STATIC_GRATINGS_SPATIAL_FREQUENCY = 18 + STATIC_GRATINGS_PHASE = 19 + + RUNNING_SPEED = 24 + PUPIL_SIZE_2D = 25 + GAZE_POS_2D = 26 + GABOR_ORIENTATION = 21 # + GABOR_POS_2D = 27 + NATURAL_SCENES = 28 + NATURAL_MOVIE_ONE_FRAME = 30 + NATURAL_MOVIE_TWO_FRAME = 31 + NATURAL_MOVIE_THREE_FRAME = 32 + LOCALLY_SPARSE_NOISE_FRAME = 33 + + # Openscope calcium + UNEXPECTED_OR_NOT = 20 # + PUPIL_MOVEMENT_REGRESSION = 22 + PUPIL_LOCATION = 34 + + # speech + SPEAKING_CVSYLLABLE = 14 + SPEAKING_CONSONANT = 15 + SPEAKING_VOWEL = 16 + + +@dataclass +class DecoderSpec: + dim: int + type: OutputType + loss_fn: str + timestamp_key: str + value_key: str + # Optional fields + task_key: Optional[str] = None + subtask_key: Optional[str] = None + # target_dtype: str = "float32" # torch.dtype is not serializable. + + +decoder_registry = { + str(Decoder.ARMVELOCITY2D): DecoderSpec( + dim=2, + target_dim=2, + type=OutputType.CONTINUOUS, + timestamp_key="behavior.timestamps", + value_key="behavior.hand_vel", + subtask_key="behavior.subtask_index", + loss_fn="mse", + ), + str(Decoder.CURSORVELOCITY2D): DecoderSpec( + dim=2, + target_dim=2, + type=OutputType.CONTINUOUS, + timestamp_key="cursor.timestamps", + value_key="cursor.vel", + subtask_key="cursor.subtask_index", + loss_fn="mse", + ), + str(Decoder.CURSORPOSITION2D): DecoderSpec( + dim=2, + target_dim=2, + type=OutputType.CONTINUOUS, + timestamp_key="cursor.timestamps", + value_key="cursor.pos", + subtask_key="cursor.subtask_index", + loss_fn="mse", + ), + # str(Decoder.WRITING_CHARACTER): DecoderSpec( + # dim=len(Character), + # target_dim=1, + # target_dtype="long", + # type=OutputType.MULTINOMIAL, + # timestamp_key="stimuli_segments.timestamps", + # value_key="stimuli_segments.letters", + # loss_fn="bce", + # ), + # str(Decoder.WRITING_LINE): DecoderSpec( + # dim=len(Line), + # target_dim=1, + # target_dtype="long", + # type=OutputType.MULTINOMIAL, + # timestamp_key="stimuli_segments.timestamps", + # value_key="stimuli_segments.letters", + # loss_fn="bce", + # ), + str(Decoder.DRIFTING_GRATINGS_ORIENTATION): DecoderSpec( + dim=8, + target_dim=1, + target_dtype="long", + type=OutputType.MULTINOMIAL, + timestamp_key="drifting_gratings.timestamps", + value_key="drifting_gratings.orientation_id", + loss_fn="bce", + ), + str(Decoder.DRIFTING_GRATINGS_TEMPORAL_FREQUENCY): DecoderSpec( + dim=5, # [1,2,4,8,15] + target_dim=1, + target_dtype="long", + type=OutputType.MULTINOMIAL, + timestamp_key="drifting_gratings.timestamps", + value_key="drifting_gratings.temporal_frequency_id", + loss_fn="bce", + ), + str(Decoder.NATURAL_MOVIE_ONE_FRAME): DecoderSpec( + dim=900, + target_dim=1, + target_dtype="long", + type=OutputType.MULTINOMIAL, + timestamp_key="natural_movie_one.timestamps", + value_key="natural_movie_one.frame", + loss_fn="bce", + ), + str(Decoder.NATURAL_MOVIE_TWO_FRAME): DecoderSpec( + dim=900, + target_dim=1, + target_dtype="long", + type=OutputType.MULTINOMIAL, + timestamp_key="natural_movie_two.timestamps", + value_key="natural_movie_two.frame", + loss_fn="bce", + ), + str(Decoder.NATURAL_MOVIE_THREE_FRAME): DecoderSpec( + dim=3600, + target_dim=1, + target_dtype="long", + type=OutputType.MULTINOMIAL, + timestamp_key="natural_movie_three.timestamps", + value_key="natural_movie_three.frame", + loss_fn="bce", + ), + str(Decoder.LOCALLY_SPARSE_NOISE_FRAME): DecoderSpec( + dim=8000, + target_dim=1, + target_dtype="long", + type=OutputType.MULTINOMIAL, + timestamp_key="locally_sparse_noise.timestamps", + value_key="locally_sparse_noise.frame", + loss_fn="bce", + ), + str(Decoder.STATIC_GRATINGS_ORIENTATION): DecoderSpec( + dim=6, + target_dim=1, + target_dtype="long", + type=OutputType.MULTINOMIAL, + timestamp_key="static_gratings.timestamps", + value_key="static_gratings.orientation_id", + loss_fn="bce", + ), + str(Decoder.STATIC_GRATINGS_SPATIAL_FREQUENCY): DecoderSpec( + dim=5, + target_dim=1, + target_dtype="long", + type=OutputType.MULTINOMIAL, + timestamp_key="static_gratings.timestamps", + value_key="static_gratings.spatial_frequency_id", + loss_fn="bce", + ), + str(Decoder.STATIC_GRATINGS_PHASE): DecoderSpec( + dim=5, + target_dim=1, + target_dtype="long", + type=OutputType.MULTINOMIAL, + timestamp_key="static_gratings.timestamps", + value_key="static_gratings.phase_id", + loss_fn="bce", + ), + # str(Decoder.SPEAKING_CVSYLLABLE): DecoderSpec( + # dim=len(CVSyllable), # empty label is included + # target_dim=1, + # target_dtype="long", + # type=OutputType.MULTINOMIAL, + # timestamp_key="speech.timestamps", + # value_key="speech.consonant_vowel_syllables", + # loss_fn="bce", + # ), + str(Decoder.NATURAL_SCENES): DecoderSpec( + dim=119, # image classes [0,...,118] + target_dim=1, + target_dtype="long", + type=OutputType.MULTINOMIAL, + timestamp_key="natural_scenes.timestamps", + value_key="natural_scenes.frame", + loss_fn="bce", + ), + str(Decoder.GABOR_ORIENTATION): DecoderSpec( + dim=4, # [0, 1, 2, 3] + target_dim=1, + target_dtype="long", + type=OutputType.MULTINOMIAL, + timestamp_key="gabors.timestamps", + value_key="gabors.gabors_orientation", + loss_fn="bce", + ), + str(Decoder.GABOR_POS_2D): DecoderSpec( # 9x9 grid modeled as (x, y) coordinates + dim=2, + target_dim=2, + type=OutputType.CONTINUOUS, + timestamp_key="gabors.timestamps", + value_key="gabors.pos_2d", + loss_fn="mse", + ), + str(Decoder.RUNNING_SPEED): DecoderSpec( + dim=1, + target_dim=1, + type=OutputType.CONTINUOUS, + timestamp_key="running.timestamps", + value_key="running.running_speed", + loss_fn="mse", + ), + str(Decoder.GAZE_POS_2D): DecoderSpec( + dim=2, + target_dim=2, + type=OutputType.CONTINUOUS, + timestamp_key="gaze.timestamps", + value_key="gaze.pos_2d", + loss_fn="mse", + ), + str(Decoder.PUPIL_LOCATION): DecoderSpec( + dim=2, + target_dim=2, + type=OutputType.CONTINUOUS, + timestamp_key="pupil.timestamps", + value_key="pupil.location", + loss_fn="mse", + ), + str(Decoder.PUPIL_SIZE_2D): DecoderSpec( + dim=2, + target_dim=2, + type=OutputType.CONTINUOUS, + timestamp_key="pupil.timestamps", + value_key="pupil.size_2d", + loss_fn="mse", + ), +} class MultitaskReadout(nn.Module):