Skip to content

Commit

Permalink
fix imports
Browse files Browse the repository at this point in the history
  • Loading branch information
mazabou committed Oct 23, 2024
1 parent c83fa46 commit dfd42c2
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 13 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Changelog

All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Added
- Added multitask decoder taxonomy. ([#36](https://github.com/neuro-galaxy/torch_brain/pull/36))

### Changed
- Update workflow to use ubuntu-latest instances from github actions. ([#36](httpps://github.com/neuro-galaxy/torch_brain/pull/36))

### Fixed
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
packages=find_packages() + find_namespace_packages(include=["hydra_plugins.*"]),
include_package_data=True,
install_requires=[
"temporaldata @ git+https://github.com/neuro-galaxy/temporaldata@main#egg=temporaldata",
"brainsets @ git+https://github.com/neuro-galaxy/brainsets@main#egg=brainsets",
"temporaldata @ git+https://github.com/neuro-galaxy/temporaldata@main#egg=temporaldata-0.1.1",
"brainsets @ git+https://github.com/neuro-galaxy/brainsets@main#egg=brainsets-0.1.0",
"torch==2.2.0",
"einops~=0.6.0",
# "setuptools~=60.2.0",
Expand Down
3 changes: 2 additions & 1 deletion torch_brain/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from .perceiver_rotary import PerceiverRotary

# readout layers
from .loss import compute_loss_or_metric
from .loss import compute_loss_or_metric, OutputType
from .multitask_readout import (
MultitaskReadout,
prepare_for_multitask_readout,
Decoder,
)
9 changes: 8 additions & 1 deletion torch_brain/nn/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@

from torchmetrics import R2Score

from brainsets.taxonomy import OutputType
from brainsets.taxonomy import StringIntEnum


class OutputType(StringIntEnum):
CONTINUOUS = 0
BINARY = 1
MULTILABEL = 2
MULTINOMIAL = 3


def compute_loss_or_metric(
Expand Down
10 changes: 1 addition & 9 deletions torch_brain/nn/multitask_readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,13 @@

from brainsets.taxonomy import StringIntEnum, Task
from torch_brain.data.collate import collate, chain, track_batch
from torch_brain.nn import compute_loss_or_metric

from torch_brain.nn import compute_loss_or_metric, OutputType

from typing import Dict, List, Tuple, Optional, Union, Any

from pydantic.dataclasses import dataclass


class OutputType(StringIntEnum):
CONTINUOUS = 0
BINARY = 1
MULTILABEL = 2
MULTINOMIAL = 3


class Decoder(StringIntEnum):
NA = 0
# Classic BCI outputs.
Expand Down

0 comments on commit dfd42c2

Please sign in to comment.