From 92e771abbb89116ae7f43198a141d2098aaab48d Mon Sep 17 00:00:00 2001 From: Dougie Squire <42455466+dougiesquire@users.noreply.github.com> Date: Fri, 7 Jul 2023 12:14:37 +1000 Subject: [PATCH] Updates for pydantic v2 (#619) --- ci/environment-docs.yml | 4 +- ci/environment-upstream-dev.yml | 2 +- ci/environment.yml | 2 +- docs/source/conf.py | 4 -- docs/source/reference/api.md | 10 +++- intake_esm/cat.py | 95 ++++++++++++++------------------- intake_esm/core.py | 38 +++++++------ intake_esm/derived.py | 4 +- intake_esm/source.py | 12 ++--- requirements.txt | 2 +- 10 files changed, 83 insertions(+), 90 deletions(-) diff --git a/ci/environment-docs.yml b/ci/environment-docs.yml index d92ad17d..14200b89 100644 --- a/ci/environment-docs.yml +++ b/ci/environment-docs.yml @@ -14,7 +14,7 @@ dependencies: - myst-nb - netcdf4!=1.6.1 - pip - - pydantic>=1.9 + - pydantic>=2.0 - python-graphviz - python=3.11 - s3fs >=2023.05 @@ -27,6 +27,6 @@ dependencies: - zarr>=2.12 - furo>=2022.09.15 - pip: + - git+https://github.com/ncar-xdev/ecgtools - sphinxext-opengraph - - autodoc_pydantic - -e .. diff --git a/ci/environment-upstream-dev.yml b/ci/environment-upstream-dev.yml index 27c2e03e..594e855c 100644 --- a/ci/environment-upstream-dev.yml +++ b/ci/environment-upstream-dev.yml @@ -18,7 +18,7 @@ dependencies: - pooch - pre-commit - psutil - - pydantic>=1.9 + - pydantic>=2.0 - pydap - pyproj - pytest diff --git a/ci/environment.yml b/ci/environment.yml index dadae80b..2e2fdd43 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -16,7 +16,7 @@ dependencies: - pip - pooch - pre-commit - - pydantic>=1.9 + - pydantic>=2.0 - pydap - pytest - pytest-cov diff --git a/docs/source/conf.py b/docs/source/conf.py index 83e48394..c2a2effb 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -16,7 +16,6 @@ 'myst_nb', 'sphinxext.opengraph', 'sphinx_copybutton', - 'sphinxcontrib.autodoc_pydantic', 'sphinx_design', ] @@ -29,9 +28,6 @@ copybutton_prompt_text = r'>>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: ' copybutton_prompt_is_regexp = True -autodoc_pydantic_model_show_json = True -autodoc_pydantic_model_show_config = False - nb_execution_mode = 'cache' nb_execution_timeout = 600 nb_execution_raise_on_error = True diff --git a/docs/source/reference/api.md b/docs/source/reference/api.md index 958da5fa..50c71ac6 100644 --- a/docs/source/reference/api.md +++ b/docs/source/reference/api.md @@ -24,13 +24,19 @@ For more details and examples, refer to the relevant chapters in the main part o ## ESM Catalog ```{eval-rst} -.. autopydantic_model:: intake_esm.cat.ESMCatalogModel +.. autoclass:: intake_esm.cat.ESMCatalogModel + :members: + :noindex: + :special-members: __init__ ``` ## Query Model ```{eval-rst} -.. autopydantic_model:: intake_esm.cat.QueryModel +.. autoclass:: intake_esm.cat.QueryModel + :members: + :noindex: + :special-members: __init__ ``` ## Derived Variable Registry diff --git a/intake_esm/cat.py b/intake_esm/cat.py index c5290003..94a5b58f 100644 --- a/intake_esm/cat.py +++ b/intake_esm/cat.py @@ -9,6 +9,7 @@ import pandas as pd import pydantic import tlz +from pydantic import ConfigDict from ._search import search, search_apply_require_all_on @@ -40,9 +41,7 @@ class AggregationType(str, enum.Enum): join_existing = 'join_existing' union = 'union' - class Config: - validate_all = True - validate_assignment = True + model_config = ConfigDict(validate_default=True, validate_assignment=True) class DataFormat(str, enum.Enum): @@ -51,47 +50,39 @@ class DataFormat(str, enum.Enum): reference = 'reference' opendap = 'opendap' - class Config: - validate_all = True - validate_assignment = True + model_config = ConfigDict(validate_default=True, validate_assignment=True) class Attribute(pydantic.BaseModel): column_name: pydantic.StrictStr vocabulary: pydantic.StrictStr = '' - class Config: - validate_all = True - validate_assignment = True + model_config = ConfigDict(validate_default=True, validate_assignment=True) class Assets(pydantic.BaseModel): column_name: pydantic.StrictStr - format: typing.Optional[DataFormat] - format_column_name: typing.Optional[pydantic.StrictStr] + format: typing.Optional[DataFormat] = None + format_column_name: typing.Optional[pydantic.StrictStr] = None - class Config: - validate_all = True - validate_assignment = True + model_config = ConfigDict(validate_default=True, validate_assignment=True) - @pydantic.root_validator - def _validate_data_format(cls, values): - data_format, format_column_name = values.get('format'), values.get('format_column_name') + @pydantic.model_validator(mode='after') + def _validate_data_format(cls, model): + data_format, format_column_name = model.format, model.format_column_name if data_format is not None and format_column_name is not None: raise ValueError('Cannot set both format and format_column_name') elif data_format is None and format_column_name is None: raise ValueError('Must set one of format or format_column_name') - return values + return model class Aggregation(pydantic.BaseModel): type: AggregationType attribute_name: pydantic.StrictStr - options: typing.Optional[dict] = {} + options: dict = {} - class Config: - validate_all = True - validate_assignment = True + model_config = ConfigDict(validate_default=True, validate_assignment=True) class AggregationControl(pydantic.BaseModel): @@ -99,9 +90,7 @@ class AggregationControl(pydantic.BaseModel): groupby_attrs: list[pydantic.StrictStr] aggregations: list[Aggregation] = [] - class Config: - validate_all = True - validate_assignment = True + model_config = ConfigDict(validate_default=True, validate_assignment=True) class ESMCatalogModel(pydantic.BaseModel): @@ -113,27 +102,25 @@ class ESMCatalogModel(pydantic.BaseModel): attributes: list[Attribute] assets: Assets aggregation_control: typing.Optional[AggregationControl] = None - id: typing.Optional[str] = '' + id: str = '' catalog_dict: typing.Optional[list[dict]] = None - catalog_file: pydantic.StrictStr = None - description: pydantic.StrictStr = None - title: pydantic.StrictStr = None + catalog_file: typing.Optional[pydantic.StrictStr] = None + description: typing.Optional[pydantic.StrictStr] = None + title: typing.Optional[pydantic.StrictStr] = None last_updated: typing.Optional[typing.Union[datetime.datetime, datetime.date]] = None - _df: typing.Optional[pd.DataFrame] = pydantic.PrivateAttr() + _df: pd.DataFrame = pydantic.PrivateAttr() - class Config: - arbitrary_types_allowed = True - underscore_attrs_are_private = True - validate_all = True - validate_assignment = True + model_config = ConfigDict( + arbitrary_types_allowed=True, validate_default=True, validate_assignment=True + ) - @pydantic.root_validator - def validate_catalog(cls, values): - catalog_dict, catalog_file = values.get('catalog_dict'), values.get('catalog_file') + @pydantic.model_validator(mode='after') + def validate_catalog(cls, model): + catalog_dict, catalog_file = model.catalog_dict, model.catalog_file if catalog_dict is not None and catalog_file is not None: raise ValueError('catalog_dict and catalog_file cannot be set at the same time') - return values + return model @classmethod def from_dict(cls, data: dict) -> 'ESMCatalogModel': @@ -141,7 +128,7 @@ def from_dict(cls, data: dict) -> 'ESMCatalogModel': df = data['df'] if 'last_updated' not in esmcat: esmcat['last_updated'] = None - cat = cls.parse_obj(esmcat) + cat = cls.model_validate(esmcat) cat._df = df return cat @@ -254,7 +241,7 @@ def load( data = json.loads(fobj.read()) if 'last_updated' not in data: data['last_updated'] = None - cat = cls.parse_obj(data) + cat = cls.model_validate(data) if cat.catalog_file: if _mapper.fs.exists(cat.catalog_file): csv_path = cat.catalog_file @@ -417,26 +404,26 @@ class QueryModel(pydantic.BaseModel): query: dict[pydantic.StrictStr, typing.Union[typing.Any, list[typing.Any]]] columns: list[str] - require_all_on: typing.Union[str, list[typing.Any]] = None + require_all_on: typing.Optional[typing.Union[str, list[typing.Any]]] = None - class Config: - validate_all = True - validate_assignment = True + # TODO: Seem to be unable to modify fields in model_validator with + # validate_assignment=True since it leads to recursion + model_config = ConfigDict(validate_default=True, validate_assignment=False) - @pydantic.root_validator(pre=False) - def validate_query(cls, values): - query = values.get('query', {}) - columns = values.get('columns') - require_all_on = values.get('require_all_on', []) + @pydantic.model_validator(mode='after') + def validate_query(cls, model): + query = model.query + columns = model.columns + require_all_on = model.require_all_on if query: for key in query: if key not in columns: raise ValueError(f'Column {key} not in columns {columns}') if isinstance(require_all_on, str): - values['require_all_on'] = [require_all_on] + model.require_all_on = [require_all_on] if require_all_on is not None: - for key in values['require_all_on']: + for key in model.require_all_on: if key not in columns: raise ValueError(f'Column {key} not in columns {columns}') _query = query.copy() @@ -444,5 +431,5 @@ def validate_query(cls, values): if isinstance(value, (str, int, float, bool)) or value is None or value is pd.NA: _query[key] = [value] - values['query'] = _query - return values + model.query = _query + return model diff --git a/intake_esm/core.py b/intake_esm/core.py index ce0dd7ac..ca899e3b 100644 --- a/intake_esm/core.py +++ b/intake_esm/core.py @@ -329,7 +329,11 @@ def _ipython_key_completions_(self): return self.__dir__() @pydantic.validate_arguments - def search(self, require_all_on: typing.Union[str, list[str]] = None, **query: typing.Any): + def search( + self, + require_all_on: typing.Optional[typing.Union[str, list[str]]] = None, + **query: typing.Any, + ): """Search for entries in the catalog. Parameters @@ -443,11 +447,11 @@ def search(self, require_all_on: typing.Union[str, list[str]] = None, **query: t def serialize( self, name: pydantic.StrictStr, - directory: typing.Union[pydantic.DirectoryPath, pydantic.StrictStr] = None, + directory: typing.Optional[typing.Union[pydantic.DirectoryPath, pydantic.StrictStr]] = None, catalog_type: str = 'dict', - to_csv_kwargs: dict[typing.Any, typing.Any] = None, - json_dump_kwargs: dict[typing.Any, typing.Any] = None, - storage_options: dict[str, typing.Any] = None, + to_csv_kwargs: typing.Optional[dict[typing.Any, typing.Any]] = None, + json_dump_kwargs: typing.Optional[dict[typing.Any, typing.Any]] = None, + storage_options: typing.Optional[dict[str, typing.Any]] = None, ) -> None: """Serialize catalog to corresponding json and csv files. @@ -536,12 +540,12 @@ def unique(self) -> pd.Series: @pydantic.validate_arguments def to_dataset_dict( self, - xarray_open_kwargs: dict[str, typing.Any] = None, - xarray_combine_by_coords_kwargs: dict[str, typing.Any] = None, - preprocess: typing.Callable = None, - storage_options: dict[pydantic.StrictStr, typing.Any] = None, - progressbar: pydantic.StrictBool = None, - aggregate: pydantic.StrictBool = None, + xarray_open_kwargs: typing.Optional[dict[str, typing.Any]] = None, + xarray_combine_by_coords_kwargs: typing.Optional[dict[str, typing.Any]] = None, + preprocess: typing.Optional[typing.Callable] = None, + storage_options: typing.Optional[dict[pydantic.StrictStr, typing.Any]] = None, + progressbar: typing.Optional[pydantic.StrictBool] = None, + aggregate: typing.Optional[pydantic.StrictBool] = None, skip_on_error: pydantic.StrictBool = False, **kwargs, ) -> dict[str, xr.Dataset]: @@ -686,12 +690,12 @@ def to_dataset_dict( @pydantic.validate_arguments def to_datatree( self, - xarray_open_kwargs: dict[str, typing.Any] = None, - xarray_combine_by_coords_kwargs: dict[str, typing.Any] = None, - preprocess: typing.Callable = None, - storage_options: dict[pydantic.StrictStr, typing.Any] = None, - progressbar: pydantic.StrictBool = None, - aggregate: pydantic.StrictBool = None, + xarray_open_kwargs: typing.Optional[dict[str, typing.Any]] = None, + xarray_combine_by_coords_kwargs: typing.Optional[dict[str, typing.Any]] = None, + preprocess: typing.Optional[typing.Callable] = None, + storage_options: typing.Optional[dict[pydantic.StrictStr, typing.Any]] = None, + progressbar: typing.Optional[pydantic.StrictBool] = None, + aggregate: typing.Optional[pydantic.StrictBool] = None, skip_on_error: pydantic.StrictBool = False, **kwargs, ): diff --git a/intake_esm/derived.py b/intake_esm/derived.py index ec375d5c..5d70d6e4 100644 --- a/intake_esm/derived.py +++ b/intake_esm/derived.py @@ -17,7 +17,7 @@ class DerivedVariable(pydantic.BaseModel): query: dict[pydantic.StrictStr, typing.Union[typing.Any, list[typing.Any]]] prefer_derived: bool - @pydantic.validator('query') + @pydantic.field_validator('query') def validate_query(cls, values): _query = values.copy() for key, value in _query.items(): @@ -46,7 +46,7 @@ def __call__(self, *args, variable_key_name: str = None, **kwargs) -> xr.Dataset class DerivedVariableRegistry: """Registry of derived variables""" - def __post_init_post_parse__(self): + def __post_init__(self): self._registry = {} @classmethod diff --git a/intake_esm/source.py b/intake_esm/source.py index b7bef525..af6d94ee 100644 --- a/intake_esm/source.py +++ b/intake_esm/source.py @@ -136,12 +136,12 @@ def __init__( *, variable_column_name: typing.Optional[pydantic.StrictStr] = None, aggregations: typing.Optional[list[Aggregation]] = None, - requested_variables: list[str] = None, - preprocess: typing.Callable = None, - storage_options: dict[str, typing.Any] = None, - xarray_open_kwargs: dict[str, typing.Any] = None, - xarray_combine_by_coords_kwargs: dict[str, typing.Any] = None, - intake_kwargs: dict[str, typing.Any] = None, + requested_variables: typing.Optional[list[str]] = None, + preprocess: typing.Optional[typing.Callable] = None, + storage_options: typing.Optional[dict[str, typing.Any]] = None, + xarray_open_kwargs: typing.Optional[dict[str, typing.Any]] = None, + xarray_combine_by_coords_kwargs: typing.Optional[dict[str, typing.Any]] = None, + intake_kwargs: typing.Optional[dict[str, typing.Any]] = None, ): """An intake compatible Data Source for ESM data. diff --git a/requirements.txt b/requirements.txt index 895305b9..09f63315 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,4 @@ netCDF4>=1.5.5 requests>=2.24.0 xarray>=2022.06 zarr>=2.12 -pydantic>=1.9 +pydantic>=2.0