Skip to content

Commit

Permalink
Add bytecode interpreter for validity polynomial calculation for kecc…
Browse files Browse the repository at this point in the history
…ak circuit
  • Loading branch information
shkoo committed Jan 21, 2025
1 parent 969982e commit 5b00c26
Show file tree
Hide file tree
Showing 56 changed files with 3,134 additions and 61 deletions.
75 changes: 75 additions & 0 deletions zirgen/Dialect/ByteCode/IR/Attrs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright 2025 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "llvm/ADT/TypeSwitch.h"

#include "zirgen/Dialect/ByteCode/IR/ByteCode.h"

using namespace mlir;

namespace zirgen::ByteCode {

DispatchKeyAttr getDispatchKey(Operation* op) {
SmallVector<size_t> intArgs;
if (auto bcInterface = llvm::dyn_cast<ByteCodeOpInterface>(op)) {
bcInterface.getByteCodeIntArgs(intArgs);
}

auto operandTypes = llvm::to_vector(op->getOperandTypes());
auto resultTypes = llvm::to_vector(op->getResultTypes());

SmallVector<mlir::Attribute> intKinds;
for (auto idx : llvm::seq(intArgs.size())) {
intKinds.push_back(StringAttr::get(
op->getContext(), (op->getName().getStringRef() + "_" + std::to_string(idx)).str()));
}

SmallVector<size_t> blockArgNums;
for (Value operand : op->getOperands()) {
if (auto blockArg = llvm::dyn_cast<BlockArgument>(operand)) {
blockArgNums.push_back(blockArg.getArgNumber());
}
}

return DispatchKeyAttr::get(op->getContext(),
/*operationName=*/op->getName().getStringRef(),
operandTypes,
resultTypes,
intKinds,
/*blockArgs=*/blockArgNums);
}

std::string getNameForIntKind(mlir::Attribute intKind) {
if (auto strAttr = llvm::dyn_cast<StringAttr>(intKind)) {
return strAttr.str();
}
if (auto unitAttr = llvm::dyn_cast<UnitAttr>(intKind)) {
return "unit";
}
std::string str;
llvm::raw_string_ostream os(str);
os << intKind;

llvm::erase_if(str, [](char c) { return c == '"' || c == ' '; });
return str;
}

} // namespace zirgen::ByteCode
77 changes: 77 additions & 0 deletions zirgen/Dialect/ByteCode/IR/Attrs.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright 2025 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef BYTECODE_ATTRS
#define BYTECODE_ATTRS

include "mlir/IR/AttrTypeBase.td"
include "zirgen/Dialect/ByteCode/IR/Dialect.td"

class ByteCodeAttr<string name, string attrMnemonic, list<Trait> traits = []>
: AttrDef<ByteCodeDialect, name, traits> {
let mnemonic = attrMnemonic;
}

def TempBufInfoAttr : ByteCodeAttr<"TempBufInfo", "temp_buf"> {
let parameters = (ins
// Arbitrary attribute used as a uniquing key for integer kinds.
"mlir::Attribute": $intKind,
// Number of bits used to encode this kind of int
"size_t": $size
);
let assemblyFormat = [{ $intKind `size` $size }];
}

def IntKindInfoAttr : ByteCodeAttr<"IntKindInfo", "int_kind_info"> {
let summary = "Information on a set of encoded integers in a bytecode encoding";
let parameters = (ins
// Arbitrary attribute used as a uniquing key for integer kinds.
"mlir::Attribute": $intKind,
// Number of bits used to encode this kind of int
"size_t": $encodedBits
);
let assemblyFormat = [{ $intKind `u` $encodedBits }];
}

def IntKindInfoArrayAttr : TypedArrayAttrBase<IntKindInfoAttr, "Array of integer kinds">;

def EncodedAttr : ByteCodeAttr<"Encoded", "encoded"> {
let summary = "A bytecode encoding for passing to an executor";
let parameters = (ins
StringRefParameter<>: $encoded,
// Sizes of any temporary buffers that need to be allocated to execute this bytecode.
ArrayRefParameter<"zirgen::ByteCode::TempBufInfoAttr">: $tempBufs
);
let assemblyFormat = [{ $encoded `temps` $tempBufs }];
}

