From 1dfb73a914497b87c6f57430271b96cb287852f2 Mon Sep 17 00:00:00 2001 From: Kai <450507+neko-kai@users.noreply.github.com> Date: Fri, 24 Jan 2025 20:35:23 +0000 Subject: [PATCH] Add test for disambiguation of parameters with similar types in factories, add progression test for missing disambiguation by Id annotation --- .../ReflectionProviderDefaultImpl.scala | 9 ++++--- .../constructors/ConstructorUtil.scala | 26 +++++++++--------- .../izumi/distage/fixtures/FactoryCases.scala | 11 ++++++++ .../distage/injector/FactoriesTest.scala | 27 +++++++++++++++++-- 4 files changed, 54 insertions(+), 19 deletions(-) diff --git a/distage/distage-core-api/src/main/scala-2/izumi/distage/reflection/macros/universe/ReflectionProviderDefaultImpl.scala b/distage/distage-core-api/src/main/scala-2/izumi/distage/reflection/macros/universe/ReflectionProviderDefaultImpl.scala index f0822c850d..ce7a492847 100644 --- a/distage/distage-core-api/src/main/scala-2/izumi/distage/reflection/macros/universe/ReflectionProviderDefaultImpl.scala +++ b/distage/distage-core-api/src/main/scala-2/izumi/distage/reflection/macros/universe/ReflectionProviderDefaultImpl.scala @@ -138,14 +138,15 @@ trait ReflectionProviderDefaultImpl extends ReflectionProvider { val alreadyInSignature = factoryMethod.paramLists.flatten.map(symbol => brp.keyFromSymbol(MacroSymbolInfo.Runtime(symbol, tpe, wasGeneric = false))) val resultTypeWiring = mkConstructorWiring(factoryMethod, resultType) + val requiredKeys = resultTypeWiring.requiredKeys - val excessiveTypes = alreadyInSignature.toSet -- resultTypeWiring.requiredKeys - if (excessiveTypes.nonEmpty) { + val excessiveKeys = alreadyInSignature.toSet[MacroDIKey] -- requiredKeys + if (excessiveKeys.nonEmpty) { throw new UnsupportedDefinitionException( s"""Augmentation failure. | * Type $tpe has been considered a factory because of abstract method `${factoryMethodSymb.name}: ${factoryMethodSymb.typeSignatureInDefiningClass}` with result type `$resultType` - | * But method signature contains types not required by constructor of the result type: $excessiveTypes - | * Only the following types are required: ${resultTypeWiring.requiredKeys} + | * But method signature contains keys not required by constructor of the result type: $excessiveKeys + | * Only the following keys are required: $requiredKeys | * This may happen in case you unintentionally bind an abstract type (trait, etc) as implementation type.""".stripMargin ) } diff --git a/distage/distage-core-api/src/main/scala-3/izumi/distage/constructors/ConstructorUtil.scala b/distage/distage-core-api/src/main/scala-3/izumi/distage/constructors/ConstructorUtil.scala index 628204b48e..31373c2476 100644 --- a/distage/distage-core-api/src/main/scala-3/izumi/distage/constructors/ConstructorUtil.scala +++ b/distage/distage-core-api/src/main/scala-3/izumi/distage/constructors/ConstructorUtil.scala @@ -510,7 +510,7 @@ class ConstructorUtil[Q <: Quotes](using val qctx: Q) { self => requireConcreteTypeConstructor(resultTpe, "FactoryConstructor") - val getFactoryProductType = { + val getFactoryProductType: List[TypeTree] => TypeRepr = { (methodTypeArgs: List[TypeTree]) => val rettAppliedProperly = methodType match { @@ -542,7 +542,7 @@ class ConstructorUtil[Q <: Quotes](using val qctx: Q) { self => res.dealias.simplified } - val factoryProductType = getFactoryProductType(Nil) + val factoryProductType: TypeRepr = getFactoryProductType(Nil) val isTrait = symbolIsTraitOrAbstract(factoryProductType.typeSymbol) @@ -555,7 +555,7 @@ class ConstructorUtil[Q <: Quotes](using val qctx: Q) { self => ctxUntyped.assertIsWireableTrait(isInFactoryConstructor = true) } - val factoryProductCtorParamLists = if (isTrait) { + val factoryProductCtorParamLists: ParamReprLists = if (isTrait) { val byNameMethodArgs = ctxUntyped.methodDecls.map { case MemberRepr(n, _, s, t, _) => ParamRepr(n, s, returnTypeOfMethodOrByName(t)) } // become byName later via ensureByName if they're InjectedDependencyParameter @@ -565,13 +565,13 @@ class ConstructorUtil[Q <: Quotes](using val qctx: Q) { self => } assertSignatureIsAcceptableForFactory(factoryProductCtorParamLists.flatten, resultTpe, s"implementation constructor ${factoryProductType.show}") - val methodParams = extractMethodParamLists(methodType, mbMethodSym.getOrElse(Symbol.noSymbol)).flatten + val methodParams: List[ParamRepr] = extractMethodParamLists(methodType, mbMethodSym.getOrElse(Symbol.noSymbol)).flatten assertSignatureIsAcceptableForFactory(methodParams, resultTpe, s"factory method $methodName") val indexedMethodParams = methodParams.zipWithIndex val methodParamIndex = indexedMethodParams.map { case (ParamRepr(n, _, t), idx) => (t, (n, idx)) } - val factoryProductParamss = factoryProductCtorParamLists.zipWithIndex.map { + val factoryProductParamss: List[List[FactoryProductParameter]] = factoryProductCtorParamLists.zipWithIndex.map { case (params, paramListIdx) => params.map { case ParamRepr(paramName, symbol, paramType) => @@ -589,14 +589,14 @@ class ConstructorUtil[Q <: Quotes](using val qctx: Q) { self => InjectedDependencyParameter(ParamRepr(newName, symbol, ensureByName(paramType)), curIndex) case multiple => - val (_, (_, idx)) = multiple - .find { case (_, (n, _)) => n == paramName } - .getOrElse( + val idx = multiple + .collectFirst { case (_, (n, idx)) if n == paramName => idx } + .getOrElse { report.errorAndAbort( s"""Couldn't disambiguate between multiple arguments with the same type available for parameter $paramName: ${paramType.show} of ${factoryProductType.show} constructor |Expected one of the arguments to be named `$paramName` or for the type to be unique among factory method arguments""".stripMargin ) - ) + } MethodParameter(idx) } } @@ -626,10 +626,10 @@ class ConstructorUtil[Q <: Quotes](using val qctx: Q) { self => } FactoryProductData( - getFactoryProductType, - factoryProductParamss.flatten.collect { case p: InjectedDependencyParameter => p.depByNameParamRepr }, - hackySecretTraitImpl, - factoryProductParamss, + getFactoryProductType = getFactoryProductType, + byNameDependencies = factoryProductParamss.flatten.collect { case p: InjectedDependencyParameter => p.depByNameParamRepr }, + hackyTraitImpl = hackySecretTraitImpl, + factoryProductParameterLists = factoryProductParamss, ) } diff --git a/distage/distage-core/src/test/scala/izumi/distage/fixtures/FactoryCases.scala b/distage/distage-core/src/test/scala/izumi/distage/fixtures/FactoryCases.scala index fb47ac853c..728c5adecb 100644 --- a/distage/distage-core/src/test/scala/izumi/distage/fixtures/FactoryCases.scala +++ b/distage/distage-core/src/test/scala/izumi/distage/fixtures/FactoryCases.scala @@ -27,6 +27,7 @@ object FactoryCases { final case class AssistedTestClass(b: Dependency, a: Int) final case class NamedAssistedTestClass(@Id("special") b: Dependency, a: Int) final case class GenericAssistedTestClass[T, S](a: List[T], b: List[S], c: Dependency) + final case class AmbiguousTestClass(a: Dependency @Id("special"), b: Dependency @Id("veryspecial")) trait Factory { def wiringTargetForDependency: Dependency @@ -56,6 +57,16 @@ object FactoryCases { def x[T, S](t: List[T], s: List[S]): GenericAssistedTestClass[T, S] } + trait AmbiguousOnlyParamNamesFactory { + def x(a: Dependency @Id("special"), b: Dependency @Id("veryspecial")): AmbiguousTestClass + def y(b: Dependency @Id("veryspecial"), a: Dependency @Id("special")): AmbiguousTestClass + } + + trait AmbiguousOnlyIdFactory { + def x(special: Dependency @Id("special"), veryspecial: Dependency @Id("veryspecial")): AmbiguousTestClass + def y(veryspecial: Dependency @Id("veryspecial"), special: Dependency @Id("special")): AmbiguousTestClass + } + trait AbstractDependency case class AbstractDependencyImpl(a: Dependency) extends AbstractDependency diff --git a/distage/distage-core/src/test/scala/izumi/distage/injector/FactoriesTest.scala b/distage/distage-core/src/test/scala/izumi/distage/injector/FactoriesTest.scala index a9bfd8cc1d..0a221b802b 100644 --- a/distage/distage-core/src/test/scala/izumi/distage/injector/FactoriesTest.scala +++ b/distage/distage-core/src/test/scala/izumi/distage/injector/FactoriesTest.scala @@ -23,6 +23,7 @@ class FactoriesTest extends AnyWordSpec with MkInjector with ScalatestGuards { makeFactory[OverridingFactory] makeFactory[AssistedFactory] makeFactory[AbstractFactory] + makeFactory[AmbiguousOnlyParamNamesFactory] }) val injector = mkNoCyclesInjector() @@ -50,6 +51,28 @@ class FactoriesTest extends AnyWordSpec with MkInjector with ScalatestGuards { val assistedFactory = context.get[AssistedFactory] assert(assistedFactory.x(1).a == 1) assert(assistedFactory.x(1).b.isInstanceOf[Dependency]) + + val ambiguousFactory = context.get[AmbiguousOnlyParamNamesFactory] + val d1 = ConcreteDep() + val d2 = ConcreteDep() + val amb1 = ambiguousFactory.x(d1, d2) + val amb2 = ambiguousFactory.y(d1, d2) + assert(d1 ne d2) + assert(amb1.a.eq(d1) && amb1.b.eq(d2)) + assert(amb2.a.eq(d2) && amb2.b.eq(d1)) + } + + "progression test: should, but doesn't support disambiguation of factory parameters based on Id annotation, not param names" in { + def test(): Unit = assertCompiles(""" + import FactoryCase1.AmbiguousOnlyIdFactory + + val definition = PlannerInput.everything(new ModuleDef { + makeFactory[AmbiguousOnlyIdFactory] + }) + """) + + val exc = intercept[TestFailedException](test()) + assert(exc.getMessage contains "Couldn't disambiguate between multiple arguments with the same type") } "handle generic arguments in factory methods" in { @@ -282,7 +305,7 @@ class FactoriesTest extends AnyWordSpec with MkInjector with ScalatestGuards { makeFactory[InvalidImplicitFactory] }""")) assert( - res.getMessage.contains("contains types not required by constructor of the result type") || + res.getMessage.contains("contains keys not required by constructor of the result type") || res.getMessage.contains("has arguments which were not consumed by implementation constructor") ) assert(res.getMessage.contains("UnrelatedTC[")) @@ -443,7 +466,7 @@ class FactoriesTest extends AnyWordSpec with MkInjector with ScalatestGuards { "support polymorphic factory types" in { import FactoryCase8.* - def definition[F[+_, +_]: TagKK] = PlannerInput.everything(new ModuleDef { + def definition[F[+_, +_]: TagKK]: PlannerInput = PlannerInput.everything(new ModuleDef { makeFactory[XFactory[F]] make[XContext[F]] })