Skip to content

Commit

Permalink
ZIR-330: [Picus] Muxes, backs, and more (#164)
Browse files Browse the repository at this point in the history
* Lowering of muxes, and conditionalization of alias layouts in muxes
* Generate back functions for Picus, and inline backs with distance 0
* Inline more builtin components, including Reg
* Add picus_inline attribute so the programmer can force inlining in Picus code (necessary for Po2)
* Incorporate access modifiers to reduce number of signals

Co-authored-by: Mars Saxman <[email protected]>
  • Loading branch information
jacobdweightman and mars-risc0 authored Jan 17, 2025
1 parent 495f0ca commit b11c5cb
Show file tree
Hide file tree
Showing 12 changed files with 342 additions and 64 deletions.
5 changes: 3 additions & 2 deletions zirgen/Conversions/Typing/ZhlComponent.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2024 RISC Zero, Inc.
// Copyright 2025 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -552,7 +552,8 @@ Zhlt::ComponentOp LoweringImpl::gen(ComponentOp component,

for (NamedAttribute attr : component->getDiscardableAttrs()) {
StringRef name = attr.getName();
if (name == "function" || name == "argument" || name == "generic" || name == "picus") {
if (name == "function" || name == "argument" || name == "generic" || name == "picus_analyze" ||
name == "picus_inline") {
ctor->setAttr(name, attr.getValue());
} else {
ctor->emitError() << "unknown attribute `" << name << "`";
Expand Down
3 changes: 2 additions & 1 deletion zirgen/Dialect/ZStruct/IR/Types.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2024 RISC Zero, Inc.
// Copyright 2025 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -62,6 +62,7 @@ void printFields(mlir::AsmPrinter& p, llvm::ArrayRef<FieldInfo> fields) {
llvm::interleaveComma(fields, p, [&](const FieldInfo& field) {
if (field.isPrivate) {
p.printKeywordOrString("private");
p << " ";
}
p.printKeywordOrString(field.name.getValue());
p << ": ";
Expand Down
1 change: 1 addition & 0 deletions zirgen/circuit/keccak2/pack.zir
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import bits;
import arr;

// The max we can pack into one BB element is 30 bits
#[picus_inline]
component Po2(n: Val) {
arr := [ 0x00000001, 0x00000002, 0x00000004, 0x00000008,
0x00000010, 0x00000020, 0x00000040, 0x00000080,
Expand Down
3 changes: 2 additions & 1 deletion zirgen/circuit/keccak2/top.zir
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ component ExtractBits(steps: Array<DoShaStep, 8>, i: Val) {

component LoadWin(kf: Array<Val, 100>, round: OneHot<8>, block: Val) {
blockSelect := OneHot<4>(block);
other := reduce for i : 0..6 { round[2 + i] } init 0 with Add;
other := reduce for i : 0..6 { round[2 + i] } init 0 with Add;
[round[0], round[1], other] -> (
blockSelect -> (
for i : 0..8 { for j : 0..2 { kf[(i + 0) * 2 + j] }},
Expand Down Expand Up @@ -230,6 +230,7 @@ component ExtractBits2(a: Array<Array<Val, 32>, 4>, e: Array<Array<Val, 32>, 4>,
)
}

#[picus_analyze]
component ShaNextBlockCycle(back1: TopState) {
Log("ShaNextBlockCycle");
// Extract current a + e values in packed format
Expand Down
202 changes: 179 additions & 23 deletions zirgen/compiler/picus/picus.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
#include <set>

using namespace mlir;
using namespace zirgen::Zhlt;
using namespace zirgen::ZStruct;
using namespace zirgen::Zll;
using namespace zirgen;
using namespace Zhlt;
using namespace ZStruct;
using namespace Zll;

namespace {

Expand All @@ -53,11 +54,26 @@ template <typename F> void visit(AnySignal signal, F f) {
visit(elem, f);
} else if (auto str = dyn_cast<SignalStruct>(signal)) {
for (auto field : str) {
assert(field.getName() != "@layout");
visit(field.getValue(), f);
}
}
}

AnySignal getSuperSignal(AnySignal signal) {
if (auto arr = dyn_cast<SignalArray>(signal)) {
SmallVector<AnySignal> supers;
for (auto elem : arr)
supers.push_back(getSuperSignal(elem));
return SignalArray::get(signal.getContext(), supers);
} else if (auto str = dyn_cast<SignalStruct>(signal)) {
return str.getNamed("@super")->getValue();
} else {
// Signal, nullptr, etc
return nullptr;
}
}

std::string canonicalizeIdentifier(std::string ident) {
for (char& ch : ident) {
if (ch == '$' || ch == '@' || ch == ' ' || ch == ':' || ch == '<' || ch == '>' || ch == ',') {
Expand All @@ -76,7 +92,7 @@ class PicusPrinter {
this->mod = mod;
os << "(prime-number 2013265921)\n";
for (auto component : mod.getOps<ComponentOp>()) {
if (component->hasAttr("picus")) {
if (component->hasAttr("picus_analyze")) {
workQueue.push(component);
}
}
Expand Down Expand Up @@ -131,26 +147,37 @@ class PicusPrinter {
// the mapping between MLIR values and Picus signals. Because Picus doesn't
// have control flow, MapOp/ReduceOp are unrolled.
llvm::TypeSwitch<Operation*>(op)
.Case<InvOp, BitAndOp, ModOp, InRangeOp, ExternOp>([&](auto op) { visitNondetOp(op); })
.Case<AddOp, SubOp, MulOp>([&](auto op) { visitBinaryPolyOp(op); })
.Case<ConstOp,
StringOp,
SubOp,
VariadicPackOp,
ExternOp,
LoadOp,
LookupOp,
SubscriptOp,
ConstructOp,
EqualZeroOp,
Zhlt::BackOp,
SwitchOp,
ArrayOp,
PackOp,
ReturnOp,
GetGlobalLayoutOp,
AliasLayoutOp,
zirgen::Zhlt::BackOp>([&](auto op) { visitOp(op); })
.Case<StoreOp, arith::ConstantOp>([](auto) { /* no-op */ })
.Case<StoreOp, YieldOp, arith::ConstantOp>([](auto) { /* no-op */ })
.Default([](Operation* op) { llvm::errs() << "unhandled op: " << *op << "\n"; });
}

// For nondeterministic operations, mark all results as fresh signals.
void visitNondetOp(Operation* op) {
for (Value result : op->getResults()) {
Signal signal = Signal::get(ctx, freshName());
valuesToSignals.insert({result, signal});
}
}

void visitOp(ConstOp constant) {
assert(constant.getCoefficients().size() == 1 && "not implemented");
auto signal = Signal::get(ctx, freshName());
Expand All @@ -175,13 +202,6 @@ class PicusPrinter {
valuesToSignals.insert({pack.getOut(), nullptr});
}

void visitOp(ExternOp ext) {
for (Value result : ext.getOut()) {
Signal signal = Signal::get(ctx, freshName());
valuesToSignals.insert({result, signal});
}
}

void visitOp(LoadOp load) {
auto signal = cast<Signal>(valuesToSignals.at(load.getRef()));
valuesToSignals.insert({load.getOut(), signal});
Expand All @@ -198,7 +218,9 @@ class PicusPrinter {

SmallVector<OpFoldResult> results;
if (failed(subscript.getIndex().getDefiningOp()->fold(results))) {
llvm::errs() << "failed to resolve subscript index\n";
auto diag = subscript->emitError("failed to resolve subscript index\n");
llvm::errs() << "index: " << *subscript.getIndex().getDefiningOp() << "\n";
return;
}
uint64_t index = cast<PolynomialAttr>(results[0].get<Attribute>())[0];
auto subSignal = signal[index];
Expand Down Expand Up @@ -231,6 +253,25 @@ class PicusPrinter {
os << "])\n";
}

void visitBinaryPolyOp(Operation* op) {
auto symbol = llvm::TypeSwitch<Operation*, const char*>(op)
.Case<AddOp>([](auto) { return "+"; })
.Case<SubOp>([](auto) { return "-"; })
.Case<MulOp>([](auto) { return "*"; })
.Default([&](auto) {
op->emitError("unknown binary poly op");
return nullptr;
});

auto signal = Signal::get(ctx, freshName());
valuesToSignals.insert({op->getResult(0), signal});

os << "(assert (= " << signal.str() << " (" << symbol << " ";
os << cast<Signal>(valuesToSignals.at(op->getOperand(0))).str() << " ";
os << cast<Signal>(valuesToSignals.at(op->getOperand(1))).str();
os << ")))\n";
}

void visitOp(SubOp sub) {
auto signal = Signal::get(ctx, freshName());
valuesToSignals.insert({sub.getOut(), signal});
Expand All @@ -247,6 +288,54 @@ class PicusPrinter {
os << " 0))\n";
}

void visitOp(SwitchOp mux) {
os << "; begin mux\n";

SmallVector<Signal> selectorSignals;
for (Value selector : mux.getSelector()) {
selectorSignals.push_back(cast<Signal>(valuesToSignals.at(selector)));
}

SmallVector<SmallVector<Signal>> armSignals;

for (Region& arm : mux.getArms()) {
// Probably need to "turn off" AliasLayoutOps, since different arms may
// write different values to the common super
assert(arm.hasOneBlock());
for (Operation& op : arm.front()) {
visitOp(&op);
}
// Collect values yielded by each arm
Value yielded = cast<YieldOp>(arm.front().getTerminator()).getValue();
AnySignal signal = valuesToSignals.at(yielded);
Type type = yielded.getType();
while (type != mux.getType()) {
signal = getSuperSignal(signal);
type = Zhlt::getSuperType(type);
}
armSignals.push_back(flatten(signal));
os << "; mark mux arm\n";
}

AnySignal outSignal = signalize("mux_" + freshName(), mux.getType());
valuesToSignals.insert({mux.getOut(), outSignal});

SmallVector<Signal> outSignals = flatten(outSignal);
for (size_t i = 0; i < outSignals.size(); i++) {
os << "(assert (= " << outSignals[i].str();
for (size_t j = 0; j < armSignals.size(); j++) {
if (j != armSignals.size() - 1)
os << " (+";
os << " (* " << selectorSignals[j].str() << " " << armSignals[j][i].str() << ")";
}
for (size_t j = 0; j < armSignals.size(); j++) {
os << ")";
}
os << ")\n";
}
os << "; end mux\n";
}

void visitOp(ArrayOp arr) {
SmallVector<AnySignal> elements;
for (auto arg : arr.getElements()) {
Expand All @@ -260,7 +349,7 @@ class PicusPrinter {
void visitOp(PackOp pack) {
SmallVector<NamedAttribute> fields;
for (auto [field, arg] : llvm::zip(pack.getOut().getType().getFields(), pack.getMembers())) {
if (field.name.strref() == "@layout")
if (field.isPrivate || field.name.strref() == "@layout")
continue;
AnySignal member = valuesToSignals.at(arg);
fields.emplace_back(field.name, member);
Expand All @@ -275,7 +364,10 @@ class PicusPrinter {
// of the component. Unify those signals with those of the return value.
AnySignal outputSignal = valuesToSignals.at(Value());
AnySignal returnSignal = valuesToSignals.at(ret.getValue());
for (auto [outs, rets] : llvm::zip(flatten(outputSignal), flatten(returnSignal))) {
SmallVector<Signal> outs = flatten(outputSignal);
SmallVector<Signal> rets = flatten(returnSignal);
assert(outs.size() == rets.size());
for (auto [outs, rets] : llvm::zip(outs, rets)) {
os << "(assert (= " << outs.str() << " " << rets.str() << "))\n";
}
}
Expand All @@ -288,14 +380,75 @@ class PicusPrinter {
}

void visitOp(AliasLayoutOp alias) {
auto lhs = valuesToSignals.at(alias.getLhs());
auto rhs = valuesToSignals.at(alias.getRhs());
for (auto [sl, sr] : llvm::zip(flatten(lhs), flatten(rhs))) {
os << "(assert (= " << sl.str() << " " << sr.str() << "))\n";
// If lhs and rhs have the same lifetime, then aliasing them is
// straightforward: simply constrain lhs = rhs. If they have different
// lifetimes, it's more complicated. Luckily, the only way for things to
// have different lifetimes is because of muxing: values in a mux arm have
// a lifetime strictly contained by the enclosing scope, and values in
// different arms of the same mux have strictly non-overlapping lifetimes.
// We assert that we never see aliases of the second type, as there
// currently is no way to produce such aliases in Zirgen and it's not clear
// how this ought to be handled. For the first case, we require the
// corresponding signals to be equal only when the shorter-lived value is
// live, so we produce Picus assertions of the form s * lhs = s * rhs: if
// s = 0, this is trivially satisified, and if s = 1 it implies lhs = rhs.
// If there are multiple intervening muxes, we take the product of all the
// intervening selectors: product_i(s_i) * lhs = product_i(s_i) * rhs.

Value lhs = alias.getLhs();
Value rhs = alias.getRhs();
Operation* lhsOp = lhs.getDefiningOp();
Operation* rhsOp = rhs.getDefiningOp();
Block* lhsBlock = lhs.getParentBlock();
Block* rhsBlock = rhs.getParentBlock();

// Find the longer-lived value
Value shortLived;
Value longLived;
if (lhsBlock->findAncestorOpInBlock(*rhsOp) != nullptr) {
shortLived = rhs;
longLived = lhs;
} else if (rhsBlock->findAncestorOpInBlock(*lhsOp) != nullptr) {
shortLived = lhs;
longLived = rhs;
} else {
assert(false && "cannot resolve relative lifetimes of aliased values");
}

SmallVector<Value> interveningSelectors;
Region* x = shortLived.getParentRegion();
while (x != longLived.getParentRegion()) {
// All loops are unrolled and no other ops contain regions, so the
// ancestor must be a SwitchOp.
auto mux = cast<SwitchOp>(x->getParentOp());
interveningSelectors.push_back(mux.getSelector()[x->getRegionNumber()]);
x = x->getParentRegion();
}

auto conditionalize = [&](Signal signal) {
if (interveningSelectors.empty()) {
os << signal.str();
} else {
for (Value s : interveningSelectors)
os << "(* " << cast<Signal>(valuesToSignals.at(s)).str() << " ";
os << signal.str();
for (size_t i = 0; i < interveningSelectors.size(); i++)
os << ")";
}
};

auto lhsSignal = valuesToSignals.at(lhs);
auto rhsSignal = valuesToSignals.at(rhs);
for (auto [sl, sr] : llvm::zip(flatten(lhsSignal), flatten(rhsSignal))) {
os << "(assert (= ";
conditionalize(sl);
os << " ";
conditionalize(sr);
os << "))\n";
}
}

void visitOp(zirgen::Zhlt::BackOp back) {
void visitOp(Zhlt::BackOp back) {
size_t distance = back.getDistance().getZExtValue();
AnySignal signal = signalize(freshName(), back.getType());
// We cannot handle the zero-distance case this way, so we expect that
Expand All @@ -321,8 +474,10 @@ class PicusPrinter {
for (auto field : str.getFields()) {
if (field.name.strref() == "@layout")
continue;
std::string name = prefix + "_" + canonicalizeIdentifier(field.name.str());
fields.emplace_back(field.name, signalize(name, field.type));
if (!field.isPrivate) {
std::string name = prefix + "_" + canonicalizeIdentifier(field.name.str());
fields.emplace_back(field.name, signalize(name, field.type));
}
}
return SignalStruct::get(ctx, fields);
} else if (auto str = dyn_cast<LayoutType>(type)) {
Expand Down Expand Up @@ -387,6 +542,7 @@ class PicusPrinter {

void printPicus(ModuleOp mod, llvm::raw_ostream& os) {
PassManager pm(mod->getContext());
pm.addPass(zirgen::dsl::createGenerateBackPass());
pm.addPass(zirgen::dsl::createInlineForPicusPass());
pm.addPass(createUnrollPass());
pm.addPass(createCanonicalizerPass());
Expand Down
13 changes: 6 additions & 7 deletions zirgen/compiler/picus/test/alias_layout_1.zir
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,17 @@
// CHECK-NEXT: (output layout_b__super__super)
// CHECK-NEXT: (output result_a__super)
// CHECK-NEXT: (output result_b__super__super)
// CHECK-NEXT: (output result_b_reg__super)
// CHECK-NEXT: (assert (= x0 0))
// CHECK-NEXT: (call [layout_b__super__super x1__super__super x1_reg__super] Reg [x0])
// CHECK-NEXT: (assert (= x1 (- x0 layout_b__super__super)))
// CHECK-NEXT: (assert (= x1 0))
// CHECK-NEXT: (assert (= layout_a__super layout_b__super__super))
// CHECK-NEXT: (assert (= result_a__super layout_a__super))
// CHECK-NEXT: (assert (= result_b__super__super x1__super__super))
// CHECK-NEXT: (assert (= result_b_reg__super x1_reg__super))
// CHECK-NEXT: (assert (= result_b__super__super layout_b__super__super))
// CHECK-NEXT: (end-module)

#[picus]
#[picus_analyze]
component Top() {
a := NondetReg(0);
b := Reg(0);
public a := NondetReg(0);
public b := Reg(0);
AliasLayout!(a, b);
}
Loading

0 comments on commit b11c5cb

Please sign in to comment.