def DispatchKeyAttr : ByteCodeAttr<"DispatchKey", "dispatch_key"> {
let summary = "Unique identifier for a byte code operation on concrete types";
let parameters = (ins
StringRefParameter<>:$operationName,
ArrayRefParameter<"mlir::Type">:$operandTypes,
ArrayRefParameter<"mlir::Type">:$resultTypes,
// Int kinds of any additional decoded arguments, not including value operands.
ArrayRefParameter<"mlir::Attribute">:$intKinds,
// Further unique by any captured block arguments from containing
// block in case they happen to be the same type.
ArrayRefParameter<"size_t">:$blockArgNums
);
let assemblyFormat = [{
$operationName `(` $intKinds `,` $operandTypes `,` $blockArgNums `)`
`->` $resultTypes
}];
}

#endif // BYTECODE_ATTRS
129 changes: 129 additions & 0 deletions zirgen/Dialect/ByteCode/IR/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")

package(
default_visibility = ["//visibility:public"],
)

td_library(
name = "TdFiles",
srcs = [
"Attrs.td",
"Dialect.td",
"Ops.td",
"Types.td",
],
deps = [
"//zirgen/Dialect/ByteCode/Interfaces:TdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
],
)

gentbl_cc_library(
name = "AttrsIncGen",
tbl_outs = [
(
[
"-gen-attrdef-decls",
],
"Attrs.h.inc",
),
(
[
"-gen-attrdef-defs",
],
"Attrs.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = ":Attrs.td",
deps = [":TdFiles"],
)

gentbl_cc_library(
name = "TypesIncGen",
tbl_outs = [
(
[
"-gen-typedef-decls",
],
"Types.h.inc",
),
(
[
"-gen-typedef-defs",
],
"Types.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = ":Types.td",
deps = [":TdFiles"],
)

gentbl_cc_library(
name = "OpsIncGen",
tbl_outs = [
(
[
"-gen-dialect-decls",
"-dialect=zbytecode",
],
"Dialect.h.inc",
),
(
[
"-gen-dialect-defs",
"-dialect=zbytecode",
],
"Dialect.cpp.inc",
),
(
["-gen-op-decls"],
"Ops.h.inc",
),
(
["-gen-op-defs"],
"Ops.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = ":Ops.td",
deps = [":TdFiles"],
)

cc_library(
name = "IR",
srcs = [
"Attrs.cpp",
"Dialect.cpp",
"Ops.cpp",
],
hdrs = [
"ByteCode.h",
],
deps = [
":AttrsIncGen",
":OpsIncGen",
":TypesIncGen",
"//zirgen/Dialect/ByteCode/Interfaces",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
],
)

cc_library(
name = "Codegen",
srcs = [
"Codegen.cpp",
],
hdrs = [
"Codegen.h",
],
deps = [
":IR",
"//zirgen/Dialect/Zll/IR",
],
)
51 changes: 51 additions & 0 deletions zirgen/Dialect/ByteCode/IR/ByteCode.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright 2025 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "zirgen/Dialect/ByteCode/Interfaces/Interfaces.h"

namespace zirgen {

class Interpreter;

} // namespace zirgen

#include "zirgen/Dialect/ByteCode/IR/Dialect.h.inc"

#define GET_TYPEDEF_CLASSES
#include "zirgen/Dialect/ByteCode/IR/Types.h.inc"

#define GET_ATTRDEF_CLASSES
#include "zirgen/Dialect/ByteCode/IR/Attrs.h.inc"

#define GET_OP_CLASSES
#include "zirgen/Dialect/ByteCode/IR/Ops.h.inc"

namespace zirgen::ByteCode {

DispatchKeyAttr getDispatchKey(mlir::Operation* op);

mlir::Attribute getDispatchKeyIntKind(mlir::MLIRContext* ctx);

std::string getNameForIntKind(mlir::Attribute intKind);

} // namespace zirgen::ByteCode
Loading

0 comments on commit 5b00c26

Please sign in to comment.