From 08fad2a9218e0a6938dd8cd89f5052b55ec16432 Mon Sep 17 00:00:00 2001 From: Jamie Stumme <3059647+StummeJ@users.noreply.github.com> Date: Fri, 18 Oct 2024 11:44:54 -0500 Subject: [PATCH 1/2] =?UTF-8?q?=F0=9F=90=9B=20fix:=20not=20resolving=20ali?= =?UTF-8?q?ases?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rodi/__init__.py | 10 ++++++++- tests/test_fn_exec.py | 1 + tests/test_services.py | 51 ++++++++++++++++++++++++++++++++---------- 3 files changed, 49 insertions(+), 13 deletions(-) diff --git a/rodi/__init__.py b/rodi/__init__.py index 7f27091..0304cb9 100644 --- a/rodi/__init__.py +++ b/rodi/__init__.py @@ -509,7 +509,7 @@ def _get_resolvers_for_parameters( # but at least Optional could be supported in the future raise UnsupportedUnionTypeException(param_name, concrete_type) - if param_type is _empty: + if param_type is _empty or param_type not in services._map: if services.strict: raise CannotResolveParameterException(param_name, concrete_type) @@ -521,6 +521,14 @@ def _get_resolvers_for_parameters( else: aliases = services._aliases[param_name] + if not aliases: + cls_name = class_name(param_type) + aliases = ( + services._aliases[cls_name] + or services._aliases[cls_name.lower()] + or services._aliases[to_standard_param_name(cls_name)] + ) + if aliases: assert ( len(aliases) == 1 diff --git a/tests/test_fn_exec.py b/tests/test_fn_exec.py index 541e800..91ff62e 100644 --- a/tests/test_fn_exec.py +++ b/tests/test_fn_exec.py @@ -2,6 +2,7 @@ Functions exec tests. exec functions are designed to enable executing any function injecting parameters. """ + import pytest from rodi import Container, inject diff --git a/tests/test_services.py b/tests/test_services.py index 1015049..cdfd71a 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -687,6 +687,33 @@ def __init__(self, cats_controller, service_settings): assert isinstance(u.cats_controller.cat_request_handler, GetCatRequestHandler) +def test_alias_dep_resolving(): + container = arrange_cats_example() + + class BaseClass: + pass + + class DerivedClass(BaseClass): + pass + + class UsingAliasByType: + def __init__(self, example: BaseClass): + self.example = example + + def resolve_derived_class(_) -> DerivedClass: + return DerivedClass() + + container.add_scoped_by_factory(resolve_derived_class, DerivedClass) + container.add_alias("BaseClass", DerivedClass) + container.add_scoped(UsingAliasByType) + + provider = container.build_provider() + u = provider.get(UsingAliasByType) + + assert isinstance(u, UsingAliasByType) + assert isinstance(u.example, DerivedClass) + + def test_get_service_by_name_or_alias(): container = arrange_cats_example() container.add_alias("k", CatsController) @@ -2323,7 +2350,7 @@ def factory() -> annotation: def test_factory_without_locals_raises(): def factory_without_context() -> None: - ... + pass with pytest.raises(FactoryMissingContextException): _get_factory_annotations_or_throw(factory_without_context) @@ -2332,7 +2359,7 @@ def factory_without_context() -> None: def test_factory_with_locals_get_annotations(): @inject() def factory_without_context() -> "Cat": - ... + pass annotations = _get_factory_annotations_or_throw(factory_without_context) @@ -2350,20 +2377,20 @@ def test_deps_github_scenario(): """ class HTTPClient: - ... + pass class CommentsService: - ... + pass class ChecksService: - ... + pass class CLAHandler: comments_service: CommentsService checks_service: ChecksService class GitHubSettings: - ... + pass class GitHubAuthHandler: settings: GitHubSettings @@ -2478,7 +2505,7 @@ class B: def test_provide_protocol_with_attribute_dependency() -> None: class P(Protocol): def foo(self) -> Any: - ... + pass class Dependency: pass @@ -2506,7 +2533,7 @@ def foo(self) -> Any: def test_provide_protocol_with_init_dependency() -> None: class P(Protocol): def foo(self) -> Any: - ... + pass class Dependency: pass @@ -2536,10 +2563,10 @@ def test_provide_protocol_generic() -> None: class P(Protocol[T]): def foo(self, t: T) -> T: - ... + pass class A: - ... + pass class Impl(P[A]): def foo(self, t: A) -> A: @@ -2562,10 +2589,10 @@ def test_provide_protocol_generic_with_inner_dependency() -> None: class P(Protocol[T]): def foo(self, t: T) -> T: - ... + pass class A: - ... + pass class Dependency: pass From 18987269bf9ec7528d66672843af52583fd11702 Mon Sep 17 00:00:00 2001 From: Jamie Stumme <3059647+StummeJ@users.noreply.github.com> Date: Fri, 18 Oct 2024 14:58:23 -0500 Subject: [PATCH 2/2] =?UTF-8?q?=F0=9F=90=9B=20fix:=20handle=20resolving=20?= =?UTF-8?q?alias=20directly?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rodi/__init__.py | 7 +++++++ tests/test_services.py | 3 +++ 2 files changed, 10 insertions(+) diff --git a/rodi/__init__.py b/rodi/__init__.py index 0304cb9..b240411 100644 --- a/rodi/__init__.py +++ b/rodi/__init__.py @@ -744,6 +744,13 @@ def get( scope = ActivationScope(self) resolver = self._map.get(desired_type) + if not resolver: + cls_name = class_name(desired_type) + resolver = ( + self._map.get(cls_name) + or self._map.get(cls_name.lower()) + or self._map.get(to_standard_param_name(cls_name)) + ) scoped_service = scope.scoped_services.get(desired_type) if scope else None if not resolver and not scoped_service: diff --git a/tests/test_services.py b/tests/test_services.py index cdfd71a..942b5b2 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -713,6 +713,9 @@ def resolve_derived_class(_) -> DerivedClass: assert isinstance(u, UsingAliasByType) assert isinstance(u.example, DerivedClass) + b = provider.get(BaseClass) + assert isinstance(b, DerivedClass) + def test_get_service_by_name_or_alias(): container = arrange_cats_example()