Skip to content

Commit

Permalink
Updates for pydantic v2 (#619)
Browse files Browse the repository at this point in the history
  • Loading branch information
dougiesquire authored Jul 7, 2023
1 parent 2b873ef commit 92e771a
Show file tree
Hide file tree
Showing 10 changed files with 83 additions and 90 deletions.
4 changes: 2 additions & 2 deletions ci/environment-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies:
- myst-nb
- netcdf4!=1.6.1
- pip
- pydantic>=1.9
- pydantic>=2.0
- python-graphviz
- python=3.11
- s3fs >=2023.05
Expand All @@ -27,6 +27,6 @@ dependencies:
- zarr>=2.12
- furo>=2022.09.15
- pip:
- git+https://github.com/ncar-xdev/ecgtools
- sphinxext-opengraph
- autodoc_pydantic
- -e ..
2 changes: 1 addition & 1 deletion ci/environment-upstream-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies:
- pooch
- pre-commit
- psutil
- pydantic>=1.9
- pydantic>=2.0
- pydap
- pyproj
- pytest
Expand Down
2 changes: 1 addition & 1 deletion ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies:
- pip
- pooch
- pre-commit
- pydantic>=1.9
- pydantic>=2.0
- pydap
- pytest
- pytest-cov
Expand Down
4 changes: 0 additions & 4 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
'myst_nb',
'sphinxext.opengraph',
'sphinx_copybutton',
'sphinxcontrib.autodoc_pydantic',
'sphinx_design',
]

Expand All @@ -29,9 +28,6 @@
copybutton_prompt_text = r'>>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: '
copybutton_prompt_is_regexp = True

autodoc_pydantic_model_show_json = True
autodoc_pydantic_model_show_config = False

nb_execution_mode = 'cache'
nb_execution_timeout = 600
nb_execution_raise_on_error = True
Expand Down
10 changes: 8 additions & 2 deletions docs/source/reference/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,19 @@ For more details and examples, refer to the relevant chapters in the main part o
## ESM Catalog

```{eval-rst}
.. autopydantic_model:: intake_esm.cat.ESMCatalogModel
.. autoclass:: intake_esm.cat.ESMCatalogModel
:members:
:noindex:
:special-members: __init__
```

## Query Model

```{eval-rst}
.. autopydantic_model:: intake_esm.cat.QueryModel
.. autoclass:: intake_esm.cat.QueryModel
:members:
:noindex:
:special-members: __init__
```

## Derived Variable Registry
Expand Down
95 changes: 41 additions & 54 deletions intake_esm/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pandas as pd
import pydantic
import tlz
from pydantic import ConfigDict

from ._search import search, search_apply_require_all_on

Expand Down Expand Up @@ -40,9 +41,7 @@ class AggregationType(str, enum.Enum):
join_existing = 'join_existing'
union = 'union'

class Config:
validate_all = True
validate_assignment = True
model_config = ConfigDict(validate_default=True, validate_assignment=True)


class DataFormat(str, enum.Enum):
Expand All @@ -51,57 +50,47 @@ class DataFormat(str, enum.Enum):
reference = 'reference'
opendap = 'opendap'

class Config:
validate_all = True
validate_assignment = True
model_config = ConfigDict(validate_default=True, validate_assignment=True)


class Attribute(pydantic.BaseModel):
column_name: pydantic.StrictStr
vocabulary: pydantic.StrictStr = ''

class Config:
validate_all = True
validate_assignment = True
model_config = ConfigDict(validate_default=True, validate_assignment=True)


class Assets(pydantic.BaseModel):
column_name: pydantic.StrictStr
format: typing.Optional[DataFormat]
format_column_name: typing.Optional[pydantic.StrictStr]
format: typing.Optional[DataFormat] = None
format_column_name: typing.Optional[pydantic.StrictStr] = None

class Config:
validate_all = True
validate_assignment = True
model_config = ConfigDict(validate_default=True, validate_assignment=True)

@pydantic.root_validator
def _validate_data_format(cls, values):
data_format, format_column_name = values.get('format'), values.get('format_column_name')
@pydantic.model_validator(mode='after')
def _validate_data_format(cls, model):
data_format, format_column_name = model.format, model.format_column_name
if data_format is not None and format_column_name is not None:
raise ValueError('Cannot set both format and format_column_name')
elif data_format is None and format_column_name is None:
raise ValueError('Must set one of format or format_column_name')
return values
return model


class Aggregation(pydantic.BaseModel):
type: AggregationType
attribute_name: pydantic.StrictStr
options: typing.Optional[dict] = {}
options: dict = {}

class Config:
validate_all = True
validate_assignment = True
model_config = ConfigDict(validate_default=True, validate_assignment=True)


class AggregationControl(pydantic.BaseModel):
variable_column_name: pydantic.StrictStr
groupby_attrs: list[pydantic.StrictStr]
aggregations: list[Aggregation] = []

