Skip to content

Commit

Permalink
Refactor formatting of multiple validation errors
Browse files Browse the repository at this point in the history
  • Loading branch information
aeisenbarth committed Oct 28, 2024
1 parent 0ef8aed commit 7f9223d
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 43 deletions.
15 changes: 12 additions & 3 deletions src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@

from spatialdata._core._elements import Images, Labels, Points, Shapes, Tables
from spatialdata._core.validation import (
ErrorDetails,
ValidationError,
check_all_keys_case_insensitively_unique,
check_target_region_column_symmetry,
check_valid_name,
collect_error_details,
validate_table_attr_keys,
)
from spatialdata._logging import logger
Expand Down Expand Up @@ -1121,10 +1124,16 @@ def _validate_can_safely_write_to_path(
)

def _validate_all_elements(self) -> None:
details: list[ErrorDetails] = []
for element_type, element_name, element in self.gen_elements():
check_valid_name(element_name)
element_path = (element_type, element_name)
with collect_error_details(collection=details, location=element_path):
check_valid_name(element_name)
if element_type == "tables":
validate_table_attr_keys(element)
with collect_error_details(collection=details, location=element_path):
validate_table_attr_keys(element, location=element_path)
if details:
raise ValidationError(title="SpatialData contains elements with invalid names", errors=details)

