Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add multitask decoder taxonomy #8

Merged
merged 15 commits into from
Oct 23, 2024
2 changes: 1 addition & 1 deletion .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ permissions:

jobs:
linting:
runs-on: self-hosted
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v3
Expand Down
9 changes: 8 additions & 1 deletion .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ permissions:

jobs:
testing:
runs-on: self-hosted
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
Expand All @@ -22,6 +22,13 @@ jobs:
run: |
python -m venv venv
source venv/bin/activate
- name: Install brainsets
env:
GITHUB_TOKEN: ${{ secrets.GH_TOKEN }}
run: |
git config --global url."https://${GITHUB_TOKEN}@github.com/".insteadOf "https://github.com/"
python -m pip install --no-user --upgrade pip
pip install --no-user "brainsets @ git+https://github.com/neuro-galaxy/brainsets.git@main#egg=brainsets-0.1.0[all]"
- name: Install dependencies
run: |
source venv/bin/activate
Expand Down
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Changelog

All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Added
- Added multitask decoder taxonomy. ([#8](https://github.com/neuro-galaxy/torch_brain/pull/8))

### Changed
- Update workflow to use ubuntu-latest instances from github actions. ([#8](httpps://github.com/neuro-galaxy/torch_brain/pull/8))

### Fixed
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
packages=find_packages() + find_namespace_packages(include=["hydra_plugins.*"]),
include_package_data=True,
install_requires=[
"temporaldata @ git+https://github.com/neuro-galaxy/temporaldata@main#egg=temporaldata",
"brainsets @ git+https://github.com/neuro-galaxy/brainsets@main#egg=brainsets",
"temporaldata @ git+https://github.com/neuro-galaxy/temporaldata@main#egg=temporaldata-0.1.1",
"brainsets @ git+https://github.com/neuro-galaxy/brainsets@main#egg=brainsets-0.1.0",
"torch==2.2.0",
"einops~=0.6.0",
# "setuptools~=60.2.0",
Expand Down
3 changes: 2 additions & 1 deletion torch_brain/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from .perceiver_rotary import PerceiverRotary

# readout layers
from .loss import compute_loss_or_metric
from .loss import compute_loss_or_metric, OutputType
from .multitask_readout import (
MultitaskReadout,
prepare_for_multitask_readout,
Decoder,
)
9 changes: 8 additions & 1 deletion torch_brain/nn/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@

from torchmetrics import R2Score

from brainsets.taxonomy import OutputType
from brainsets.taxonomy import StringIntEnum


class OutputType(StringIntEnum):
CONTINUOUS = 0
BINARY = 1
MULTILABEL = 2
MULTINOMIAL = 3


def compute_loss_or_metric(
Expand Down
264 changes: 262 additions & 2 deletions torch_brain/nn/multitask_readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,269 @@
import torch.nn as nn
from torchtyping import TensorType

from brainsets.taxonomy import DecoderSpec, Decoder, Task
from brainsets.taxonomy import StringIntEnum, Task
from torch_brain.data.collate import collate, chain, track_batch
from torch_brain.nn import compute_loss_or_metric
from torch_brain.nn import compute_loss_or_metric, OutputType

from typing import Dict, List, Tuple, Optional, Union, Any

from pydantic.dataclasses import dataclass


class Decoder(StringIntEnum):
NA = 0
# Classic BCI outputs.
ARMVELOCITY2D = 1
CURSORPOSITION2D = 2
EYE2D = 3
FINGER3D = 4

# Shenoy handwriting style outputs.
WRITING_CHARACTER = 5
WRITING_LINE = 6

DISCRETE_TRIAL_ONSET_OFFSET = 7
CONTINUOUS_TRIAL_ONSET_OFFSET = 8

CURSORVELOCITY2D = 9

# Allen data
DRIFTING_GRATINGS_ORIENTATION = 13
DRIFTING_GRATINGS_TEMPORAL_FREQUENCY = 23
STATIC_GRATINGS_ORIENTATION = 17
STATIC_GRATINGS_SPATIAL_FREQUENCY = 18
STATIC_GRATINGS_PHASE = 19

RUNNING_SPEED = 24
PUPIL_SIZE_2D = 25
GAZE_POS_2D = 26
GABOR_ORIENTATION = 21 #
GABOR_POS_2D = 27
NATURAL_SCENES = 28
NATURAL_MOVIE_ONE_FRAME = 30
NATURAL_MOVIE_TWO_FRAME = 31
NATURAL_MOVIE_THREE_FRAME = 32
LOCALLY_SPARSE_NOISE_FRAME = 33

# Openscope calcium
UNEXPECTED_OR_NOT = 20 #
PUPIL_MOVEMENT_REGRESSION = 22
PUPIL_LOCATION = 34

# speech
SPEAKING_CVSYLLABLE = 14
SPEAKING_CONSONANT = 15
SPEAKING_VOWEL = 16


@dataclass
class DecoderSpec:
dim: int
type: OutputType
loss_fn: str
timestamp_key: str
value_key: str
# Optional fields
task_key: Optional[str] = None
subtask_key: Optional[str] = None
# target_dtype: str = "float32" # torch.dtype is not serializable.


decoder_registry = {
str(Decoder.ARMVELOCITY2D): DecoderSpec(
dim=2,
target_dim=2,
type=OutputType.CONTINUOUS,
timestamp_key="behavior.timestamps",
value_key="behavior.hand_vel",
subtask_key="behavior.subtask_index",
loss_fn="mse",
),
str(Decoder.CURSORVELOCITY2D): DecoderSpec(
dim=2,
target_dim=2,
type=OutputType.CONTINUOUS,
timestamp_key="cursor.timestamps",
value_key="cursor.vel",
subtask_key="cursor.subtask_index",
loss_fn="mse",
),
str(Decoder.CURSORPOSITION2D): DecoderSpec(
dim=2,
target_dim=2,
type=OutputType.CONTINUOUS,
timestamp_key="cursor.timestamps",
value_key="cursor.pos",
subtask_key="cursor.subtask_index",
loss_fn="mse",
),
# str(Decoder.WRITING_CHARACTER): DecoderSpec(
# dim=len(Character),
# target_dim=1,
# target_dtype="long",
# type=OutputType.MULTINOMIAL,
# timestamp_key="stimuli_segments.timestamps",
# value_key="stimuli_segments.letters",
# loss_fn="bce",
# ),
# str(Decoder.WRITING_LINE): DecoderSpec(
# dim=len(Line),
# target_dim=1,
# target_dtype="long",
# type=OutputType.MULTINOMIAL,
# timestamp_key="stimuli_segments.timestamps",
# value_key="stimuli_segments.letters",
# loss_fn="bce",
# ),
str(Decoder.DRIFTING_GRATINGS_ORIENTATION): DecoderSpec(
dim=8,
target_dim=1,
target_dtype="long",
type=OutputType.MULTINOMIAL,
timestamp_key="drifting_gratings.timestamps",
value_key="drifting_gratings.orientation_id",
loss_fn="bce",
),
str(Decoder.DRIFTING_GRATINGS_TEMPORAL_FREQUENCY): DecoderSpec(
dim=5, # [1,2,4,8,15]
target_dim=1,
target_dtype="long",
type=OutputType.MULTINOMIAL,
timestamp_key="drifting_gratings.timestamps",
value_key="drifting_gratings.temporal_frequency_id",
loss_fn="bce",
),
str(Decoder.NATURAL_MOVIE_ONE_FRAME): DecoderSpec(
dim=900,
target_dim=1,
target_dtype="long",
type=OutputType.MULTINOMIAL,
timestamp_key="natural_movie_one.timestamps",
value_key="natural_movie_one.frame",
loss_fn="bce",
),
str(Decoder.NATURAL_MOVIE_TWO_FRAME): DecoderSpec(
dim=900,
target_dim=1,
target_dtype="long",
type=OutputType.MULTINOMIAL,
timestamp_key="natural_movie_two.timestamps",
value_key="natural_movie_two.frame",
loss_fn="bce",
),
str(Decoder.NATURAL_MOVIE_THREE_FRAME): DecoderSpec(
dim=3600,
target_dim=1,
target_dtype="long",
type=OutputType.MULTINOMIAL,
timestamp_key="natural_movie_three.timestamps",
value_key="natural_movie_three.frame",
loss_fn="bce",
),
str(Decoder.LOCALLY_SPARSE_NOISE_FRAME): DecoderSpec(
dim=8000,
target_dim=1,
target_dtype="long",
type=OutputType.MULTINOMIAL,
timestamp_key="locally_sparse_noise.timestamps",
value_key="locally_sparse_noise.frame",
loss_fn="bce",
),
str(Decoder.STATIC_GRATINGS_ORIENTATION): DecoderSpec(
dim=6,
target_dim=1,
target_dtype="long",
type=OutputType.MULTINOMIAL,
timestamp_key="static_gratings.timestamps",
value_key="static_gratings.orientation_id",
loss_fn="bce",
),
str(Decoder.STATIC_GRATINGS_SPATIAL_FREQUENCY): DecoderSpec(
dim=5,
target_dim=1,
target_dtype="long",
type=OutputType.MULTINOMIAL,
timestamp_key="static_gratings.timestamps",
value_key="static_gratings.spatial_frequency_id",
loss_fn="bce",
),
str(Decoder.STATIC_GRATINGS_PHASE): DecoderSpec(
dim=5,
target_dim=1,
target_dtype="long",
type=OutputType.MULTINOMIAL,
timestamp_key="static_gratings.timestamps",
value_key="static_gratings.phase_id",
loss_fn="bce",
),
# str(Decoder.SPEAKING_CVSYLLABLE): DecoderSpec(
# dim=len(CVSyllable), # empty label is included
# target_dim=1,
# target_dtype="long",
# type=OutputType.MULTINOMIAL,
# timestamp_key="speech.timestamps",
# value_key="speech.consonant_vowel_syllables",
# loss_fn="bce",
# ),
str(Decoder.NATURAL_SCENES): DecoderSpec(
dim=119, # image classes [0,...,118]
target_dim=1,
target_dtype="long",
type=OutputType.MULTINOMIAL,
timestamp_key="natural_scenes.timestamps",
value_key="natural_scenes.frame",
loss_fn="bce",
),
str(Decoder.GABOR_ORIENTATION): DecoderSpec(
dim=4, # [0, 1, 2, 3]
target_dim=1,
target_dtype="long",
type=OutputType.MULTINOMIAL,
timestamp_key="gabors.timestamps",
value_key="gabors.gabors_orientation",
loss_fn="bce",
),
str(Decoder.GABOR_POS_2D): DecoderSpec( # 9x9 grid modeled as (x, y) coordinates
dim=2,
target_dim=2,
type=OutputType.CONTINUOUS,
timestamp_key="gabors.timestamps",
value_key="gabors.pos_2d",
loss_fn="mse",
),
str(Decoder.RUNNING_SPEED): DecoderSpec(
dim=1,
target_dim=1,
type=OutputType.CONTINUOUS,
timestamp_key="running.timestamps",
value_key="running.running_speed",
loss_fn="mse",
),
str(Decoder.GAZE_POS_2D): DecoderSpec(
dim=2,
target_dim=2,
type=OutputType.CONTINUOUS,
timestamp_key="gaze.timestamps",
value_key="gaze.pos_2d",
loss_fn="mse",
),
str(Decoder.PUPIL_LOCATION): DecoderSpec(
dim=2,
target_dim=2,
type=OutputType.CONTINUOUS,
timestamp_key="pupil.timestamps",
value_key="pupil.location",
loss_fn="mse",
),
str(Decoder.PUPIL_SIZE_2D): DecoderSpec(
dim=2,
target_dim=2,
type=OutputType.CONTINUOUS,
timestamp_key="pupil.timestamps",
value_key="pupil.size_2d",
loss_fn="mse",
),
}


class MultitaskReadout(nn.Module):
Expand Down
Loading