class Config:
validate_all = True
validate_assignment = True
model_config = ConfigDict(validate_default=True, validate_assignment=True)


class ESMCatalogModel(pydantic.BaseModel):
Expand All @@ -113,35 +102,33 @@ class ESMCatalogModel(pydantic.BaseModel):
attributes: list[Attribute]
assets: Assets
aggregation_control: typing.Optional[AggregationControl] = None
id: typing.Optional[str] = ''
id: str = ''
catalog_dict: typing.Optional[list[dict]] = None
catalog_file: pydantic.StrictStr = None
description: pydantic.StrictStr = None
title: pydantic.StrictStr = None
catalog_file: typing.Optional[pydantic.StrictStr] = None
description: typing.Optional[pydantic.StrictStr] = None
title: typing.Optional[pydantic.StrictStr] = None
last_updated: typing.Optional[typing.Union[datetime.datetime, datetime.date]] = None
_df: typing.Optional[pd.DataFrame] = pydantic.PrivateAttr()
_df: pd.DataFrame = pydantic.PrivateAttr()

class Config:
arbitrary_types_allowed = True
underscore_attrs_are_private = True
validate_all = True
validate_assignment = True
model_config = ConfigDict(
arbitrary_types_allowed=True, validate_default=True, validate_assignment=True
)

@pydantic.root_validator
def validate_catalog(cls, values):
catalog_dict, catalog_file = values.get('catalog_dict'), values.get('catalog_file')
@pydantic.model_validator(mode='after')
def validate_catalog(cls, model):
catalog_dict, catalog_file = model.catalog_dict, model.catalog_file
if catalog_dict is not None and catalog_file is not None:
raise ValueError('catalog_dict and catalog_file cannot be set at the same time')

return values
return model

@classmethod
def from_dict(cls, data: dict) -> 'ESMCatalogModel':
esmcat = data['esmcat']
df = data['df']
if 'last_updated' not in esmcat:
esmcat['last_updated'] = None
cat = cls.parse_obj(esmcat)
cat = cls.model_validate(esmcat)
cat._df = df
return cat

