Skip to content

Commit

Permalink
Merge pull request #75 from opengeokube/dev
Browse files Browse the repository at this point in the history
Update
  • Loading branch information
jamesWalczak authored May 22, 2024
2 parents 528900f + 34ac0fb commit c97dc65
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 22 deletions.
34 changes: 34 additions & 0 deletions .github/workflows/build_mkdocs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: Deploy MkDocs site to GitHub Pages

on:
workflow_dispatch:
push:
branches:
- main

jobs:
deploy:
runs-on: ubuntu-latest

steps:
- name: Checkout the repository
uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.x'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install mkdocs mkdocs-material
- name: Build MkDocs site
run: mkdocs build

- name: Deploy to GitHub Pages
uses: peaceiris/actions-gh-pages@v3
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
publish_dir: .
21 changes: 13 additions & 8 deletions kit4dl/cli/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def _is_test_allowed(trainer: Trainer) -> bool:
@_app.command()
def init(
name: Annotated[
str, typer.Option(help="The name of your new project")
] = "new_kit4dl_project"
str, typer.Option(default=..., help="The name of your new project")
] = "new_kit4dl_project",
) -> None:
"""Create a new Kit4DL project.
Expand All @@ -93,8 +93,8 @@ def init(
@_app.command()
def resume(
checkpoint: Annotated[
str, typer.Option(help="Path to the checkpoint file")
]
str, typer.Option(default=..., help="Path to the checkpoint file")
] = "./checkpoint.ckpt",
):
"""Resume learning from the checkpoint.
Expand All @@ -109,7 +109,8 @@ def resume(
@_app.command()
def test(
conf: Annotated[
str, typer.Option(help="Path to the configuration TOML file")
str,
typer.Option(default=..., help="Path to the configuration TOML file"),
] = get_default_conf_path(),
):
"""Test using the configuration file.
Expand All @@ -135,16 +136,20 @@ def test(
@_app.command()
def train(
conf: Annotated[
str, typer.Option(help="Path to the configuration TOML file")
str,
typer.Option(default=..., help="Path to the configuration TOML file"),
] = get_default_conf_path(),
skiptest: Annotated[
bool,
typer.Option(help="If testing (using best weights) should be skipped"),
typer.Option(
default=...,
help="If testing (using best weights) should be skipped",
),
] = False,
overwrite: Annotated[
Optional[str],
typer.Option(
...,
default=...,
callback=parse_overwriting_options,
help="Comma-separated key-value pairs (KEY=VALUE)",
),
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ dependencies = [
"toml; python_version<'3.11'"
]

[project.urls]
"Homepage" = "https://opengeokube.github.io/kit4dl/"
"Documentation" = "https://opengeokube.github.io/kit4dl/"
"Source Code" = "https://github.com/opengeokube/kit4dl"
"Bug Tracker" = "https://github.com/opengeokube/kit4dl/issues"

[tool.setuptools.dynamic]
version = {attr = "kit4dl._version.__version__"}

Expand Down
5 changes: 4 additions & 1 deletion tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,13 @@ def base_conf_txt():

@pytest.fixture
def base_conf_txt_full(base_conf_txt):
return base_conf_txt + """
return (
base_conf_txt
+ """
[metrics]
Precision = {target = "torchmetrics::Precision", task = "multiclass", num_classes = 10}
"""
)


@pytest.fixture
Expand Down
65 changes: 52 additions & 13 deletions tests/nn/test_confmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,41 +376,53 @@ def test_define_val_train_test_predict_none_on_trainval_defined(self):
class TestConf:
@pytest.fixture
def full_conf_dict(self, base_conf_txt):
entries = base_conf_txt + """
entries = (
base_conf_txt
+ """
[metrics]
Precision = {target = "torchmetrics::Precision", task = "multiclass", num_classes=10}
"""
)
yield toml.loads(entries)

def test_conf_parse(self, base_conf_txt):
load = base_conf_txt + """
load = (
base_conf_txt
+ """
[metrics]
Precision = {target = "torchmetrics::Precision", task = "multiclass", num_classes = 10}
FBetaScore = {target = "torchmetrics::Recall", task = "multiclass", num_classes = 10, beta = 0.1}
"""
)
_ = Conf(**toml.loads(load))

@pytest.mark.skipif(
sys.version_info < (3, 11), reason="test for Python < 3.11"
)
def test_fail_on_duplicated_key_name(self, base_conf_txt):
load = base_conf_txt + """
load = (
base_conf_txt
+ """
[metrics]
Precision = {target = "torchmetrics::Precision", task = "multiclass", num_classes = 10}
Precision = {target = "torchmetrics::Precision", task = "multiclass", num_classes = 10}
"""
)
with pytest.raises(ValueError, match="Duplicate keys!"):
_ = Conf(**toml.loads(load))

@pytest.mark.skipif(
sys.version_info > (3, 10), reason="test for Python > 3.10"
)
def test_fail_on_duplicated_key_name(self, base_conf_txt):
load = base_conf_txt + """
load = (
base_conf_txt
+ """
[metrics]
Precision = {target = "torchmetrics::Precision", task = "multiclass", num_classes=10}
Precision = {target = "torchmetrics::Precision", task = "multiclass", num_classes=10}
"""
)
with pytest.raises(ValueError, match="Cannot overwrite a value.*"):
_ = Conf(**toml.loads(load))

Expand All @@ -421,10 +433,13 @@ def test_fail_on_duplicated_key_name(self, base_conf_txt):
)
)
def test_conf_custom_metric(self, base_conf_txt):
load = base_conf_txt + """
load = (
base_conf_txt
+ """
[metrics]
test.dummy_module.CustomMetric = {}
"""
)
conf = Conf(**toml.loads(load))

@pytest.mark.skip(
Expand All @@ -434,18 +449,24 @@ def test_conf_custom_metric(self, base_conf_txt):
)
)
def test_conf_custom_metric_fail_on_wrong_parentclass(self, base_conf_txt):
load = base_conf_txt + """
load = (
base_conf_txt
+ """
[metrics]
teset.dummy_module.CustomMetricWrong = {}
"""
)
with pytest.raises(ValidationError, match="duplicate"):
_ = Conf(**toml.loads(load))

