Skip to content

Commit

Permalink
[MINOR] Support Pydantic model title (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
faph authored May 23, 2024
2 parents 9b74e38 + 3114ed5 commit 8653c75
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 3 deletions.
34 changes: 32 additions & 2 deletions src/py_avro_schema/_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import decimal
import enum
import inspect
import re
import sys
import types
import uuid
Expand Down Expand Up @@ -108,8 +109,14 @@ class Option(enum.Flag):
#: Do not populate ``doc`` schema attributes based on Python docstrings
NO_DOC = enum.auto()

#: Use the alias specified in a classes ``Field`` instead of the field's name.
#: This currently only affects Pydantic Models
#: Use an alias specified as part of a class instead of the class name itself.
#: This currently affects Pydantic models only.
#: See https://docs.pydantic.dev/dev/api/config/#pydantic.config.ConfigDict.title
USE_CLASS_ALIAS = enum.auto()

#: Use the alias specified in a class field instead of the field/attribute name itself.
#: This currently affects Pydantic models only.
#: See https://docs.pydantic.dev/dev/api/fields/#pydantic.fields.Field
USE_FIELD_ALIAS = enum.auto()


Expand Down Expand Up @@ -162,6 +169,17 @@ def _schema_obj(py_type: Type, namespace: Optional[str] = None, options: Option
raise TypeNotSupportedError(f"Cannot generate Avro schema for Python type {py_type}")


# See https://avro.apache.org/docs/1.11.1/specification/#names
_AVRO_NAME_PATTERN = re.compile(r"^[A-Za-z]([A-Za-z0-9_])*$")


def validate_name(value: str) -> str:
"""Validate (and return) whether a given string is a valid Avro name"""
if not re.match(_AVRO_NAME_PATTERN, value):
raise ValueError(f"'{value}' is not a valid Avro name")
return value


class Schema(abc.ABC):
"""Schema base"""

Expand Down Expand Up @@ -690,6 +708,16 @@ def __str__(self):
"""Human rendering of the schema"""
return self.fullname

@property
def name(self):
"""Return the schema name"""
return self._name

@name.setter
def name(self, value: str):
"""Validate and set the schema name"""
self._name = validate_name(value)

@property
def fullname(self):
"""The schema's full name including the namespace if set"""
Expand Down Expand Up @@ -897,6 +925,8 @@ def __init__(self, py_type: Type[pydantic.BaseModel], namespace: Optional[str] =
:param options: Schema generation options.
"""
super().__init__(py_type, namespace=namespace, options=options)
if Option.USE_CLASS_ALIAS in self.options:
self.name = py_type.model_config.get("title") or self.name
self.py_fields = py_type.model_fields
self.record_fields = [self._record_field(name, field) for name, field in self.py_fields.items()]

Expand Down
34 changes: 33 additions & 1 deletion tests/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def test_field_alias_generator():
class PyType(pydantic.BaseModel):
field_a: str

model_config = {"alias_generator": lambda x: x.upper()}
model_config = pydantic.ConfigDict(alias_generator=lambda x: x.upper())

expected = {
"type": "record",
Expand All @@ -441,6 +441,38 @@ class PyType(pydantic.BaseModel):
assert_schema(PyType, expected, options=pas.Option.USE_FIELD_ALIAS)


def test_class_title():
class PyType(pydantic.BaseModel):
model_config = pydantic.ConfigDict(title="PyTitle")

expected = {
"type": "record",
"name": "PyTitle",
"fields": [],
}
assert_schema(PyType, expected, options=pas.Option.USE_CLASS_ALIAS)


def test_class_title_not_set():
class PyType(pydantic.BaseModel):
model_config = pydantic.ConfigDict()

expected = {
"type": "record",
"name": "PyType",
"fields": [],
}
assert_schema(PyType, expected, options=pas.Option.USE_CLASS_ALIAS)


def test_class_title_with_space():
class PyType(pydantic.BaseModel):
model_config = pydantic.ConfigDict(title="Py Title")

with pytest.raises(ValueError, match="'Py Title' is not a valid Avro name"):
assert_schema(PyType, {}, options=pas.Option.USE_CLASS_ALIAS)


def test_annotated_decimal():
class PyType(pydantic.BaseModel):
field_a: Annotated[
Expand Down

0 comments on commit 8653c75

Please sign in to comment.