def write(
self,
Expand Down Expand Up @@ -2001,7 +2010,7 @@ def _validate_element_names_are_unique(self) -> None:
ValueError
If the element names are not unique.
"""
check_all_keys_case_insensitively_unique([name for _, name, _ in self.gen_elements()])
check_all_keys_case_insensitively_unique([name for _, name, _ in self.gen_elements()], location=())

def _find_element(self, element_name: str) -> tuple[str, str, SpatialElement | AnnData]:
"""
Expand Down
139 changes: 105 additions & 34 deletions src/spatialdata/_core/validation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,43 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Callable, Collection
from collections.abc import Collection
from types import TracebackType
from typing import NamedTuple, cast

import pandas as pd
from anndata import AnnData


class ErrorDetails(NamedTuple):
location: tuple[str, ...]
"""Tuple of strings identifying the element for which the error occurred."""

message: str
"""A human readable error message."""


class ValidationError(ValueError):
def __init__(self, title: str, errors: list[ErrorDetails]):
self._errors: list[ErrorDetails] = list(errors)
super().__init__(title)

@property
def title(self) -> str:
return str(self.args[0]) if self.args else ""

@property
def errors(self) -> list[ErrorDetails]:
return list(self._errors)

# def __repr__(self) -> str:
# return self.__str__()

def __str__(self) -> str:
return f"{self.title}:\n" + "\n".join(
f" {'/'.join(str(key) for key in details.location)}: {details.message}" for details in self.errors
)


def check_target_region_column_symmetry(table: AnnData, region_key: str, target: str | pd.Series) -> None:
"""
Check region and region_key column symmetry.
Expand Down Expand Up @@ -84,7 +115,7 @@ def check_valid_name(name: str) -> None:
raise ValueError("Name must contain only alphanumeric characters, underscores, dots and hyphens.")


def check_all_keys_case_insensitively_unique(keys: Collection[str]) -> None:
def check_all_keys_case_insensitively_unique(keys: Collection[str], location: tuple[str, ...] = ()) -> None:
"""
Check that all keys are unique when ignoring case.
Expand All @@ -95,6 +126,8 @@ def check_all_keys_case_insensitively_unique(keys: Collection[str]) -> None:
----------
keys
A collection of string keys
location
Tuple of strings identifying the parent element
Raises
------
Expand All @@ -113,10 +146,14 @@ def check_all_keys_case_insensitively_unique(keys: Collection[str]) -> None:
```
"""
seen: set[str | None] = set()
validation_error = ValidationError(title="Element contains conflicting keys", errors=[])
for key in keys:
normalized_key = key.lower()
check_key_is_case_insensitively_unique(key, seen)
with collect_error_details(collection=validation_error.errors, location=location + (key,)):
check_key_is_case_insensitively_unique(key, seen)
seen.add(normalized_key)
if validation_error.errors:
raise validation_error


def check_key_is_case_insensitively_unique(key: str, other_keys: set[str | None]) -> None:
Expand Down Expand Up @@ -178,28 +215,7 @@ def check_valid_dataframe_column_name(name: str) -> None:
raise ValueError("Name cannot be '_index'")


def _iter_anndata_attr_keys_collect_value_errors(
adata: AnnData, attr_visitor: Callable[[str], None], key_visitor: Callable[[str, str], None]
) -> None:
messages_per_attr: dict[str, list[str]] = defaultdict(list)
for attr in ("obs", "obsm", "obsp", "var", "varm", "varp", "uns", "layers"):
try:
attr_visitor(attr)
except ValueError as e:
messages_per_attr[attr].append(f" {e.args[0]}")
for key in getattr(adata, attr):
try:
key_visitor(attr, key)
except ValueError as e:
messages_per_attr[attr].append(f" '{key}': {e.args[0]}")
if messages_per_attr:
raise ValueError(
"Table contains invalid names:\n"
+ "\n".join(f"{attr}:\n" + "\n".join(messages) for attr, messages in messages_per_attr.items())
)


def validate_table_attr_keys(data: AnnData) -> None:
def validate_table_attr_keys(data: AnnData, location: tuple[str, ...] = ()) -> None:
"""
Check that all keys of all AnnData attributes have valid names.
Expand All @@ -210,20 +226,75 @@ def validate_table_attr_keys(data: AnnData) -> None:
----------
data
The AnnData table
location
Tuple of strings identifying the parent element
Raises
------
ValueError
If the AnnData contains one or several invalid keys.
"""
validation_error = ValidationError(title="Table contains invalid names", errors=[])
for attr in ("obs", "obsm", "obsp", "var", "varm", "varp", "uns", "layers"):
attr_path = location + (attr,)
with collect_error_details(collection=validation_error.errors, location=attr_path):
check_all_keys_case_insensitively_unique(getattr(data, attr).keys(), location=attr_path)
for key in getattr(data, attr):
key_path = attr_path + (key,)
with collect_error_details(collection=validation_error.errors, location=key_path):
if attr in ("obs", "var"):
check_valid_dataframe_column_name(key)
else:
check_valid_name(key)
if validation_error.errors:
raise validation_error


class collect_error_details:
"""
Context manager to possible collect an exception into a list.
def _check_valid_attr_keys(attr: str) -> None:
check_all_keys_case_insensitively_unique(getattr(data, attr).keys())
This is syntactic sugar for shortening the try/except construction when the error handling is
the same.
def _check_valid_attr_key(attr: str, key: str) -> None:
if attr in ("obs", "var"):
check_valid_dataframe_column_name(key)
else:
check_valid_name(key)
Parameters
----------
collection
The list to which to add the exception
location
Tuple of strings identifying the parent element
expected_exception
The class of the exception to catch. Other exceptions are raised.
"""

_iter_anndata_attr_keys_collect_value_errors(data, _check_valid_attr_keys, _check_valid_attr_key)
def __init__(
self,
collection: list[ErrorDetails],
location: tuple[str, ...] = (),
expected_exception: type[BaseException] = ValueError,
) -> None:
self._collection: list[ErrorDetails] = collection
self._location: tuple[str, ...] = location
self._expected_exception: type[BaseException] = expected_exception

def __enter__(self) -> None:
pass

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool:
if exc_type is None:
return True
if not issubclass(exc_type, self._expected_exception):
return False
assert exc_val is not None
if issubclass(exc_type, ValidationError):
exc_val = cast(ValidationError, exc_val)
self._collection += exc_val.errors
else:
details = ErrorDetails(location=self._location, message=str(exc_val.args[0]))
self._collection.append(details)
return True
17 changes: 11 additions & 6 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def test_table_model_invalid_names(self, key: str, attr: str, parse: bool):
if attr in ("obs", "var"):
df = pd.DataFrame([[None]], columns=[key], index=["1"])
adata = AnnData(np.array([[0]]), **{attr: df})
with pytest.raises(ValueError, match=f"Table contains invalid names:\n{attr}:\n '{re.escape(key)}'"):
with pytest.raises(ValueError, match=f"Table contains invalid names:\n{attr}:\n {attr}/{re.escape(key)}"):
if parse:
TableModel.parse(adata)
else:
Expand All @@ -439,14 +439,18 @@ def test_table_model_invalid_names(self, key: str, attr: str, parse: bool):
if attr in ("obsm", "varm", "obsp", "varp", "layers"):
array = np.array([[0]])
adata = AnnData(np.array([[0]]), **{attr: {key: array}})
with pytest.raises(ValueError, match=f"Table contains invalid names:\n{attr}:\n '{re.escape(key)}'"):
with pytest.raises(
ValueError, match=f"Table contains invalid names:\n{attr}:\n {attr}/{re.escape(key)}"
):
if parse:
TableModel.parse(adata)
else:
TableModel().validate(adata)
elif attr == "uns":
adata = AnnData(np.array([[0]]), **{attr: {key: {}}})
with pytest.raises(ValueError, match=f"Table contains invalid names:\n{attr}:\n '{re.escape(key)}'"):
with pytest.raises(
ValueError, match=f"Table contains invalid names:\n{attr}:\n {attr}/{re.escape(key)}"
):
if parse:
TableModel.parse(adata)
else:
Expand All @@ -462,13 +466,14 @@ def test_table_model_invalid_names(self, key: str, attr: str, parse: bool):
@pytest.mark.parametrize("attr", ["obs", "var"])
@pytest.mark.parametrize("parse", [True, False])
def test_table_model_not_unique_columns(self, keys: list[str], attr: str, parse: bool):
key_regex = re.escape(keys[1])
invalid_key = keys[1]
key_regex = re.escape(invalid_key)
df = pd.DataFrame([[None] * len(keys)], columns=keys, index=["1"])
adata = AnnData(np.array([[0]]), **{attr: df})
with pytest.raises(
ValueError,
match=f"Table contains invalid names:\n{attr}:\n"
+ f" Key `{key_regex}` is not unique, or another case-variant of it exists.",
match=f"Table contains invalid names:\n {attr}/{invalid_key}: "
+ f"Key `{key_regex}` is not unique, or another case-variant of it exists.",
):
if parse:
TableModel.parse(adata)
Expand Down

0 comments on commit 7f9223d

Please sign in to comment.