def test_conf_fail_on_nonexisting_metric(self, base_conf_txt):
load = base_conf_txt + """
load = (
base_conf_txt
+ """
[metrics]
MyMetric = {target = "torchmetrics::NonExistingMetric"}
"""
)
with pytest.raises(
AttributeError,
match=(
Expand All @@ -455,10 +476,13 @@ def test_conf_fail_on_nonexisting_metric(self, base_conf_txt):
_ = Conf(**toml.loads(load))

def test_conf_fail_on_monitoring_undefined_metric(self, base_conf_txt):
load = base_conf_txt + """
load = (
base_conf_txt
+ """
[metrics]
Precision = {target = "torchmetrics::Precision"}
"""
)
load_dict = toml.loads(load)
load_dict["training"]["checkpoint"]["monitor"] = {
"metric": "Recall",
Expand All @@ -468,21 +492,27 @@ def test_conf_fail_on_monitoring_undefined_metric(self, base_conf_txt):
_ = Conf(**load_dict)

def test_conf_get_metric_obj_failed_on_missing_target(self, base_conf_txt):
load = base_conf_txt + """
load = (
base_conf_txt
+ """
[metrics]
Precision = {}
"""
)
with pytest.raises(
ValidationError,
match=r".*`target` is not defined for some metric.*",
):
conf = Conf(**toml.loads(load))

def test_conf_get_metric_obj_failed_on_missing_task(self, base_conf_txt):
load = base_conf_txt + """
load = (
base_conf_txt
+ """
[metrics]
Precision = {target = "torchmetrics::Precision"}
"""
)
conf = Conf(**toml.loads(load))
with pytest.raises(
TypeError,
Expand All @@ -494,10 +524,13 @@ def test_conf_get_metric_obj_failed_on_missing_task(self, base_conf_txt):
conf.metrics_obj

def test_conf_get_metric_obj(self, base_conf_txt):
load = base_conf_txt + """
load = (
base_conf_txt
+ """
[metrics]
Precision = {target = "torchmetrics::Precision", task = "multiclass", num_classes = 10}
"""
)
conf = Conf(**toml.loads(load))
metrics = conf.metrics_obj
assert "precision" in metrics
Expand Down Expand Up @@ -547,10 +580,13 @@ def test_conf_schedulers_double_preconfigured_schedulers_classes(
)

def test_use_base_exp_name_for_metric_logging(self, base_conf_txt_full):
load = base_conf_txt_full + """
load = (
base_conf_txt_full
+ """
[logging]
type = "csv"
"""
)
conf = Conf(**toml.loads(load))
assert "name" in conf.logging.arguments
assert (
Expand All @@ -564,11 +600,14 @@ def test_use_base_exp_name_for_metric_logging(self, base_conf_txt_full):
def test_dont_override_exp_name_with_base_if_provided(
self, base_conf_txt_full
):
load = base_conf_txt_full + """
load = (
base_conf_txt_full
+ """
[logging]
type = "csv"
name = "logging_exp_name"
"""
)
conf = Conf(**toml.loads(load))
assert "name" in conf.logging.arguments
assert conf.logging.arguments["name"] == "logging_exp_name"
Expand Down

0 comments on commit c97dc65

Please sign in to comment.