Skip to content

Commit

Permalink
Unification benchmark using 'map' and 'stream' (#798)
Browse files Browse the repository at this point in the history
Resolves #796 

It's an OK (not great, not terrible) benchmark:
1. generates two nested, deep pair types like `(...((XLLL, XLLR), (XLRL,
XLRR)), ...` for `X = { A, B }`,
2. then unifies them,
3. and checks that the substitutions are like `ALLL -> BLLL`.

Configurable with a given `N` if some tweaking is needed.

### Perf:

#### My machine

On my computer with `N = 12`:
- JS compilation only: 4.5s
- LLVM compilation only: 5.5s
- then hyperfined:

```
$ hyperfine --warmup 5 --min-runs 20 './out/unify-js' './out/unify-llvm'
Benchmark 1: ./out/unify-js
  Time (mean ± σ):     703.0 ms ±   7.8 ms    [User: 1035.2 ms, System: 77.1 ms]
  Range (min … max):   690.9 ms … 721.1 ms    20 runs

Benchmark 2: ./out/unify-llvm
  Time (mean ± σ):      37.6 ms ±   0.5 ms    [User: 36.1 ms, System: 1.0 ms]
  Range (min … max):    37.1 ms …  39.7 ms    69 runs

Summary
  ./out/unify-llvm ran
   18.68 ± 0.32 times faster than ./out/unify-js
```

On my computer with `N = 16` (16x more work than `N = 12`)
- JS compilation only: 4.7s
- LLVM compilation only: 5.6s
- then hyperfined:

```
Benchmark 1: ./out/unify-js
  Time (mean ± σ):     19.982 s ±  0.608 s    [User: 26.011 s, System: 1.525 s]
  Range (min … max):   19.300 s … 21.267 s    20 runs

Benchmark 2: ./out/unify-llvm
  Time (mean ± σ):     928.2 ms ±  20.8 ms    [User: 906.3 ms, System: 16.1 ms]
  Range (min … max):   913.3 ms … 1006.2 ms    20 runs

  Warning: Statistical outliers were detected. Consider re-running this benchmark on a quiet system without any interferences from other programs. It might help to use the '--warmup' or '--prepare' options.

Summary
  ./out/unify-llvm ran
   21.53 ± 0.81 times faster than ./out/unify-js
```

#### CI

With `N = 12` in CI (measured imprecisely in two CI rounds):
- 9.2-12.0 seconds on Chez backends
- 8.0-8.9 seconds on JS
- 9.9-12.7 seconds on LLVM

With `N = 16` in CI (measured imprecisely in two CI rounds):
- 21-29 seconds on Chez backends
- 55-57 seconds on JS
- 79-86 seconds on LLVM

---------

Co-authored-by: Jonathan Brachthäuser <[email protected]>
  • Loading branch information
jiribenes and b-studios authored Jan 24, 2025
1 parent 2be5198 commit 94eb038
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ object Normalizer { normal =>
case None => false
}

private def isUnused(id: Id)(using ctx: Context): Boolean =
ctx.usage.get(id).forall { u => u == Usage.Never }

def normalize(entrypoints: Set[Id], m: ModuleDecl, maxInlineSize: Int, preserveBoxing: Boolean): ModuleDecl = {
// usage information is used to detect recursive functions (and not inline them)
val usage = Reachable(entrypoints, m)
Expand Down Expand Up @@ -160,6 +163,10 @@ object Normalizer { normal =>

def normalize(s: Stmt)(using C: Context): Stmt = s match {

// see #798 for context (led to stack overflow)
case Stmt.Def(id, block, body) if isUnused(id) =>
normalize(body)

case Stmt.Def(id, block, body) =>
val normalized = active(block).dealiased
Stmt.Def(id, normalized, normalize(body)(using C.bind(id, normalized)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,11 @@ object TransformerCps extends Transformer {
js.Lambda(vps.map(nameDef) ++ bps.map(nameDef) ++ List(nameDef(ks), nameDef(k)), toJS(body).stmts)
}

def argumentToJS(b: cps.Block)(using TransformerContext): js.Expr = b match {
case cps.BlockLit(vps, bps, ks, k, body) => toJS(b)(using nonrecursive(ks))
case other => toJS(b)
}

def toJS(handler: cps.Implementation)(using TransformerContext): js.Expr = handler match {
case cps.Implementation(interface, operations) =>
js.Object(operations.map {
Expand All @@ -191,7 +196,7 @@ object TransformerCps extends Transformer {
case Pure.Literal(s: String) => JsString(escape(s))
case literal: Pure.Literal => js.RawExpr(literal.value.toString)
case DirectApp(id, vargs, Nil) => inlineExtern(id, vargs)
case DirectApp(id, vargs, bargs) => js.Call(nameRef(id), vargs.map(toJS) ++ bargs.map(toJS))
case DirectApp(id, vargs, bargs) => js.Call(nameRef(id), vargs.map(toJS) ++ bargs.map(argumentToJS))
case Pure.PureApp(id, vargs) => inlineExtern(id, vargs)
case Pure.Make(data, tag, vargs) => js.New(nameRef(tag), vargs map toJS)
case Pure.Box(b) => toJS(b)
Expand Down Expand Up @@ -331,7 +336,7 @@ object TransformerCps extends Transformer {
stmts.append(js.Assign(nameRef(param), toJS(substitutions.substitute(arg)(using subst))))
}
(bparams zip bargs).foreach { (param, arg) =>
stmts.append(js.Assign(nameRef(param), toJS(substitutions.substitute(arg)(using subst))))
stmts.append(js.Assign(nameRef(param), argumentToJS(substitutions.substitute(arg)(using subst))))
}

// Restore metacont if needed
Expand All @@ -343,11 +348,11 @@ object TransformerCps extends Transformer {
}

case cps.Stmt.App(callee, vargs, bargs, ks, k) =>
pure(js.Return(js.Call(toJS(callee), vargs.map(toJS) ++ bargs.map(toJS) ++ List(toJS(ks),
pure(js.Return(js.Call(toJS(callee), vargs.map(toJS) ++ bargs.map(argumentToJS) ++ List(toJS(ks),
requiringThunk { toJS(k) }))))

case cps.Stmt.Invoke(callee, method, vargs, bargs, ks, k) =>
val args = vargs.map(toJS) ++ bargs.map(toJS) ++ List(toJS(ks), toJS(k))
val args = vargs.map(toJS) ++ bargs.map(argumentToJS) ++ List(toJS(ks), toJS(k))
pure(js.Return(MethodCall(toJS(callee), memberNameRef(method), args:_*)))

// const r = ks.arena.newRegion(); body
Expand Down
4 changes: 4 additions & 0 deletions examples/benchmarks/other/unify.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Unification successful!
a -> List(Int)
Unification successful!
4096
183 changes: 183 additions & 0 deletions examples/benchmarks/other/unify.effekt
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
/// Robinson-style Unification Algorithm
module examples/benchmarks/unify

import examples/benchmarks/runner
import map
import result
import bytearray
import stream

type Type {
Var(name: String)
Con(name: String, args: List[Type])
}

type Substitution = Map[String, Type]

type UnificationError {
OccursCheckFailure(variable: String, tpe: Type)
UnificationFailure(tpe1: Type, tpe2: Type)
UnificationManyFailure(tps1: List[Type], tps2: List[Type])
}

// Check if a type variable occurs in another type
def occurs(variable: String, ty: Type): Bool = ty match {
case Var(name) => name == variable
case Con(_, args) => args.any { arg => variable.occurs(arg) }
}

// Apply a substitution to a type
def apply(subst: Substitution, ty: Type): Type = ty match {
case Var(name) =>
subst.getOrElse(name) { () => ty }
case Con(name, args) =>
Con(name, args.map { arg => subst.apply(arg) })
}

def unify(ty1: Type, ty2: Type, subst: Substitution): Substitution / Exception[UnificationError] = {
val substTy1 = subst.apply(ty1)
val substTy2 = subst.apply(ty2)

(substTy1, substTy2) match {
// If both are the same variable, return current substitution
case (Var(x), Var(y)) and x == y =>
subst

// If first is a variable, try to bind it
case (Var(x), _) =>
if (x.occurs(substTy2)) {
do raise(OccursCheckFailure(x, substTy2), "")
} else {
subst.put(x, substTy2)
}

// If second is a variable, try to bind it
case (_, Var(y)) =>
if (occurs(y, substTy1)) {
do raise(OccursCheckFailure(y, substTy1), "")
} else {
subst.put(y, substTy1)
}

// If both are constructors, unify their arguments
case (Con(name1, args1), Con(name2, args2)) =>
if (name1 != name2) {
do raise(UnificationFailure(substTy1, substTy2), "Different constructors!")
} else if (args1.size != args2.size) {
do raise(UnificationFailure(substTy1, substTy2), "Different number of arguments!")
} else {
unifyMany(args1, args2, subst)
}
}
}

// Unify a list of arguments with a current substitution
def unifyMany(args1: List[Type], args2: List[Type], subst: Substitution): Substitution / Exception[UnificationError] =
(args1, args2) match {
case (Nil(), Nil()) => subst
case (Cons(a1, rest1), Cons(a2, rest2)) =>
val newSubst = unify(a1, a2, subst)
unifyMany(rest1, rest2, newSubst)
case _ => do raise(UnificationManyFailure(args1, args2), "Different numbers of types on each side!")
}

def unify(ty1: Type, ty2: Type): Substitution / Exception[UnificationError] =
unify(ty1, ty2, map::empty(box bytearray::compareStringBytes))

def showType(ty: Type): String = ty match {
case Var(name) => name
case Con(name, Nil()) => name
case Con(name, args) =>
name ++ "(" ++ args.map { t => showType(t) }.join(", ") ++ ")"
}

def show(err: UnificationError): String = err match {
case OccursCheckFailure(variable, ty) =>
"Occurs check failed: " ++ variable ++ " occurs in " ++ showType(ty)
case UnificationFailure(ty1, ty2) =>
"Cannot unify " ++ showType(ty1) ++ " with " ++ showType(ty2)
case UnificationManyFailure(tps1, tps2) =>
"Cannot unify " ++ tps2.map { showType }.join(", ") ++ " with " ++ tps1.map { showType }.join(", ")
}

/// Worker wrapper
def reporting { body : => Substitution / Exception[UnificationError] }: Unit / emit[(String, Type)] = {
val res = result[Substitution, UnificationError] {
body()
}

res match {
case Success(subst) => {
println("Unification successful!")
subst.each
}
case Error(err, msg) =>
println("Unification failed: " ++ show(err))
if (msg.length > 0) {
println(msg)
}
}
}

/// Used for testing to generate two `depth`-deep, nested types of the shape:
/// ```
/// (Nested
/// (Nested
/// ...
/// XLLLLLLLL
/// XLLLLLLLR)
/// ```
/// for `baseVar = X`.
def generateDeepType(depth: Int, baseVar: String): Type = {
def recur(currentDepth: Int, varSuffix: String): Type =
if (currentDepth == 0) {
Var(baseVar ++ varSuffix)
} else {
Con("Nested", [
recur(currentDepth - 1, varSuffix ++ "L"),
recur(currentDepth - 1, varSuffix ++ "R")
])
}

recur(depth, "")
}

def run(N: Int) = {
def printBinding(pair: (String, Type)): Unit =
println(" " ++ pair.first ++ " -> " ++ showType(pair.second))

// sanity check
for {
reporting {
val intType = Con("Int", [])
val listType = Con("List", [intType])
val typeVar = Var("a")

unify(typeVar, listType)
}
} { printBinding }

// the actual test
var found = 0

for {
reporting {
val deepType1 = generateDeepType(N, "A")
val deepType2 = generateDeepType(N, "B")
unify(deepType1, deepType2)
}
} {
case (l, Var(r)) and l.substring(1) == r.substring(1) =>
found = found + 1
case (l, r) =>
println("error! " ++ l ++ " -> " ++ showType(r))
}

val expected = 2.0.pow(N).toInt
if (found != expected) {
panic("found: " ++ found.show ++ ", but expected: " ++ expected.show)
}
found
}

def main() = benchmark(12){run}

0 comments on commit 94eb038

Please sign in to comment.