From e4acd1363524d42a70568b6b0842eaad661c05e8 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Fri, 26 Jan 2024 14:14:54 -0500 Subject: [PATCH] fix? --- .../is/hail/expr/ir/CompileAndEvaluate.scala | 2 +- .../src/main/scala/is/hail/expr/ir/Emit.scala | 33 +++++++++++-------- .../scala/is/hail/expr/ir/ForwardLets.scala | 20 ++++++----- .../main/scala/is/hail/expr/ir/IsPure.scala | 11 +++++++ 4 files changed, 42 insertions(+), 24 deletions(-) create mode 100644 hail/src/main/scala/is/hail/expr/ir/IsPure.scala diff --git a/hail/src/main/scala/is/hail/expr/ir/CompileAndEvaluate.scala b/hail/src/main/scala/is/hail/expr/ir/CompileAndEvaluate.scala index f0f8b7efbfe3..56b253402b96 100644 --- a/hail/src/main/scala/is/hail/expr/ir/CompileAndEvaluate.scala +++ b/hail/src/main/scala/is/hail/expr/ir/CompileAndEvaluate.scala @@ -7,7 +7,7 @@ import is.hail.expr.ir.lowering.LoweringPipeline import is.hail.types.physical.PTuple import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType import is.hail.types.virtual._ -import is.hail.utils.{log, FastSeq} +import is.hail.utils.FastSeq import org.apache.spark.sql.Row diff --git a/hail/src/main/scala/is/hail/expr/ir/Emit.scala b/hail/src/main/scala/is/hail/expr/ir/Emit.scala index e9fd2e9e575f..caa55ec1524c 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Emit.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Emit.scala @@ -819,13 +819,15 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) { emitI(cond).consume(cb, {}, m => cb.if_(m.asBoolean.value, emitVoid(cnsq), emitVoid(altr))) case let: Let => + println(Pretty.sexprStyle(let)) val newEnv = emitLetBindings( emitI = (ir, cb, env, r) => - if (ir.typ.isInstanceOf[TStream]) emitStream(ir, r, env = env).toI(cb) + if (ir.typ.isInstanceOf[TStream]) + EmitStream.produce(this, ir, cb, cb.emb, r, env, container) else emitI(ir, cb = cb, env = env, region = r), - emitVoid = (ir, cb, env, r) => emitVoid(ir, env = env, region = r) + emitVoid = (ir, cb, env, r) => emitVoid(ir, env = env, region = r, cb = cb), )(let, cb, env, region) - emitVoid(let.body, cb, env = newEnv) + emitVoid(let.body, env = newEnv) case StreamFor(a, valueName, body) => emitStream(a, region).toI(cb).consume( @@ -1004,7 +1006,8 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) { ): IEmitCode = this.emitI(ir, cb, region, env, container, loopEnv) - def emitStream(ir: IR, cb: EmitCodeBuilder, outerRegion: Value[Region], env: EmitEnv = env): IEmitCode = + def emitStream(ir: IR, cb: EmitCodeBuilder, outerRegion: Value[Region], env: EmitEnv = env) + : IEmitCode = EmitStream.produce(this, ir, cb, cb.emb, outerRegion, env, container) def emitVoid( @@ -1109,10 +1112,11 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) { case let: Let => val newEnv = emitLetBindings( emitI = (ir, cb, env, r) => - if (ir.typ.isInstanceOf[TStream]) emitStream(ir, cb, region, env = env) + if (ir.typ.isInstanceOf[TStream]) // emitStream(ir, cb, region, env = env) + EmitStream.produce(this, ir, cb, cb.emb, r, env, container) else emitInNewBuilder(cb, ir, region = r, env = env), emitVoid = (ir, cb, env, r) => - emitVoid(ir, cb = cb, env = env, region = r) + emitVoid(ir, cb = cb, env = env, region = r), )(let, cb, env, region) emitI(let.body, env = newEnv) @@ -3646,12 +3650,13 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) { case None => mutable.Set.empty } - def emitChunk(cb: EmitCodeBuilder, bindings: Seq[(String, IR)], r: Value[Region]): EmitEnv = + def emitChunk(cb: EmitCodeBuilder, bindings: Seq[(String, IR)], env: EmitEnv, r: Value[Region]): EmitEnv = bindings.foldLeft(env) { case (newEnv, (name, ir)) => - if (!uses.contains(name)) newEnv - else if (ir.typ == TVoid) { + if (ir.typ == TVoid) { emitVoid(ir, cb, newEnv, r) newEnv + } else if (IsPure(ir) && !uses.contains(name)) { + newEnv } else { val value = emitI(ir, cb, newEnv, r) val memo = cb.memoizeMaybeStreamValue(value, s"let_$name") @@ -3660,7 +3665,9 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) { } if ( - !ctx.inLoopCriticalPath.contains(let) && let.bindings.forall(x => !ctx.inLoopCriticalPath.contains(x._2)) + !ctx.inLoopCriticalPath.contains(let) && + let.bindings.forall(x => !ctx.inLoopCriticalPath.contains(x._2)) +// false ) { var newEnv = env let.bindings.grouped(16).zipWithIndex.foreach { case (group, idx) => @@ -3669,14 +3676,12 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) { FastSeq[ParamType](classInfo[Region]), UnitInfo, ) - mb.voidWithBuilder { cb => - newEnv = emitChunk(cb, group, mb.getCodeParam[Region](1)) - } + mb.voidWithBuilder(cb => newEnv = emitChunk(cb, group, newEnv, mb.getCodeParam[Region](1))) cb.invokeVoid(mb, cb.this_, r) } newEnv } else - emitChunk(cb, let.bindings, r) + emitChunk(cb, let.bindings, env, r) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/ForwardLets.scala b/hail/src/main/scala/is/hail/expr/ir/ForwardLets.scala index d1c594b2f3b7..b6d22042ea50 100644 --- a/hail/src/main/scala/is/hail/expr/ir/ForwardLets.scala +++ b/hail/src/main/scala/is/hail/expr/ir/ForwardLets.scala @@ -15,15 +15,17 @@ object ForwardLets { def rewrite(ir: BaseIR, env: BindingEnv[IR]): BaseIR = { def shouldForward(value: IR, refs: Set[RefEquality[BaseRef]], base: IR): Boolean = { - value.isInstanceOf[Ref] || - value.isInstanceOf[In] || - (IsConstant(value) && !value.isInstanceOf[Str]) || - refs.isEmpty || - (refs.size == 1 && - nestingDepth.lookup(refs.head) == nestingDepth.lookup(base) && - !ContainsScan(value) && - !ContainsAgg(value)) && - !ContainsAggIntermediate(value) + IsPure(value) && ( + value.isInstanceOf[Ref] || + value.isInstanceOf[In] || + (IsConstant(value) && !value.isInstanceOf[Str]) || + refs.isEmpty || + (refs.size == 1 && + nestingDepth.lookup(refs.head) == nestingDepth.lookup(base) && + !ContainsScan(value) && + !ContainsAgg(value)) && + !ContainsAggIntermediate(value) + ) } ir match { diff --git a/hail/src/main/scala/is/hail/expr/ir/IsPure.scala b/hail/src/main/scala/is/hail/expr/ir/IsPure.scala new file mode 100644 index 000000000000..b4295d2b9ef7 --- /dev/null +++ b/hail/src/main/scala/is/hail/expr/ir/IsPure.scala @@ -0,0 +1,11 @@ +package is.hail.expr.ir + +import is.hail.types.virtual.TVoid + +object IsPure { + def apply(x: IR): Boolean = x match { + case _ if x.typ == TVoid => false + case _: WritePartition | _: WriteValue => false + case _ => true + } +}