Expand Down Expand Up @@ -254,7 +241,7 @@ def load(
data = json.loads(fobj.read())
if 'last_updated' not in data:
data['last_updated'] = None
cat = cls.parse_obj(data)
cat = cls.model_validate(data)
if cat.catalog_file:
if _mapper.fs.exists(cat.catalog_file):
csv_path = cat.catalog_file
Expand Down Expand Up @@ -417,32 +404,32 @@ class QueryModel(pydantic.BaseModel):

query: dict[pydantic.StrictStr, typing.Union[typing.Any, list[typing.Any]]]
columns: list[str]
require_all_on: typing.Union[str, list[typing.Any]] = None
require_all_on: typing.Optional[typing.Union[str, list[typing.Any]]] = None

class Config:
validate_all = True
validate_assignment = True
# TODO: Seem to be unable to modify fields in model_validator with
# validate_assignment=True since it leads to recursion
model_config = ConfigDict(validate_default=True, validate_assignment=False)

@pydantic.root_validator(pre=False)
def validate_query(cls, values):
query = values.get('query', {})
columns = values.get('columns')
require_all_on = values.get('require_all_on', [])
@pydantic.model_validator(mode='after')
def validate_query(cls, model):
query = model.query
columns = model.columns
require_all_on = model.require_all_on

if query:
for key in query:
if key not in columns:
raise ValueError(f'Column {key} not in columns {columns}')
if isinstance(require_all_on, str):
values['require_all_on'] = [require_all_on]
model.require_all_on = [require_all_on]
if require_all_on is not None:
for key in values['require_all_on']:
for key in model.require_all_on:
if key not in columns:
raise ValueError(f'Column {key} not in columns {columns}')
_query = query.copy()
for key, value in _query.items():
if isinstance(value, (str, int, float, bool)) or value is None or value is pd.NA:
_query[key] = [value]

values['query'] = _query
return values
model.query = _query
return model
38 changes: 21 additions & 17 deletions intake_esm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,11 @@ def _ipython_key_completions_(self):
return self.__dir__()

@pydantic.validate_arguments
def search(self, require_all_on: typing.Union[str, list[str]] = None, **query: typing.Any):
def search(
self,
require_all_on: typing.Optional[typing.Union[str, list[str]]] = None,
**query: typing.Any,
):
"""Search for entries in the catalog.
Parameters
Expand Down Expand Up @@ -443,11 +447,11 @@ def search(self, require_all_on: typing.Union[str, list[str]] = None, **query: t
def serialize(
self,
name: pydantic.StrictStr,
directory: typing.Union[pydantic.DirectoryPath, pydantic.StrictStr] = None,
directory: typing.Optional[typing.Union[pydantic.DirectoryPath, pydantic.StrictStr]] = None,
catalog_type: str = 'dict',
to_csv_kwargs: dict[typing.Any, typing.Any] = None,
json_dump_kwargs: dict[typing.Any, typing.Any] = None,
storage_options: dict[str, typing.Any] = None,
to_csv_kwargs: typing.Optional[dict[typing.Any, typing.Any]] = None,
json_dump_kwargs: typing.Optional[dict[typing.Any, typing.Any]] = None,
storage_options: typing.Optional[dict[str, typing.Any]] = None,
) -> None:
"""Serialize catalog to corresponding json and csv files.
Expand Down Expand Up @@ -536,12 +540,12 @@ def unique(self) -> pd.Series:
@pydantic.validate_arguments
def to_dataset_dict(
self,
xarray_open_kwargs: dict[str, typing.Any] = None,
xarray_combine_by_coords_kwargs: dict[str, typing.Any] = None,
preprocess: typing.Callable = None,
storage_options: dict[pydantic.StrictStr, typing.Any] = None,
progressbar: pydantic.StrictBool = None,
aggregate: pydantic.StrictBool = None,
xarray_open_kwargs: typing.Optional[dict[str, typing.Any]] = None,
xarray_combine_by_coords_kwargs: typing.Optional[dict[str, typing.Any]] = None,
preprocess: typing.Optional[typing.Callable] = None,
storage_options: typing.Optional[dict[pydantic.StrictStr, typing.Any]] = None,
progressbar: typing.Optional[pydantic.StrictBool] = None,
aggregate: typing.Optional[pydantic.StrictBool] = None,
skip_on_error: pydantic.StrictBool = False,
**kwargs,
) -> dict[str, xr.Dataset]:
Expand Down Expand Up @@ -686,12 +690,12 @@ def to_dataset_dict(
@pydantic.validate_arguments
def to_datatree(
self,
xarray_open_kwargs: dict[str, typing.Any] = None,
xarray_combine_by_coords_kwargs: dict[str, typing.Any] = None,
preprocess: typing.Callable = None,
storage_options: dict[pydantic.StrictStr, typing.Any] = None,
progressbar: pydantic.StrictBool = None,
aggregate: pydantic.StrictBool = None,
xarray_open_kwargs: typing.Optional[dict[str, typing.Any]] = None,
xarray_combine_by_coords_kwargs: typing.Optional[dict[str, typing.Any]] = None,
preprocess: typing.Optional[typing.Callable] = None,
storage_options: typing.Optional[dict[pydantic.StrictStr, typing.Any]] = None,
progressbar: typing.Optional[pydantic.StrictBool] = None,
aggregate: typing.Optional[pydantic.StrictBool] = None,
skip_on_error: pydantic.StrictBool = False,
**kwargs,
):
Expand Down
4 changes: 2 additions & 2 deletions intake_esm/derived.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class DerivedVariable(pydantic.BaseModel):
query: dict[pydantic.StrictStr, typing.Union[typing.Any, list[typing.Any]]]
prefer_derived: bool

@pydantic.validator('query')
@pydantic.field_validator('query')
def validate_query(cls, values):
_query = values.copy()
for key, value in _query.items():
Expand Down Expand Up @@ -46,7 +46,7 @@ def __call__(self, *args, variable_key_name: str = None, **kwargs) -> xr.Dataset
class DerivedVariableRegistry:
"""Registry of derived variables"""

def __post_init_post_parse__(self):
def __post_init__(self):
self._registry = {}

@classmethod
Expand Down
12 changes: 6 additions & 6 deletions intake_esm/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,12 @@ def __init__(
*,
variable_column_name: typing.Optional[pydantic.StrictStr] = None,
aggregations: typing.Optional[list[Aggregation]] = None,
requested_variables: list[str] = None,
preprocess: typing.Callable = None,
storage_options: dict[str, typing.Any] = None,
xarray_open_kwargs: dict[str, typing.Any] = None,
xarray_combine_by_coords_kwargs: dict[str, typing.Any] = None,
intake_kwargs: dict[str, typing.Any] = None,
requested_variables: typing.Optional[list[str]] = None,
preprocess: typing.Optional[typing.Callable] = None,
storage_options: typing.Optional[dict[str, typing.Any]] = None,
xarray_open_kwargs: typing.Optional[dict[str, typing.Any]] = None,
xarray_combine_by_coords_kwargs: typing.Optional[dict[str, typing.Any]] = None,
intake_kwargs: typing.Optional[dict[str, typing.Any]] = None,
):
"""An intake compatible Data Source for ESM data.
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ netCDF4>=1.5.5
requests>=2.24.0
xarray>=2022.06
zarr>=2.12
pydantic>=1.9
pydantic>=2.0

0 comments on commit 92e771a

Please sign in to comment.