Skip to content

Commit

Permalink
Support typeclass method with multiple arguments (#170)
Browse files Browse the repository at this point in the history
  • Loading branch information
iRevive authored Nov 13, 2024
1 parent 94cf1f2 commit 853deb0
Show file tree
Hide file tree
Showing 3 changed files with 304 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,132 @@ object UnionDerivation {
import quotes.reflect.*

def deriveImpl[A: Type]: Expr[F[A]] = {
given Diagnostic = Diagnostic(TypeRepr.of[F], TypeRepr.of[A].dealias)

val tpe: TypeRepr = TypeRepr.of[A]

tpe.dealias match {
case o: OrType =>
val paramType = findParamType
val abstractMethod = findAbstractMethod
val collectedTypes = collectTypes(o)
val mt = MethodType(List("a"))(_ => List(tpe), _ => abstractMethod.returnTpt.tpe)
val params = collectParams(abstractMethod, paramType.tpe)

val lambdaType = MethodType(params.map(_.name))(
_ => params.map(p => if (p.isPoly) tpe else p.typeRepr),
_ => abstractMethod.returnTpt.tpe
)

val lambda =
Lambda(Symbol.spliceOwner, mt, (meth, arg) => body(arg.head.asExprOf[A], collectedTypes, abstractMethod.name))
val lambda = Lambda(
Symbol.spliceOwner,
lambdaType,
(_, args) => body(collectedTypes, params, args, abstractMethod.name)
)

// transform lambda to an instance of the typeclass
val instanceTree = lambda match {
case Block(body, Closure(meth, _)) =>
Block(body, Closure(meth, Some(TypeRepr.of[F].appliedTo(tpe))))
case Block(body, Closure(method, _)) =>
Block(body, Closure(method, Some(TypeRepr.of[F].appliedTo(tpe))))
}

instanceTree.asExprOf[F[A]]

case other =>
report.errorAndAbort(s"Cannot derive a typeclass for the ${tpe.show}. Only Union type is supported")
errorAndAbort("only Union type is supported.")
}
}

private final case class Diagnostic(typeclass: TypeRepr, targetType: TypeRepr)

private final case class MethodParam(
name: String,
typeRepr: TypeRepr,
isPoly: Boolean // whether param appear in the polymorphic position, e.g. (a: A)
)

private def collectParams(method: DefDef, paramType: TypeRepr)(using Diagnostic): List[MethodParam] =
method.paramss match {
case TermParamClause(params) :: Nil =>
val all = params.map { param =>
MethodParam(param.name, param.tpt.tpe, param.tpt.tpe == paramType)
}

val typed = all.filter(_.isPoly)

if (typed.size == 1) {
all
} else if (typed.isEmpty) {
errorAndAbort(
"the abstract method without the polymorphic param isn't supported.",
Some(
"""check the example below where the instance cannot be derived
|
|trait Typeclass[A] {
| def magic(a: Int): String
| // ^
| // Polymorphic param of type A is missing
|}""".stripMargin
)
)
} else {
errorAndAbort(
s"the abstract method has multiple polymorphic params of the same parametrized type: ${typed.map(_.name).mkString(", ")}.",
Some("""check the example below where the instance cannot be derived
|
|trait Typeclass[A] {
| def magic(a1: A, b: Int, a2: A): String
| // ^ ^
| // Polymorphic type A appears in two positions
|}""".stripMargin)
)
}

case Nil =>
errorAndAbort(
"the abstract method without the polymorphic param isn't supported.",
Some(
"""check the example below where the instance cannot be derived
|
|trait Typeclass[A] {
| def magic: String
| // ^
| // Polymorphic param of type A is missing
|}""".stripMargin
)
)

case _ =>
errorAndAbort(
"the curried abstract method isn't supported.",
Some(
"""check the example below where the instance cannot be derived
|
|trait Typeclass[A] {
| def magic(a: A)(b: Int): String
| // ^
| // Curried functions aren't supported
|}""".stripMargin
)
)
}

// required exactly one type param
private def findParamType(using Diagnostic): TypeTree =
TypeRepr.of[F].typeSymbol.declaredTypes match {
case head :: Nil =>
TypeIdent(head)

case Nil =>
errorAndAbort("The typeclass doesn't have a type parameter")

case _ =>
errorAndAbort("The typeclass has multiple type parameters")
}

/**
* Looks-up for an abstract method in F[_]
*/
private def findAbstractMethod: DefDef = {
private def findAbstractMethod(using Diagnostic): DefDef = {
val tcl: TypeRepr = TypeRepr.of[F]

val methods = tcl.typeSymbol.declaredMethods.filter(_.isDefDef).map(_.tree).collect {
Expand All @@ -51,16 +149,17 @@ object UnionDerivation {

methods match {
case Nil =>
report.errorAndAbort(
s"""Cannot detect an abstract method in ${tcl.typeSymbol}. `scalacOptions += "-Yretain-trees"` may solve the issue"""
errorAndAbort(
"cannot detect an abstract method in the typeclass.",
Some("""`scalacOptions += "-Yretain-trees"` may solve the issue.""")
)

case head :: Nil =>
head

case other =>
report.errorAndAbort(
s"More than one abstract method detected in ${tcl.typeSymbol}: ${other.map(_.name).mkString(", ")}. Automatic derivation is impossible"
errorAndAbort(
s"more than one abstract method is detected: ${other.map(_.name).mkString(", ")}."
)
}
}
Expand All @@ -70,33 +169,50 @@ object UnionDerivation {
*
* The
* {{{
* if (value.isInstanceOf[Int]) summon[Show[Int]].show(value)
* else if (value.isInstanceOf[String]) summon[Show[String]].show(value)
* if (value.isInstanceOf[Int]) summon[Typeclass[Int]].magic(value)
* else if (value.isInstanceOf[String]) summon[Typeclass[String]].magic(value)
* else sys.error("Impossible") // impossible state
* }}}
*
* @param t
* the input value of the method
* @param knownTypes
* the known member types of the union
* @param params
* the list of function parameter
* @param lambdaArgs
* the list of lambda args
* @param method
* the name of the typeclass method to apply
* @tparam A
* the input type
* @tparam R
* the output type of the method
* @return
*/
private def body[A](t: Expr[A], knownTypes: List[TypeRepr], method: String): Term = {
val selector: Term = t.asTerm
private def body[A: Type](
knownTypes: List[TypeRepr],
params: List[MethodParam],
lambdaArgs: List[Tree],
method: String
)(using Diagnostic): Term = {

val selector: Term = params
.zip(lambdaArgs)
.collectFirst { case (param, arg) if param.isPoly => arg }
.getOrElse(errorAndAbort("cannot find poly param in the list of lambda arguments."))
.asExprOf[A]
.asTerm

val ifBranches: List[(Term, Term)] = knownTypes.map { tpe =>
val identifier = TypeIdent(tpe.typeSymbol)
val condition = TypeApply(Select.unique(selector, "isInstanceOf"), identifier :: Nil)
val tcl = lookupImplicit(tpe)
val castedValue = Select.unique(selector, "asInstanceOf").appliedToType(tpe)
val identifier = TypeIdent(tpe.typeSymbol)
val condition = TypeApply(Select.unique(selector, "isInstanceOf"), identifier :: Nil)
val tcl = lookupImplicit(tpe)

val action: Term = Apply(Select.unique(tcl, method), castedValue :: Nil)
val args: List[Term] = params.zip(lambdaArgs).map {
case (param, arg) if param.isPoly =>
Select.unique(selector, "asInstanceOf").appliedToType(tpe)

case (_, arg) =>
arg.asExpr.asTerm
}

val action: Term = Select.unique(tcl, method).appliedToArgs(args)

(condition, action)
}
Expand All @@ -122,12 +238,12 @@ object UnionDerivation {
/**
* Looks-up for an instance of `F[A]` for the provided type
*/
private def lookupImplicit(t: TypeRepr): Term = {
private def lookupImplicit(t: TypeRepr)(using Diagnostic): Term = {
val typeclassTpe = TypeRepr.of[F]
val tclTpe = typeclassTpe.appliedTo(t)
Implicits.search(tclTpe) match {
case success: ImplicitSearchSuccess => success.tree
case failure: ImplicitSearchFailure => report.errorAndAbort(failure.explanation)
case failure: ImplicitSearchFailure => errorAndAbort(failure.explanation)
}
}

Expand All @@ -141,5 +257,11 @@ object UnionDerivation {
case Nil =>
('{ throw RuntimeException("Unhandled condition encountered during derivation") }).asTerm
}

private def errorAndAbort(reason: String, hint: Option[String] = None)(using d: Diagnostic): Nothing =
report.errorAndAbort(
s"""UnionDerivation cannot derive an instance of ${d.typeclass.typeSymbol} for the type `${d.targetType.show}`.
|Reason: $reason""".stripMargin + hint.map(fix => s"\nHint: $fix").getOrElse("") + "\n\n"
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ class ShowDerivationSuite extends munit.FunSuite {
test("fail derivation for a non-union type") {
val expected =
"""
|error: Cannot derive a typeclass for the scala.Int. Only Union type is supported
|error:
|UnionDerivation cannot derive an instance of trait Show for the type `scala.Int`.
|Reason: only Union type is supported.
|
| assertNoDiff(compileErrors("Show.deriveUnion[Int]"), expected)
| ^
|
Expand All @@ -36,7 +39,10 @@ class ShowDerivationSuite extends munit.FunSuite {
test("fail derivation if an instance of a typeclass is missing for a member type") {
val expected =
"""
|error: no implicit values were found that match type io.github.irevive.union.derivation.ShowDerivationSuite.Show[Double]
|error:
|UnionDerivation cannot derive an instance of trait Show for the type `scala.Int | scala.Predef.String | scala.Double`.
|Reason: no implicit values were found that match type io.github.irevive.union.derivation.ShowDerivationSuite.Show[Double]
|
| assertNoDiff(compileErrors("Show.deriveUnion[Int | String | Double]"), expected)
| ^
|""".stripMargin
Expand Down
Loading

0 comments on commit 853deb0

Please sign in to comment.