Skip to content

Commit

Permalink
fix dataset test
Browse files Browse the repository at this point in the history
  • Loading branch information
jgrss committed Apr 25, 2024
1 parent 89162c5 commit ca38957
Show file tree
Hide file tree
Showing 10 changed files with 141 additions and 181 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9"]
python-version: ["3.9", "3.10"]
steps:
- uses: actions/checkout@v3
- name: Setup Python ${{ matrix.python-version }}
Expand Down Expand Up @@ -45,7 +45,7 @@ jobs:
- name: Install PyTorch
run: |
TORCH_CPU="https://download.pytorch.org/whl/cpu"
TORCH_VERSION="2.1.0"
TORCH_VERSION="2.2.2"
pip install --upgrade --no-cache-dir setuptools>=0.59.5
pip install torch==${TORCH_VERSION} torchvision==0.16.0 torchaudio==${TORCH_VERSION} --extra-index-url $TORCH_CPU
- name: Install cultionet
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![python](https://img.shields.io/badge/Python-3.8%20%7C%203.9-3776AB.svg?style=flat&logo=python&logoColor=white)](https://www.python.org)
[![python](https://img.shields.io/badge/Python-3.9%20%7C%203.10-3776AB.svg?style=flat&logo=python&logoColor=white)](https://www.python.org)
[![](https://img.shields.io/github/v/release/jgrss/cultionet?display_name=release)](https://github.com/jgrss/cultionet/releases)
[![](https://github.com/jgrss/cultionet/actions/workflows/ci.yml/badge.svg)](https://github.com/jgrss/cultionet/actions?query=workflow%3ACI)

Expand Down Expand Up @@ -265,9 +265,9 @@ pyenv virtualenv 3.8.12 venv.cnet
pyenv activate venv.cnet
(venv.cnet) pip install -U pip setuptools wheel numpy cython
(venv.cnet) pip install gdal==$(gdal-config --version | awk -F'[.]' '{print $1"."$2"."$3}') --no-binary=gdal
(venv.cnet) pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
(venv.cnet) TORCH_VERSION=2.2.2
(venv.cnet) pip install torch==${TORCH_VERSION} torchvision==0.16.0 torchaudio==${TORCH_VERSION} --index-url https://download.pytorch.org/whl/cpu
(venv.cnet) TORCH_VERSION=$(python -c "import torch;print(torch.__version__)")
(venv.cnet) pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-${TORCH_VERSION}.html
(venv.cnet) pip install cultionet@git+https://github.com/jgrss/cultionet.git
```

Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ classifiers =
Topic :: Scientific :: Time series
Topic :: Scientific :: Segmentation
Programming Language :: Cython
Programming Language :: Python :: 3.8 :: 3.9 :: 3.10
Programming Language :: Python :: 3.9 :: 3.10

[options]
package_dir=
Expand All @@ -29,7 +29,7 @@ setup_requires =
Cython>=0.29.0,<3.0.0
numpy>=1.22.0
python_requires =
>=3.8.0,<3.11.0
>=3.9.0,<3.11.0
install_requires =
attrs>=21.0
frozendict>=2.2.0
Expand Down
23 changes: 13 additions & 10 deletions src/cultionet/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
self.random_seed = random_seed
self.augment_prob = augment_prob

seed_everything(self.random_seed, workers=True)
seed_everything(self.random_seed)
self.rng = np.random.default_rng(self.random_seed)

self.augmentations_ = [
Expand All @@ -88,7 +88,7 @@ def __init__(

def get_data_list(self):
"""Gets the list of data files."""
data_list_ = list(Path(self.processed_dir).glob(self.pattern))
data_list_ = sorted(list(Path(self.processed_dir).glob(self.pattern)))

if not data_list_:
logger.exception(
Expand All @@ -110,18 +110,20 @@ def cleanup(self):
for fn in self.data_list_:
fn.unlink()

def shuffle_items(self, data: T.Optional[list] = None):
self.data_list_ = []

def shuffle(self, data: T.Optional[list] = None):
"""Applies a random in-place shuffle to the data list."""
if data is not None:
self.rng.shuffle(data)
else:
self.rng.shuffle(self.data_list_)

@property
def num_time_features(self):
def num_time(self) -> int:
"""Get the number of time features."""
data = self[0]
return int(data.ntime)
return int(data.num_time)

def to_frame(self) -> gpd.GeoDataFrame:
"""Converts the Dataset to a GeoDataFrame."""
Expand Down Expand Up @@ -322,7 +324,7 @@ def split_train_val_by_partition(
self.get_spatial_partitions(spatial_partitions=spatial_partitions)
train_indices = []
val_indices = []
self.shuffle_items()
self.shuffle()
# self.spatial_partitions is a GeoDataFrame with Point geometry
for row in tqdm(
self.spatial_partitions.itertuples(),
Expand Down Expand Up @@ -364,10 +366,9 @@ def split_train_val(
Returns:
train dataset, validation dataset
"""
id_column = "common_id"
self.shuffle_items()

if spatial_overlap_allowed:
self.shuffle()
n_train = int(len(self) * (1.0 - val_frac))
train_ds = self[:n_train]
val_ds = self[n_train:]
Expand All @@ -394,7 +395,9 @@ def split_train_val(
# `qt.sample` random samples from the quad-tree in a
# spatially balanced manner. Thus, `df_val_sample` is
# a GeoDataFrame with `n_val` sites spatially balanced.
df_val_sample = qt.sample(n=n_val)
df_val_sample = qt.sample(
n=n_val, random_state=self.random_seed
)

# Since we only took one sample from each coordinate,
# we need to find all of the .pt files that share
Expand All @@ -406,7 +409,7 @@ def split_train_val(
# Randomly sample a percentage for validation
df_val_ids = self.dataset_df.sample(
frac=val_frac, random_state=self.random_seed
).to_frame(name=id_column)
).to_frame(name=self.grid_id_column)
# Get all ids for validation samples
val_mask = self.dataset_df[self.grid_id_column].isin(
df_val_ids[self.grid_id_column]
Expand Down
13 changes: 10 additions & 3 deletions src/cultionet/data/modules.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing as T

import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Sampler

Expand All @@ -22,6 +23,7 @@ def __init__(
sampler: T.Optional[Sampler] = None,
pin_memory: bool = False,
persistent_workers: bool = False,
generator: T.Optional[torch.Generator] = None,
):
super().__init__()

Expand All @@ -35,6 +37,7 @@ def __init__(
self.sampler = sampler
self.pin_memory = pin_memory
self.persistent_workers = persistent_workers
self.generator = generator

def train_dataloader(self):
"""Returns a data loader for train data."""
Expand All @@ -47,34 +50,38 @@ def train_dataloader(self):
pin_memory=self.pin_memory,
collate_fn=collate_fn,
persistent_workers=self.persistent_workers,
generator=self.generator,
)

def val_dataloader(self):
"""Returns a data loader for validation data."""
return DataLoader(
self.val_ds,
batch_size=self.batch_size,
shuffle=self.shuffle,
shuffle=False,
num_workers=self.num_workers,
collate_fn=collate_fn,
generator=self.generator,
)

def test_dataloader(self):
"""Returns a data loader for test data."""
return DataLoader(
self.test_ds,
batch_size=self.batch_size,
shuffle=self.shuffle,
shuffle=False,
num_workers=self.num_workers,
collate_fn=collate_fn,
generator=self.generator,
)

def predict_dataloader(self):
"""Returns a data loader for predict data."""
return DataLoader(
self.predict_ds,
batch_size=self.batch_size,
shuffle=self.shuffle,
shuffle=False,
num_workers=self.num_workers,
collate_fn=collate_fn,
generator=self.generator,
)
46 changes: 0 additions & 46 deletions tests/_test_dataset.py

This file was deleted.

71 changes: 0 additions & 71 deletions tests/_test_reshape.py

This file was deleted.

42 changes: 0 additions & 42 deletions tests/_test_temporal_attention.py

This file was deleted.

2 changes: 0 additions & 2 deletions tests/test_cultionet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import tempfile

import torch

from cultionet.data.modules import EdgeDataModule
from cultionet.enums import ModelTypes, ResBlockTypes
from cultionet.models.cultio import CultioNet
Expand Down
Loading

0 comments on commit ca38957

Please sign in to comment.