Skip to content

Commit

Permalink
fix?
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-schultz committed Jan 26, 2024
1 parent 92fb910 commit 935c2a6
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import is.hail.expr.ir.lowering.LoweringPipeline
import is.hail.types.physical.PTuple
import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType
import is.hail.types.virtual._
import is.hail.utils.{log, FastSeq}
import is.hail.utils.FastSeq

import org.apache.spark.sql.Row

Expand Down
34 changes: 20 additions & 14 deletions hail/src/main/scala/is/hail/expr/ir/Emit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -819,13 +819,15 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) {
emitI(cond).consume(cb, {}, m => cb.if_(m.asBoolean.value, emitVoid(cnsq), emitVoid(altr)))

case let: Let =>
println(Pretty.sexprStyle(let))
val newEnv = emitLetBindings(
emitI = (ir, cb, env, r) =>
if (ir.typ.isInstanceOf[TStream]) emitStream(ir, r, env = env).toI(cb)
if (ir.typ.isInstanceOf[TStream])
EmitStream.produce(this, ir, cb, cb.emb, r, env, container)
else emitI(ir, cb = cb, env = env, region = r),
emitVoid = (ir, cb, env, r) => emitVoid(ir, env = env, region = r)
emitVoid = (ir, cb, env, r) => emitVoid(ir, env = env, region = r, cb = cb),
)(let, cb, env, region)
emitVoid(let.body, cb, env = newEnv)
emitVoid(let.body, env = newEnv)

case StreamFor(a, valueName, body) =>
emitStream(a, region).toI(cb).consume(
Expand Down Expand Up @@ -1004,7 +1006,8 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) {
): IEmitCode =
this.emitI(ir, cb, region, env, container, loopEnv)

def emitStream(ir: IR, cb: EmitCodeBuilder, outerRegion: Value[Region], env: EmitEnv = env): IEmitCode =
def emitStream(ir: IR, cb: EmitCodeBuilder, outerRegion: Value[Region], env: EmitEnv = env)
: IEmitCode =
EmitStream.produce(this, ir, cb, cb.emb, outerRegion, env, container)

def emitVoid(
Expand Down Expand Up @@ -1109,10 +1112,11 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) {
case let: Let =>
val newEnv = emitLetBindings(
emitI = (ir, cb, env, r) =>
if (ir.typ.isInstanceOf[TStream]) emitStream(ir, cb, region, env = env)
if (ir.typ.isInstanceOf[TStream]) // emitStream(ir, cb, region, env = env)
EmitStream.produce(this, ir, cb, cb.emb, r, env, container)
else emitInNewBuilder(cb, ir, region = r, env = env),
emitVoid = (ir, cb, env, r) =>
emitVoid(ir, cb = cb, env = env, region = r)
emitVoid(ir, cb = cb, env = env, region = r),
)(let, cb, env, region)
emitI(let.body, env = newEnv)

Expand Down Expand Up @@ -3646,12 +3650,14 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) {
case None => mutable.Set.empty
}

def emitChunk(cb: EmitCodeBuilder, bindings: Seq[(String, IR)], r: Value[Region]): EmitEnv =
def emitChunk(cb: EmitCodeBuilder, bindings: Seq[(String, IR)], env: EmitEnv, r: Value[Region])
: EmitEnv =
bindings.foldLeft(env) { case (newEnv, (name, ir)) =>
if (!uses.contains(name)) newEnv
else if (ir.typ == TVoid) {
if (ir.typ == TVoid) {
emitVoid(ir, cb, newEnv, r)
newEnv
} else if (IsPure(ir) && !uses.contains(name)) {
newEnv
} else {
val value = emitI(ir, cb, newEnv, r)
val memo = cb.memoizeMaybeStreamValue(value, s"let_$name")
Expand All @@ -3660,7 +3666,9 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) {
}

if (
!ctx.inLoopCriticalPath.contains(let) && let.bindings.forall(x => !ctx.inLoopCriticalPath.contains(x._2))
!ctx.inLoopCriticalPath.contains(let) &&
let.bindings.forall(x => !ctx.inLoopCriticalPath.contains(x._2))
// false
) {
var newEnv = env
let.bindings.grouped(16).zipWithIndex.foreach { case (group, idx) =>
Expand All @@ -3669,14 +3677,12 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) {
FastSeq[ParamType](classInfo[Region]),
UnitInfo,
)
mb.voidWithBuilder { cb =>
newEnv = emitChunk(cb, group, mb.getCodeParam[Region](1))
}
mb.voidWithBuilder(cb => newEnv = emitChunk(cb, group, newEnv, mb.getCodeParam[Region](1)))
cb.invokeVoid(mb, cb.this_, r)
}
newEnv
} else
emitChunk(cb, let.bindings, r)
emitChunk(cb, let.bindings, env, r)
}
}

Expand Down
20 changes: 11 additions & 9 deletions hail/src/main/scala/is/hail/expr/ir/ForwardLets.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@ object ForwardLets {
def rewrite(ir: BaseIR, env: BindingEnv[IR]): BaseIR = {

def shouldForward(value: IR, refs: Set[RefEquality[BaseRef]], base: IR): Boolean = {
value.isInstanceOf[Ref] ||
value.isInstanceOf[In] ||
(IsConstant(value) && !value.isInstanceOf[Str]) ||
refs.isEmpty ||
(refs.size == 1 &&
nestingDepth.lookup(refs.head) == nestingDepth.lookup(base) &&
!ContainsScan(value) &&
!ContainsAgg(value)) &&
!ContainsAggIntermediate(value)
IsPure(value) && (
value.isInstanceOf[Ref] ||
value.isInstanceOf[In] ||
(IsConstant(value) && !value.isInstanceOf[Str]) ||
refs.isEmpty ||
(refs.size == 1 &&
nestingDepth.lookup(refs.head) == nestingDepth.lookup(base) &&
!ContainsScan(value) &&
!ContainsAgg(value)) &&
!ContainsAggIntermediate(value)
)
}

ir match {
Expand Down
3 changes: 1 addition & 2 deletions hail/src/main/scala/is/hail/expr/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -919,12 +919,11 @@ final case class RunAggScan(
) extends IR

object Begin {
def apply(xs: IndexedSeq[IR]): IR = {
def apply(xs: IndexedSeq[IR]): IR =
if (xs.isEmpty)
Void()
else
Let(xs.init.map(x => ("__void", x)), xs.last)
}
}

final case class Begin(xs: IndexedSeq[IR]) extends IR
Expand Down
11 changes: 11 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/IsPure.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package is.hail.expr.ir

import is.hail.types.virtual.TVoid

object IsPure {
def apply(x: IR): Boolean = x match {
case _ if x.typ == TVoid => false
case _: WritePartition | _: WriteValue => false
case _ => true
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ object EmitStream {
case let: Let =>
val newEnv = emitter.emitLetBindings(
emitI = (ir, cb, env, r) => emit(ir, cb, region = r, env = env),
emitVoid = (ir, cb, env, r) => emitVoid(ir, cb, region = r, env = env)
emitVoid = (ir, cb, env, r) => emitVoid(ir, cb, region = r, env = env),
)(let, cb, env, outerRegion)
produce(let.body, cb, env = newEnv)

Expand Down

0 comments on commit 935c2a6

Please sign in to comment.