From bdb9b477c1838ee9a5475e2531bbee752caad7f5 Mon Sep 17 00:00:00 2001 From: Mars Saxman Date: Thu, 19 Dec 2024 14:47:19 -0800 Subject: [PATCH 1/7] support Zhlt::BackOp in picus mode --- zirgen/compiler/picus/picus.cpp | 13 ++++++++++++- zirgen/compiler/picus/test/back.zir | 13 +++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) create mode 100644 zirgen/compiler/picus/test/back.zir diff --git a/zirgen/compiler/picus/picus.cpp b/zirgen/compiler/picus/picus.cpp index 8b1e8887..87ce3250 100644 --- a/zirgen/compiler/picus/picus.cpp +++ b/zirgen/compiler/picus/picus.cpp @@ -139,7 +139,8 @@ class PicusPrinter { PackOp, ReturnOp, GetGlobalLayoutOp, - AliasLayoutOp>([&](auto op) { visitOp(op); }) + AliasLayoutOp, + zirgen::Zhlt::BackOp>([&](auto op) { visitOp(op); }) .Case([](auto) { /* no-op */ }) .Default([](Operation* op) { llvm::errs() << "unhandled op: " << *op << "\n"; }); } @@ -288,6 +289,16 @@ class PicusPrinter { } } + void visitOp(zirgen::Zhlt::BackOp back) { + auto callee = back.getCallee(); + size_t distance = back.getDistance().getZExtValue(); + AnySignal signal = signalize(freshName(), back.getType()); + if (distance > 0) { + declareSignals(signal, /*isInput=*/true); + } + valuesToSignals.insert({back.getOut(), signal}); + } + // Constructs a fresh signal structure corresponding to the given type AnySignal signalize(std::string prefix, Type type) { if (isa(type) || isa(type)) { diff --git a/zirgen/compiler/picus/test/back.zir b/zirgen/compiler/picus/test/back.zir new file mode 100644 index 00000000..c5aa0922 --- /dev/null +++ b/zirgen/compiler/picus/test/back.zir @@ -0,0 +1,13 @@ +// RUN: zirgen %s --emit=picus | FileCheck %s + +// CHECK: (prime-number 2013265921) +// CHECK: (begin-module Count) +// CHECK: (input x0) +// CHECK: (end-module) + +#[picus] +component Count(first: Val) { + public a : Reg; + a := Reg(a@1); +} + From 0e138d6e8544504a7ac47e9eff4b39a07348bf84 Mon Sep 17 00:00:00 2001 From: Mars Saxman Date: Thu, 9 Jan 2025 13:41:53 -0800 Subject: [PATCH 2/7] support signals with `assume-deterministic` operator --- zirgen/compiler/picus/picus.cpp | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/zirgen/compiler/picus/picus.cpp b/zirgen/compiler/picus/picus.cpp index 87ce3250..d28daa7e 100644 --- a/zirgen/compiler/picus/picus.cpp +++ b/zirgen/compiler/picus/picus.cpp @@ -37,6 +37,12 @@ using SignalArray = ArrayAttr; using SignalStruct = DictionaryAttr; using AnySignal = Attribute; +enum class SignalType { + Input, + Output, + AssumeDeterministic, +}; + template void visit(AnySignal signal, F f) { if (!signal) { // no-op @@ -95,7 +101,7 @@ class PicusPrinter { if (isa(param.getType()) || isa(param.getType())) continue; AnySignal signal = signalize(freshName(), param.getType()); - declareSignals(signal, /*isInput=*/true); + declareSignals(signal, SignalType::Input); valuesToSignals.insert({param, signal}); workQueue.push(lookupConstructor(param.getType())); } @@ -103,13 +109,13 @@ class PicusPrinter { // The layout is an output if (auto layout = component.getLayout()) { AnySignal layoutSignal = signalize("layout", layout.getType()); - declareSignals(layoutSignal, /*isInput=*/false); + declareSignals(layoutSignal, SignalType::Output); valuesToSignals.insert({layout, layoutSignal}); } // The result is an output AnySignal result = signalize("result", component.getOutType()); - declareSignals(result, /*isInput=*/false); + declareSignals(result, SignalType::Output); valuesToSignals.insert({Value(), result}); for (Operation& op : component.getBody().front()) { @@ -277,7 +283,7 @@ class PicusPrinter { void visitOp(GetGlobalLayoutOp get) { // This is sound but presumably not complete? AnySignal signal = signalize(freshName(), get.getType()); - declareSignals(signal, /*isInput=*/false); + declareSignals(signal, SignalType::Output); valuesToSignals.insert({get.getOut(), signal}); } @@ -294,7 +300,7 @@ class PicusPrinter { size_t distance = back.getDistance().getZExtValue(); AnySignal signal = signalize(freshName(), back.getType()); if (distance > 0) { - declareSignals(signal, /*isInput=*/true); + declareSignals(signal, SignalType::AssumeDeterministic); } valuesToSignals.insert({back.getOut(), signal}); } @@ -341,12 +347,18 @@ class PicusPrinter { return flattened; } - void declareSignals(AnySignal signal, bool isInput) { - visit(signal, [&](Signal s) { declareSignal(s, isInput); }); + void declareSignals(AnySignal signal, SignalType type) { + visit(signal, [&](Signal s) { declareSignal(s, type); }); } - void declareSignal(Signal signal, bool isInput) { - os << "(" << (isInput ? "input " : "output ") << signal.str() << ")\n"; + void declareSignal(Signal signal, SignalType type) { + std::string op; + switch (type) { + case SignalType::Input: op = "input"; break; + case SignalType::Output: op = "output"; break; + case SignalType::AssumeDeterministic: op = "assume-deterministic"; break; + } + os << "(" << op << " " << signal.str() << ")\n"; } ComponentOp lookupConstructor(Type type) { From 27fd33e24fad1dfd6b3cbcd0c6057629a9052f84 Mon Sep 17 00:00:00 2001 From: Mars Saxman Date: Thu, 9 Jan 2025 13:44:46 -0800 Subject: [PATCH 3/7] update test results now that back values are assume-deterministic --- zirgen/compiler/picus/test/back.zir | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/zirgen/compiler/picus/test/back.zir b/zirgen/compiler/picus/test/back.zir index c5aa0922..f8d78253 100644 --- a/zirgen/compiler/picus/test/back.zir +++ b/zirgen/compiler/picus/test/back.zir @@ -2,7 +2,8 @@ // CHECK: (prime-number 2013265921) // CHECK: (begin-module Count) -// CHECK: (input x0) +// CHECK: (assume-deterministic x1__super__super) +// CHECK: (assume-deterministic x1_reg__super) // CHECK: (end-module) #[picus] From a88e09e171d42e668424833608a4c838c1784bfd Mon Sep 17 00:00:00 2001 From: Mars Saxman Date: Mon, 13 Jan 2025 11:17:16 -0800 Subject: [PATCH 4/7] add test for back with distance zero --- zirgen/compiler/picus/picus.cpp | 1 - zirgen/compiler/picus/test/back-zero.zir | 19 +++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) create mode 100644 zirgen/compiler/picus/test/back-zero.zir diff --git a/zirgen/compiler/picus/picus.cpp b/zirgen/compiler/picus/picus.cpp index d28daa7e..4d328bfb 100644 --- a/zirgen/compiler/picus/picus.cpp +++ b/zirgen/compiler/picus/picus.cpp @@ -296,7 +296,6 @@ class PicusPrinter { } void visitOp(zirgen::Zhlt::BackOp back) { - auto callee = back.getCallee(); size_t distance = back.getDistance().getZExtValue(); AnySignal signal = signalize(freshName(), back.getType()); if (distance > 0) { diff --git a/zirgen/compiler/picus/test/back-zero.zir b/zirgen/compiler/picus/test/back-zero.zir new file mode 100644 index 00000000..6921176a --- /dev/null +++ b/zirgen/compiler/picus/test/back-zero.zir @@ -0,0 +1,19 @@ +// RUN: zirgen %s --emit=picus | FileCheck %s + +// CHECK: (prime-number 2013265921) +// CHECK: (begin-module Count) +// CHECK: (input x0) +// CHECK-NEXT: (output layout_a__super__super) +// CHECK-NEXT: (output result_a__super__super) +// CHECK-NEXT: (output result_a_reg__super) +// CHECK-NEXT: (call [{{.*}}] Reg [{{.*}}]) +// CHECK-NEXT: (assert (= result_a__super__super x2__super__super)) +// CHECK-NEXT: (assert (= result_a_reg__super x2_reg__super)) +// CHECK-NEXT: (end-module) + +#[picus] +component Count(first: Val) { + public a : Reg; + a := Reg(a@0); +} + From d09d6dbcdf70d0b3cf51c2331477adfd365422f9 Mon Sep 17 00:00:00 2001 From: Mars Saxman Date: Mon, 13 Jan 2025 14:16:30 -0800 Subject: [PATCH 5/7] when back distance is zero, mark operand as input --- zirgen/compiler/picus/picus.cpp | 2 ++ zirgen/compiler/picus/test/back-zero.zir | 2 ++ 2 files changed, 4 insertions(+) diff --git a/zirgen/compiler/picus/picus.cpp b/zirgen/compiler/picus/picus.cpp index 4d328bfb..8e144162 100644 --- a/zirgen/compiler/picus/picus.cpp +++ b/zirgen/compiler/picus/picus.cpp @@ -300,6 +300,8 @@ class PicusPrinter { AnySignal signal = signalize(freshName(), back.getType()); if (distance > 0) { declareSignals(signal, SignalType::AssumeDeterministic); + } else { + declareSignals(signal, SignalType::Input); } valuesToSignals.insert({back.getOut(), signal}); } diff --git a/zirgen/compiler/picus/test/back-zero.zir b/zirgen/compiler/picus/test/back-zero.zir index 6921176a..bc8b722c 100644 --- a/zirgen/compiler/picus/test/back-zero.zir +++ b/zirgen/compiler/picus/test/back-zero.zir @@ -6,6 +6,8 @@ // CHECK-NEXT: (output layout_a__super__super) // CHECK-NEXT: (output result_a__super__super) // CHECK-NEXT: (output result_a_reg__super) +// CHECK-NEXT: (input x1__super__super) +// CHECK-NEXT: (input x1_reg__super) // CHECK-NEXT: (call [{{.*}}] Reg [{{.*}}]) // CHECK-NEXT: (assert (= result_a__super__super x2__super__super)) // CHECK-NEXT: (assert (= result_a_reg__super x2_reg__super)) From cb1eeae5d9c4e1d1fe1aaa374942ce67c8fcc15a Mon Sep 17 00:00:00 2001 From: Mars Saxman Date: Tue, 14 Jan 2025 12:08:09 -0800 Subject: [PATCH 6/7] explicitly punt on the zero-distance back case (it will be inlined away) --- zirgen/compiler/picus/picus.cpp | 9 ++++----- zirgen/compiler/picus/test/back-zero.zir | 21 --------------------- 2 files changed, 4 insertions(+), 26 deletions(-) delete mode 100644 zirgen/compiler/picus/test/back-zero.zir diff --git a/zirgen/compiler/picus/picus.cpp b/zirgen/compiler/picus/picus.cpp index 8e144162..2994e09e 100644 --- a/zirgen/compiler/picus/picus.cpp +++ b/zirgen/compiler/picus/picus.cpp @@ -298,11 +298,10 @@ class PicusPrinter { void visitOp(zirgen::Zhlt::BackOp back) { size_t distance = back.getDistance().getZExtValue(); AnySignal signal = signalize(freshName(), back.getType()); - if (distance > 0) { - declareSignals(signal, SignalType::AssumeDeterministic); - } else { - declareSignals(signal, SignalType::Input); - } + // We cannot handle the zero-distance case this way, so we expect that + // all zero-distance backs will have been converted & inlined already. + assert (distance > 0); + declareSignals(signal, SignalType::AssumeDeterministic); valuesToSignals.insert({back.getOut(), signal}); } diff --git a/zirgen/compiler/picus/test/back-zero.zir b/zirgen/compiler/picus/test/back-zero.zir deleted file mode 100644 index bc8b722c..00000000 --- a/zirgen/compiler/picus/test/back-zero.zir +++ /dev/null @@ -1,21 +0,0 @@ -// RUN: zirgen %s --emit=picus | FileCheck %s - -// CHECK: (prime-number 2013265921) -// CHECK: (begin-module Count) -// CHECK: (input x0) -// CHECK-NEXT: (output layout_a__super__super) -// CHECK-NEXT: (output result_a__super__super) -// CHECK-NEXT: (output result_a_reg__super) -// CHECK-NEXT: (input x1__super__super) -// CHECK-NEXT: (input x1_reg__super) -// CHECK-NEXT: (call [{{.*}}] Reg [{{.*}}]) -// CHECK-NEXT: (assert (= result_a__super__super x2__super__super)) -// CHECK-NEXT: (assert (= result_a_reg__super x2_reg__super)) -// CHECK-NEXT: (end-module) - -#[picus] -component Count(first: Val) { - public a : Reg; - a := Reg(a@0); -} - From 48040687778f5275a675773545c1bf12e0b7d2f6 Mon Sep 17 00:00:00 2001 From: Mars Saxman Date: Tue, 14 Jan 2025 12:49:11 -0800 Subject: [PATCH 7/7] update copyright year and apply clang-format --- zirgen/compiler/picus/picus.cpp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/zirgen/compiler/picus/picus.cpp b/zirgen/compiler/picus/picus.cpp index 2994e09e..91966b84 100644 --- a/zirgen/compiler/picus/picus.cpp +++ b/zirgen/compiler/picus/picus.cpp @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -300,7 +300,7 @@ class PicusPrinter { AnySignal signal = signalize(freshName(), back.getType()); // We cannot handle the zero-distance case this way, so we expect that // all zero-distance backs will have been converted & inlined already. - assert (distance > 0); + assert(distance > 0); declareSignals(signal, SignalType::AssumeDeterministic); valuesToSignals.insert({back.getOut(), signal}); } @@ -354,9 +354,15 @@ class PicusPrinter { void declareSignal(Signal signal, SignalType type) { std::string op; switch (type) { - case SignalType::Input: op = "input"; break; - case SignalType::Output: op = "output"; break; - case SignalType::AssumeDeterministic: op = "assume-deterministic"; break; + case SignalType::Input: + op = "input"; + break; + case SignalType::Output: + op = "output"; + break; + case SignalType::AssumeDeterministic: + op = "assume-deterministic"; + break; } os << "(" << op << " " << signal.str() << ")\n"; }