From 90ee9173010421a18d043ae731fb5901b16af254 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 24 Jun 2024 10:10:43 -0500 Subject: [PATCH] Make config wrap validators --- src/config.rs | 327 +++++++++++++++++++++++++++++ src/lib.rs | 1 + src/validators/config.rs | 45 ++++ src/validators/dataclass.rs | 23 +- src/validators/float.rs | 58 ++--- src/validators/mod.rs | 4 + src/validators/model.rs | 60 ++---- src/validators/validation_state.rs | 15 +- tests/validators/test_model.py | 3 +- 9 files changed, 449 insertions(+), 87 deletions(-) create mode 100644 src/config.rs create mode 100644 src/validators/config.rs diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 000000000..767428f4d --- /dev/null +++ b/src/config.rs @@ -0,0 +1,327 @@ +use pyo3::prelude::*; +use pyo3::types::PyType; +use pyo3::{exceptions::PyValueError, types::PyDict, FromPyObject, IntoPy, PyAny, PyObject, PyResult, Python}; + +#[derive(Debug, Clone, Default)] +pub struct CoreConfig { + pub title: Option, + pub strict: Option, + pub extra_fields_behavior: Option, + pub typed_dict_total: Option, + pub from_attributes: Option, + pub loc_by_alias: Option, + pub revalidate_instances: Option, + pub validate_default: Option, + pub populate_by_name: Option, + pub str_max_length: Option, + pub str_min_length: Option, + pub str_strip_whitespace: Option, + pub str_to_lower: Option, + pub str_to_upper: Option, + pub allow_inf_nan: Option, + pub ser_json_timedelta: Option, + pub ser_json_bytes: Option, + pub ser_json_inf_nan: Option, + pub hide_input_in_errors: Option, + pub validation_error_cause: Option, + pub coerce_numbers_to_str: Option, + pub regex_engine: Option, + pub cache_strings: Option, +} + +impl TryFrom> for CoreConfig { + type Error = PyErr; + fn try_from(value: Bound<'_, PyDict>) -> Result { + Ok(CoreConfig { + title: value.get_item("title")?.map(|v| v.extract()).transpose()?, + strict: value.get_item("strict")?.map(|v| v.extract()).transpose()?, + extra_fields_behavior: value + .get_item("extra_fields_behavior")? + .map(|v| v.extract()) + .transpose()?, + typed_dict_total: value.get_item("typed_dict_total")?.map(|v| v.extract()).transpose()?, + from_attributes: value.get_item("from_attributes")?.map(|v| v.extract()).transpose()?, + loc_by_alias: value.get_item("loc_by_alias")?.map(|v| v.extract()).transpose()?, + revalidate_instances: value + .get_item("revalidate_instances")? + .map(|v| v.extract()) + .transpose()?, + validate_default: value.get_item("validate_default")?.map(|v| v.extract()).transpose()?, + populate_by_name: value.get_item("populate_by_name")?.map(|v| v.extract()).transpose()?, + str_max_length: value.get_item("str_max_length")?.map(|v| v.extract()).transpose()?, + str_min_length: value.get_item("str_min_length")?.map(|v| v.extract()).transpose()?, + str_strip_whitespace: value + .get_item("str_strip_whitespace")? + .map(|v| v.extract()) + .transpose()?, + str_to_lower: value.get_item("str_to_lower")?.map(|v| v.extract()).transpose()?, + str_to_upper: value.get_item("str_to_upper")?.map(|v| v.extract()).transpose()?, + allow_inf_nan: value.get_item("allow_inf_nan")?.map(|v| v.extract()).transpose()?, + ser_json_timedelta: value.get_item("ser_json_timedelta")?.map(|v| v.extract()).transpose()?, + ser_json_bytes: value.get_item("ser_json_bytes")?.map(|v| v.extract()).transpose()?, + ser_json_inf_nan: value.get_item("ser_json_inf_nan")?.map(|v| v.extract()).transpose()?, + hide_input_in_errors: value + .get_item("hide_input_in_errors")? + .map(|v| v.extract()) + .transpose()?, + validation_error_cause: value + .get_item("validation_error_cause")? + .map(|v| v.extract()) + .transpose()?, + coerce_numbers_to_str: value + .get_item("coerce_numbers_to_str")? + .map(|v| v.extract()) + .transpose()?, + regex_engine: value.get_item("regex_engine")?.map(|v| v.extract()).transpose()?, + cache_strings: value.get_item("cache_strings")?.map(|v| v.extract()).transpose()?, + }) + } +} + +impl IntoPy for CoreConfig { + fn into_py(self, py: Python<'_>) -> PyObject { + let dict = PyDict::new_bound(py); + if let Some(title) = self.title { + dict.set_item("title", title).unwrap(); + }; + if let Some(strict) = self.strict { + dict.set_item("strict", strict).unwrap(); + }; + if let Some(extra_fields_behavior) = self.extra_fields_behavior { + let value = match extra_fields_behavior { + ExtraBehavior::Allow => "allow", + ExtraBehavior::Ignore => "ignore", + ExtraBehavior::Error => "error", + }; + dict.set_item("extra_fields_behavior", value).unwrap(); + }; + if let Some(typed_dict_total) = self.typed_dict_total { + dict.set_item("typed_dict_total", typed_dict_total).unwrap(); + }; + if let Some(from_attributes) = self.from_attributes { + dict.set_item("from_attributes", from_attributes).unwrap(); + }; + if let Some(loc_by_alias) = self.loc_by_alias { + dict.set_item("loc_by_alias", loc_by_alias).unwrap(); + }; + if let Some(revalidate_instances) = self.revalidate_instances { + let value = match revalidate_instances { + RevalidateInstances::Always => "always", + RevalidateInstances::Never => "never", + RevalidateInstances::SubclassInstances => "subclass_instances", + }; + dict.set_item("revalidate_instances", value).unwrap(); + }; + if let Some(validate_default) = self.validate_default { + dict.set_item("validate_default", validate_default).unwrap(); + }; + if let Some(populate_by_name) = self.populate_by_name { + dict.set_item("populate_by_name", populate_by_name).unwrap(); + }; + if let Some(str_max_length) = self.str_max_length { + dict.set_item("str_max_length", str_max_length).unwrap(); + }; + if let Some(str_min_length) = self.str_min_length { + dict.set_item("str_min_length", str_min_length).unwrap(); + }; + if let Some(str_strip_whitespace) = self.str_strip_whitespace { + dict.set_item("str_strip_whitespace", str_strip_whitespace).unwrap(); + }; + if let Some(str_to_lower) = self.str_to_lower { + dict.set_item("str_to_lower", str_to_lower).unwrap(); + }; + if let Some(str_to_upper) = self.str_to_upper { + dict.set_item("str_to_upper", str_to_upper).unwrap(); + }; + if let Some(allow_inf_nan) = self.allow_inf_nan { + dict.set_item("allow_inf_nan", allow_inf_nan).unwrap(); + }; + if let Some(ser_json_timedelta) = self.ser_json_timedelta { + let value = match ser_json_timedelta { + SerJsonTimedelta::Iso8601 => "iso8601", + SerJsonTimedelta::Float => "float", + }; + dict.set_item("ser_json_timedelta", value).unwrap(); + }; + if let Some(ser_json_bytes) = self.ser_json_bytes { + let value = match ser_json_bytes { + SerJsonBytes::Utf8 => "utf8", + SerJsonBytes::Base64 => "base64", + SerJsonBytes::Hex => "hex", + }; + dict.set_item("ser_json_bytes", value).unwrap(); + }; + if let Some(ser_json_inf_nan) = self.ser_json_inf_nan { + let value = match ser_json_inf_nan { + SerJsonInfNan::Null => "null", + SerJsonInfNan::Constants => "constants", + SerJsonInfNan::Strings => "strings", + }; + dict.set_item("ser_json_inf_nan", value).unwrap(); + }; + if let Some(hide_input_in_errors) = self.hide_input_in_errors { + dict.set_item("hide_input_in_errors", hide_input_in_errors).unwrap(); + }; + if let Some(validation_error_cause) = self.validation_error_cause { + dict.set_item("validation_error_cause", validation_error_cause).unwrap(); + }; + if let Some(coerce_numbers_to_str) = self.coerce_numbers_to_str { + dict.set_item("coerce_numbers_to_str", coerce_numbers_to_str).unwrap(); + }; + if let Some(regex_engine) = self.regex_engine { + let value = match regex_engine { + RegexEngine::RustRegex => "rust_regex", + RegexEngine::PythonRe => "python_re", + }; + dict.set_item("regex_engine", value).unwrap(); + }; + if let Some(cache_strings) = self.cache_strings { + let value = match cache_strings { + CacheStrings::All => "all", + CacheStrings::Keys => "keys", + CacheStrings::None => "none", + }; + dict.set_item("cache_strings", value).unwrap(); + }; + dict.into() + } +} + +#[derive(Debug, Clone)] +pub enum ExtraBehavior { + Allow, + Ignore, + Error, +} + +impl FromPyObject<'_> for ExtraBehavior { + fn extract(ob: &PyAny) -> PyResult { + let value = ob.extract::()?; + match value.as_str() { + "allow" => Ok(ExtraBehavior::Allow), + "ignore" => Ok(ExtraBehavior::Ignore), + "error" => Ok(ExtraBehavior::Error), + _ => Err(PyValueError::new_err("Invalid value for extra_fields_behavior")), + } + } +} + +#[derive(Debug, Clone)] +pub enum RevalidateInstances { + Always, + Never, + SubclassInstances, +} + +impl FromPyObject<'_> for RevalidateInstances { + fn extract(ob: &PyAny) -> PyResult { + let value = ob.extract::()?; + match value.as_str() { + "always" => Ok(RevalidateInstances::Always), + "never" => Ok(RevalidateInstances::Never), + "subclass_instances" => Ok(RevalidateInstances::SubclassInstances), + _ => Err(PyValueError::new_err("Invalid value for revalidate_instances")), + } + } +} + +impl RevalidateInstances { + pub fn should_revalidate(&self, input: &Bound<'_, PyAny>, class: &Bound<'_, PyType>) -> bool { + match self { + RevalidateInstances::Always => true, + RevalidateInstances::Never => false, + RevalidateInstances::SubclassInstances => !input.is_exact_instance(class), + } + } +} + +#[derive(Debug, Clone)] +pub enum SerJsonTimedelta { + Iso8601, + Float, +} + +impl FromPyObject<'_> for SerJsonTimedelta { + fn extract(ob: &PyAny) -> PyResult { + let value = ob.extract::()?; + match value.as_str() { + "iso8601" => Ok(SerJsonTimedelta::Iso8601), + "float" => Ok(SerJsonTimedelta::Float), + _ => Err(PyValueError::new_err("Invalid value for ser_json_timedelta")), + } + } +} + +#[derive(Debug, Clone)] +pub enum SerJsonBytes { + Utf8, + Base64, + Hex, +} + +impl FromPyObject<'_> for SerJsonBytes { + fn extract(ob: &PyAny) -> PyResult { + let value = ob.extract::()?; + match value.as_str() { + "utf8" => Ok(SerJsonBytes::Utf8), + "base64" => Ok(SerJsonBytes::Base64), + "hex" => Ok(SerJsonBytes::Hex), + _ => Err(PyValueError::new_err("Invalid value for ser_json_bytes")), + } + } +} + +#[derive(Debug, Clone)] +pub enum SerJsonInfNan { + Null, + Constants, + Strings, +} + +impl FromPyObject<'_> for SerJsonInfNan { + fn extract(ob: &PyAny) -> PyResult { + let value = ob.extract::()?; + match value.as_str() { + "null" => Ok(SerJsonInfNan::Null), + "constants" => Ok(SerJsonInfNan::Constants), + "strings" => Ok(SerJsonInfNan::Strings), + _ => Err(PyValueError::new_err("Invalid value for ser_json_inf_nan")), + } + } +} + +#[derive(Debug, Clone)] +pub enum RegexEngine { + RustRegex, + PythonRe, +} + +impl FromPyObject<'_> for RegexEngine { + fn extract(ob: &PyAny) -> PyResult { + let value = ob.extract::()?; + match value.as_str() { + "rust_regex" => Ok(RegexEngine::RustRegex), + "python_re" => Ok(RegexEngine::PythonRe), + _ => Err(PyValueError::new_err("Invalid value for regex_engine")), + } + } +} + +#[derive(Debug, Clone)] +pub enum CacheStrings { + All, + Keys, + None, +} + +impl FromPyObject<'_> for CacheStrings { + fn extract(ob: &PyAny) -> PyResult { + let value = ob.extract::()?; + match value.as_str() { + "all" => Ok(CacheStrings::All), + "keys" => Ok(CacheStrings::Keys), + "none" => Ok(CacheStrings::None), + _ => Err(PyValueError::new_err("Invalid value for cache_strings")), + } + } +} diff --git a/src/lib.rs b/src/lib.rs index d55e83645..999d5d45f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,7 @@ mod py_gc; mod argument_markers; mod build_tools; +mod config; mod definitions; mod errors; mod input; diff --git a/src/validators/config.rs b/src/validators/config.rs new file mode 100644 index 000000000..9c2ee1e96 --- /dev/null +++ b/src/validators/config.rs @@ -0,0 +1,45 @@ +use std::sync::Arc; + +use pyo3::prelude::*; +use pyo3::types::PyDict; + +use crate::config::CoreConfig; +use crate::errors::ValResult; +use crate::input::Input; + +use super::{CombinedValidator, ValidationState, Validator}; + +/// A validator that sets the current configuration. +#[derive(Debug, Clone)] +pub struct ConfigValidator { + config: CoreConfig, + inner: Arc, +} + +impl ConfigValidator { + pub fn try_new(config: Bound<'_, PyDict>, inner: Arc) -> PyResult { + Ok(Self { + config: config.try_into()?, + inner, + }) + } +} + +impl_py_gc_traverse!(ConfigValidator {}); + +impl Validator for ConfigValidator { + fn validate<'py>( + &self, + py: Python<'py>, + input: &(impl Input<'py> + ?Sized), + state: &mut ValidationState<'_, 'py>, + ) -> ValResult { + let mut state = + ValidationState::new_with_config(state.extra().clone(), state.recursion_guard, self.config.clone()); + self.inner.validate(py, input, &mut state) + } + + fn get_name(&self) -> &str { + "config" + } +} diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index 72d738440..eeece750a 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -6,7 +6,7 @@ use pyo3::types::{PyDict, PyList, PyString, PyTuple, PyType}; use ahash::AHashSet; use crate::build_tools::py_schema_err; -use crate::build_tools::{is_strict, schema_or_config_same, ExtraBehavior}; +use crate::build_tools::{schema_or_config_same, ExtraBehavior}; use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{ input_as_python_instance, Arguments, BorrowInput, Input, InputType, KeywordArgs, PositionalArgs, ValidationMatch, @@ -15,7 +15,7 @@ use crate::lookup_key::LookupKey; use crate::tools::SchemaDict; use crate::validators::function::convert_err; -use super::model::{create_class, force_setattr, Revalidate}; +use super::model::{create_class, force_setattr}; use super::validation_state::Exactness; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; @@ -435,12 +435,10 @@ impl Validator for DataclassArgsValidator { #[derive(Debug)] pub struct DataclassValidator { - strict: bool, validator: Box, class: Py, fields: Vec>, post_init: Option>, - revalidate: Revalidate, name: String, frozen: bool, slots: bool, @@ -481,17 +479,10 @@ impl BuildValidator for DataclassValidator { .collect::>>()?; Ok(Self { - strict: is_strict(schema, config)?, validator: Box::new(validator), class: class.into(), fields, post_init, - revalidate: Revalidate::from_str( - schema_or_config_same::>(schema, config, intern!(py, "revalidate_instances"))? - .as_ref() - .map(|s| s.to_str()) - .transpose()?, - )?, name, frozen: schema.get_as(intern!(py, "frozen"))?.unwrap_or(false), slots: schema.get_as(intern!(py, "slots"))?.unwrap_or(false), @@ -517,7 +508,13 @@ impl Validator for DataclassValidator { // same logic as on models let class = self.class.bind(py); if let Some(py_input) = input_as_python_instance(input, class) { - if self.revalidate.should_revalidate(py_input, class) { + if state + .config + .revalidate_instances + .as_ref() + .unwrap_or(&crate::config::RevalidateInstances::Never) + .should_revalidate(py_input, class) + { let input_dict = self.dataclass_to_dict(py_input)?; let val_output = self.validator.validate(py, input_dict.as_any(), state)?; let dc = create_class(self.class.bind(py))?; @@ -526,7 +523,7 @@ impl Validator for DataclassValidator { } else { Ok(input.to_object(py)) } - } else if state.strict_or(self.strict) && state.extra().input_type == InputType::Python { + } else if state.strict_or(false) && state.extra().input_type == InputType::Python { Err(ValError::new( ErrorType::DataclassExactType { class_name: self.get_name().to_string(), diff --git a/src/validators/float.rs b/src/validators/float.rs index 3a79de9fe..a85061310 100644 --- a/src/validators/float.rs +++ b/src/validators/float.rs @@ -1,14 +1,15 @@ use std::cmp::Ordering; +use std::sync::Arc; use pyo3::intern; use pyo3::prelude::*; use pyo3::types::PyDict; -use crate::build_tools::{is_strict, schema_or_config_same}; use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult}; use crate::input::Input; use crate::tools::SchemaDict; +use super::config::ConfigValidator; use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; pub struct FloatBuilder; @@ -26,38 +27,43 @@ impl BuildValidator for FloatBuilder { || schema.get_item(intern!(py, "lt"))?.is_some() || schema.get_item(intern!(py, "ge"))?.is_some() || schema.get_item(intern!(py, "gt"))?.is_some(); - if use_constrained { - ConstrainedFloatValidator::build(schema, config, definitions) + + let mut validator = if use_constrained { + ConstrainedFloatValidator::build(schema, config, definitions)? } else { - Ok(FloatValidator { - strict: is_strict(schema, config)?, - allow_inf_nan: schema_or_config_same(schema, config, intern!(py, "allow_inf_nan"))?.unwrap_or(true), + FloatValidator::build(schema, config, definitions)? + }; + + let strict: Option = schema.get_as(intern!(py, "strict"))?; + let allow_inf_nan: Option = schema.get_as(intern!(py, "allow_inf_nan"))?; + + if strict.is_some() | allow_inf_nan.is_some() { + let config = PyDict::new_bound(py); + if let Some(strict) = strict { + config.set_item("strict", strict)?; + } + if let Some(allow_inf_nan) = allow_inf_nan { + config.set_item("allow_inf_nan", allow_inf_nan)?; } - .into()) + validator = CombinedValidator::Config(ConfigValidator::try_new(config, Arc::new(validator))?); } + + Ok(validator) } } #[derive(Debug, Clone)] -pub struct FloatValidator { - strict: bool, - allow_inf_nan: bool, -} +pub struct FloatValidator; impl BuildValidator for FloatValidator { const EXPECTED_TYPE: &'static str = "float"; fn build( - schema: &Bound<'_, PyDict>, - config: Option<&Bound<'_, PyDict>>, + _schema: &Bound<'_, PyDict>, + _config: Option<&Bound<'_, PyDict>>, _definitions: &mut DefinitionsBuilder, ) -> PyResult { - let py = schema.py(); - Ok(Self { - strict: is_strict(schema, config)?, - allow_inf_nan: schema_or_config_same(schema, config, intern!(py, "allow_inf_nan"))?.unwrap_or(true), - } - .into()) + Ok(Self.into()) } } @@ -70,8 +76,8 @@ impl Validator for FloatValidator { input: &(impl Input<'py> + ?Sized), state: &mut ValidationState<'_, 'py>, ) -> ValResult { - let either_float = input.validate_float(state.strict_or(self.strict))?.unpack(state); - if !self.allow_inf_nan && !either_float.as_f64().is_finite() { + let either_float = input.validate_float(state.strict_or(false))?.unpack(state); + if !state.config.allow_inf_nan.unwrap_or(false) && !either_float.as_f64().is_finite() { return Err(ValError::new(ErrorTypeDefaults::FiniteNumber, input)); } Ok(either_float.into_py(py)) @@ -84,8 +90,6 @@ impl Validator for FloatValidator { #[derive(Debug, Clone)] pub struct ConstrainedFloatValidator { - strict: bool, - allow_inf_nan: bool, multiple_of: Option, le: Option, lt: Option, @@ -102,9 +106,9 @@ impl Validator for ConstrainedFloatValidator { input: &(impl Input<'py> + ?Sized), state: &mut ValidationState<'_, 'py>, ) -> ValResult { - let either_float = input.validate_float(state.strict_or(self.strict))?.unpack(state); + let either_float = input.validate_float(state.strict_or(false))?.unpack(state); let float: f64 = either_float.as_f64(); - if !self.allow_inf_nan && !float.is_finite() { + if !state.config.allow_inf_nan.unwrap_or(false) && !float.is_finite() { return Err(ValError::new(ErrorTypeDefaults::FiniteNumber, input)); } if let Some(multiple_of) = self.multiple_of { @@ -176,13 +180,11 @@ impl BuildValidator for ConstrainedFloatValidator { const EXPECTED_TYPE: &'static str = "float"; fn build( schema: &Bound<'_, PyDict>, - config: Option<&Bound<'_, PyDict>>, + _config: Option<&Bound<'_, PyDict>>, _definitions: &mut DefinitionsBuilder, ) -> PyResult { let py = schema.py(); Ok(Self { - strict: is_strict(schema, config)?, - allow_inf_nan: schema_or_config_same(schema, config, intern!(py, "allow_inf_nan"))?.unwrap_or(true), multiple_of: schema.get_as(intern!(py, "multiple_of"))?, le: schema.get_as(intern!(py, "le"))?, lt: schema.get_as(intern!(py, "lt"))?, diff --git a/src/validators/mod.rs b/src/validators/mod.rs index ede6489ab..b50b589da 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -24,6 +24,7 @@ mod bytes; mod call; mod callable; mod chain; +mod config; mod custom_error; mod dataclass; mod date; @@ -732,6 +733,9 @@ pub enum CombinedValidator { DefinitionRef(definitions::DefinitionRefValidator), // input dependent JsonOrPython(json_or_python::JsonOrPython), + // config applies the a CoreConfig to the current validation context + // for example if we hit a Model that has an attached config this validator will apply that config + Config(config::ConfigValidator), } /// This trait must be implemented by all validators, it allows various validators to be accessed consistently, diff --git a/src/validators/model.rs b/src/validators/model.rs index 2c0cef6fd..91d60c80d 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -1,17 +1,17 @@ use std::ptr::null_mut; +use std::sync::Arc; use pyo3::exceptions::PyTypeError; use pyo3::ffi; use pyo3::types::{PyDict, PySet, PyString, PyTuple, PyType}; use pyo3::{intern, prelude::*}; +use super::config::ConfigValidator; use super::function::convert_err; use super::validation_state::Exactness; use super::{ build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, ValidationState, Validator, }; -use crate::build_tools::py_schema_err; -use crate::build_tools::schema_or_config_same; use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult}; use crate::input::{input_as_python_instance, py_error_on_minusone, Input}; use crate::tools::{py_err, SchemaDict}; @@ -23,36 +23,8 @@ const DUNDER_FIELDS_SET_KEY: &str = "__pydantic_fields_set__"; const DUNDER_MODEL_EXTRA_KEY: &str = "__pydantic_extra__"; const DUNDER_MODEL_PRIVATE_KEY: &str = "__pydantic_private__"; -#[derive(Debug, Clone)] -pub(super) enum Revalidate { - Always, - Never, - SubclassInstances, -} - -impl Revalidate { - pub fn from_str(s: Option<&str>) -> PyResult { - match s { - None => Ok(Self::Never), - Some("always") => Ok(Self::Always), - Some("never") => Ok(Self::Never), - Some("subclass-instances") => Ok(Self::SubclassInstances), - Some(s) => py_schema_err!("Invalid revalidate_instances value: {}", s), - } - } - - pub fn should_revalidate(&self, input: &Bound<'_, PyAny>, class: &Bound<'_, PyType>) -> bool { - match self { - Revalidate::Always => true, - Revalidate::Never => false, - Revalidate::SubclassInstances => !input.is_exact_instance(class), - } - } -} - #[derive(Debug)] pub struct ModelValidator { - revalidate: Revalidate, validator: Box, class: Py, post_init: Option>, @@ -80,17 +52,7 @@ impl BuildValidator for ModelValidator { let validator = build_validator(&sub_schema, config.as_ref(), definitions)?; let name = class.getattr(intern!(py, "__name__"))?.extract()?; - Ok(Self { - revalidate: Revalidate::from_str( - schema_or_config_same::>( - schema, - config.as_ref(), - intern!(py, "revalidate_instances"), - )? - .as_ref() - .map(|s| s.to_str()) - .transpose()?, - )?, + let mut validator = Self { validator: Box::new(validator), class: class.into(), post_init: schema.get_as(intern!(py, "post_init"))?, @@ -101,7 +63,13 @@ impl BuildValidator for ModelValidator { // Get the class's `__name__`, not using `class.qualname()` name, } - .into()) + .into(); + + if let Some(config) = config { + validator = CombinedValidator::Config(ConfigValidator::try_new(config, Arc::new(validator))?); + } + + Ok(validator) } } @@ -125,7 +93,13 @@ impl Validator for ModelValidator { // but use from attributes to create a new instance of the model field type let class = self.class.bind(py); if let Some(py_input) = input_as_python_instance(input, class) { - if self.revalidate.should_revalidate(py_input, class) { + if state + .config + .revalidate_instances + .as_ref() + .unwrap_or(&crate::config::RevalidateInstances::Never) + .should_revalidate(py_input, class) + { let fields_set = py_input.getattr(intern!(py, DUNDER_FIELDS_SET_KEY))?; if self.root_model { let inner_input = py_input.getattr(intern!(py, ROOT_FIELD))?; diff --git a/src/validators/validation_state.rs b/src/validators/validation_state.rs index b125cd316..f584a2bee 100644 --- a/src/validators/validation_state.rs +++ b/src/validators/validation_state.rs @@ -3,6 +3,7 @@ use pyo3::types::PyString; use jiter::StringCacheMode; +use crate::config::CoreConfig; use crate::recursion_guard::{ContainsRecursionState, RecursionState}; use crate::tools::new_py_string; @@ -25,6 +26,7 @@ pub struct ValidationState<'a, 'py> { pub fields_set_count: Option, // deliberately make Extra readonly extra: Extra<'a, 'py>, + pub config: CoreConfig, } impl<'a, 'py> ValidationState<'a, 'py> { @@ -34,6 +36,17 @@ impl<'a, 'py> ValidationState<'a, 'py> { exactness: None, fields_set_count: None, extra, + config: CoreConfig::default(), + } + } + + pub fn new_with_config(extra: Extra<'a, 'py>, recursion_guard: &'a mut RecursionState, config: CoreConfig) -> Self { + Self { + recursion_guard, + exactness: None, + fields_set_count: None, + extra, + config, } } @@ -54,7 +67,7 @@ impl<'a, 'py> ValidationState<'a, 'py> { } pub fn strict_or(&self, default: bool) -> bool { - self.extra.strict.unwrap_or(default) + self.extra.strict.unwrap_or(self.config.strict.unwrap_or(default)) } /// Sets the exactness to the lower of the current exactness diff --git a/tests/validators/test_model.py b/tests/validators/test_model.py index bdcfbb2e9..8b02041d9 100644 --- a/tests/validators/test_model.py +++ b/tests/validators/test_model.py @@ -630,11 +630,10 @@ def __init__(self): }, } ) - assert re.search(r'revalidate: \w+', repr(v)).group(0) == 'revalidate: Never' m = MyModel() m2 = v.validate_python(m) assert isinstance(m, MyModel) - assert m is m2 + assert m is m2 # revalidate_instances should default to `Never` assert m.field_a == 'init_a' # note that since dict validation was not run here, there has been no check this is an int assert m.field_b == 'init_b'