Skip to content

Commit

Permalink
Add test for disambiguation of parameters with similar types in facto…
Browse files Browse the repository at this point in the history
…ries, add progression test for missing disambiguation by Id annotation
  • Loading branch information
neko-kai committed Jan 24, 2025
1 parent c0dbc1e commit 1dfb73a
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,15 @@ trait ReflectionProviderDefaultImpl extends ReflectionProvider {

val alreadyInSignature = factoryMethod.paramLists.flatten.map(symbol => brp.keyFromSymbol(MacroSymbolInfo.Runtime(symbol, tpe, wasGeneric = false)))
val resultTypeWiring = mkConstructorWiring(factoryMethod, resultType)
val requiredKeys = resultTypeWiring.requiredKeys

val excessiveTypes = alreadyInSignature.toSet -- resultTypeWiring.requiredKeys
if (excessiveTypes.nonEmpty) {
val excessiveKeys = alreadyInSignature.toSet[MacroDIKey] -- requiredKeys
if (excessiveKeys.nonEmpty) {
throw new UnsupportedDefinitionException(
s"""Augmentation failure.
| * Type $tpe has been considered a factory because of abstract method `${factoryMethodSymb.name}: ${factoryMethodSymb.typeSignatureInDefiningClass}` with result type `$resultType`
| * But method signature contains types not required by constructor of the result type: $excessiveTypes
| * Only the following types are required: ${resultTypeWiring.requiredKeys}
| * But method signature contains keys not required by constructor of the result type: $excessiveKeys
| * Only the following keys are required: $requiredKeys
| * This may happen in case you unintentionally bind an abstract type (trait, etc) as implementation type.""".stripMargin
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ class ConstructorUtil[Q <: Quotes](using val qctx: Q) { self =>

requireConcreteTypeConstructor(resultTpe, "FactoryConstructor")

val getFactoryProductType = {
val getFactoryProductType: List[TypeTree] => TypeRepr = {
(methodTypeArgs: List[TypeTree]) =>

val rettAppliedProperly = methodType match {
Expand Down Expand Up @@ -542,7 +542,7 @@ class ConstructorUtil[Q <: Quotes](using val qctx: Q) { self =>

res.dealias.simplified
}
val factoryProductType = getFactoryProductType(Nil)
val factoryProductType: TypeRepr = getFactoryProductType(Nil)

val isTrait = symbolIsTraitOrAbstract(factoryProductType.typeSymbol)

Expand All @@ -555,7 +555,7 @@ class ConstructorUtil[Q <: Quotes](using val qctx: Q) { self =>
ctxUntyped.assertIsWireableTrait(isInFactoryConstructor = true)
}

val factoryProductCtorParamLists = if (isTrait) {
val factoryProductCtorParamLists: ParamReprLists = if (isTrait) {
val byNameMethodArgs = ctxUntyped.methodDecls.map {
case MemberRepr(n, _, s, t, _) => ParamRepr(n, s, returnTypeOfMethodOrByName(t))
} // become byName later via ensureByName if they're InjectedDependencyParameter
Expand All @@ -565,13 +565,13 @@ class ConstructorUtil[Q <: Quotes](using val qctx: Q) { self =>
}
assertSignatureIsAcceptableForFactory(factoryProductCtorParamLists.flatten, resultTpe, s"implementation constructor ${factoryProductType.show}")

val methodParams = extractMethodParamLists(methodType, mbMethodSym.getOrElse(Symbol.noSymbol)).flatten
val methodParams: List[ParamRepr] = extractMethodParamLists(methodType, mbMethodSym.getOrElse(Symbol.noSymbol)).flatten
assertSignatureIsAcceptableForFactory(methodParams, resultTpe, s"factory method $methodName")

val indexedMethodParams = methodParams.zipWithIndex
val methodParamIndex = indexedMethodParams.map { case (ParamRepr(n, _, t), idx) => (t, (n, idx)) }

val factoryProductParamss = factoryProductCtorParamLists.zipWithIndex.map {
val factoryProductParamss: List[List[FactoryProductParameter]] = factoryProductCtorParamLists.zipWithIndex.map {
case (params, paramListIdx) =>
params.map {
case ParamRepr(paramName, symbol, paramType) =>
Expand All @@ -589,14 +589,14 @@ class ConstructorUtil[Q <: Quotes](using val qctx: Q) { self =>
InjectedDependencyParameter(ParamRepr(newName, symbol, ensureByName(paramType)), curIndex)

case multiple =>
val (_, (_, idx)) = multiple
.find { case (_, (n, _)) => n == paramName }
.getOrElse(
val idx = multiple
.collectFirst { case (_, (n, idx)) if n == paramName => idx }
.getOrElse {
report.errorAndAbort(
s"""Couldn't disambiguate between multiple arguments with the same type available for parameter $paramName: ${paramType.show} of ${factoryProductType.show} constructor
|Expected one of the arguments to be named `$paramName` or for the type to be unique among factory method arguments""".stripMargin
)
)
}
MethodParameter(idx)
}
}
Expand Down Expand Up @@ -626,10 +626,10 @@ class ConstructorUtil[Q <: Quotes](using val qctx: Q) { self =>
}

FactoryProductData(
getFactoryProductType,
factoryProductParamss.flatten.collect { case p: InjectedDependencyParameter => p.depByNameParamRepr },
hackySecretTraitImpl,
factoryProductParamss,
getFactoryProductType = getFactoryProductType,
byNameDependencies = factoryProductParamss.flatten.collect { case p: InjectedDependencyParameter => p.depByNameParamRepr },
hackyTraitImpl = hackySecretTraitImpl,
factoryProductParameterLists = factoryProductParamss,
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ object FactoryCases {
final case class AssistedTestClass(b: Dependency, a: Int)
final case class NamedAssistedTestClass(@Id("special") b: Dependency, a: Int)
final case class GenericAssistedTestClass[T, S](a: List[T], b: List[S], c: Dependency)
final case class AmbiguousTestClass(a: Dependency @Id("special"), b: Dependency @Id("veryspecial"))

trait Factory {
def wiringTargetForDependency: Dependency
Expand Down Expand Up @@ -56,6 +57,16 @@ object FactoryCases {
def x[T, S](t: List[T], s: List[S]): GenericAssistedTestClass[T, S]
}

trait AmbiguousOnlyParamNamesFactory {
def x(a: Dependency @Id("special"), b: Dependency @Id("veryspecial")): AmbiguousTestClass
def y(b: Dependency @Id("veryspecial"), a: Dependency @Id("special")): AmbiguousTestClass
}

trait AmbiguousOnlyIdFactory {
def x(special: Dependency @Id("special"), veryspecial: Dependency @Id("veryspecial")): AmbiguousTestClass
def y(veryspecial: Dependency @Id("veryspecial"), special: Dependency @Id("special")): AmbiguousTestClass
}

trait AbstractDependency

case class AbstractDependencyImpl(a: Dependency) extends AbstractDependency
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class FactoriesTest extends AnyWordSpec with MkInjector with ScalatestGuards {
makeFactory[OverridingFactory]
makeFactory[AssistedFactory]
makeFactory[AbstractFactory]
makeFactory[AmbiguousOnlyParamNamesFactory]
})

val injector = mkNoCyclesInjector()
Expand Down Expand Up @@ -50,6 +51,28 @@ class FactoriesTest extends AnyWordSpec with MkInjector with ScalatestGuards {
val assistedFactory = context.get[AssistedFactory]
assert(assistedFactory.x(1).a == 1)
assert(assistedFactory.x(1).b.isInstanceOf[Dependency])

val ambiguousFactory = context.get[AmbiguousOnlyParamNamesFactory]
val d1 = ConcreteDep()
val d2 = ConcreteDep()
val amb1 = ambiguousFactory.x(d1, d2)
val amb2 = ambiguousFactory.y(d1, d2)
assert(d1 ne d2)
assert(amb1.a.eq(d1) && amb1.b.eq(d2))
assert(amb2.a.eq(d2) && amb2.b.eq(d1))
}

"progression test: should, but doesn't support disambiguation of factory parameters based on Id annotation, not param names" in {
def test(): Unit = assertCompiles("""
import FactoryCase1.AmbiguousOnlyIdFactory
val definition = PlannerInput.everything(new ModuleDef {
makeFactory[AmbiguousOnlyIdFactory]
})
""")

val exc = intercept[TestFailedException](test())
assert(exc.getMessage contains "Couldn't disambiguate between multiple arguments with the same type")
}

"handle generic arguments in factory methods" in {
Expand Down Expand Up @@ -282,7 +305,7 @@ class FactoriesTest extends AnyWordSpec with MkInjector with ScalatestGuards {
makeFactory[InvalidImplicitFactory]
}"""))
assert(
res.getMessage.contains("contains types not required by constructor of the result type") ||
res.getMessage.contains("contains keys not required by constructor of the result type") ||
res.getMessage.contains("has arguments which were not consumed by implementation constructor")
)
assert(res.getMessage.contains("UnrelatedTC["))
Expand Down Expand Up @@ -443,7 +466,7 @@ class FactoriesTest extends AnyWordSpec with MkInjector with ScalatestGuards {
"support polymorphic factory types" in {
import FactoryCase8.*

def definition[F[+_, +_]: TagKK] = PlannerInput.everything(new ModuleDef {
def definition[F[+_, +_]: TagKK]: PlannerInput = PlannerInput.everything(new ModuleDef {
makeFactory[XFactory[F]]
make[XContext[F]]
})
Expand Down

0 comments on commit 1dfb73a

Please sign in to comment.