From de6f134db72f24b5a2982fa4ae4f81933cec9c87 Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Fri, 22 Nov 2024 15:41:43 -0800 Subject: [PATCH 01/34] Draft modmul accelerator --- zirgen/circuit/bigint/bigint2c.cpp | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/zirgen/circuit/bigint/bigint2c.cpp b/zirgen/circuit/bigint/bigint2c.cpp index cdb73a20..43cecdcd 100644 --- a/zirgen/circuit/bigint/bigint2c.cpp +++ b/zirgen/circuit/bigint/bigint2c.cpp @@ -42,6 +42,7 @@ enum class Program { ModPow_65537, EC_Double, EC_Add, + ModMul, }; } // namespace @@ -50,7 +51,8 @@ static cl::opt cl::desc("The program to compile"), cl::values(clEnumValN(Program::ModPow_65537, "modpow_65537", "ModPow_65537"), clEnumValN(Program::EC_Double, "ec_double", "EC_Double"), - clEnumValN(Program::EC_Add, "ec_add", "EC_Add")), + clEnumValN(Program::EC_Add, "ec_add", "EC_Add"), + clEnumValN(Program::ModMul, "modmul", "ModMul")), // TODO: Don't hardcode bitwidth cl::Required); const APInt secp256k1_prime = APInt::getAllOnes(256) - APInt::getOneBitSet(256, 32) - @@ -417,6 +419,15 @@ void genModPow65537(mlir::Location loc, mlir::OpBuilder& builder) { builder.create(loc, x, 13, 0); } +void genModMul(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { + auto lhs = builder.create(loc, bitwidth, 11, 0); + auto rhs = builder.create(loc, bitwidth, 12, 0); + auto modulus = builder.create(loc, bitwidth, 13, 0); + auto prod = builder.create(loc, lhs, rhs); + auto result = builder.create(loc, prod, modulus); + builder.create(loc, result, 14, 0); +} + void genECDouble(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks size_t chunkwidth = bitwidth / 128; @@ -482,6 +493,8 @@ int main(int argc, char* argv[]) { case Program::EC_Add: genECAdd(loc, builder, 256); // TODO: Selectable bitwidth break; + case Program::ModMul: + genModMul(loc, builder, 256); // TODO: Selectable bitwidth } builder.create(loc); From 59ac726c03ab7a0dabbdaeb60189592db0f1f138 Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Wed, 27 Nov 2024 16:46:12 -0800 Subject: [PATCH 02/34] Draft field & field extension ops --- zirgen/circuit/bigint/bigint2c.cpp | 60 ++++++++++++++++ zirgen/circuit/bigint/field.cpp | 107 +++++++++++++++++++++++++++++ zirgen/circuit/bigint/field.h | 36 ++++++++++ 3 files changed, 203 insertions(+) create mode 100644 zirgen/circuit/bigint/field.cpp create mode 100644 zirgen/circuit/bigint/field.h diff --git a/zirgen/circuit/bigint/bigint2c.cpp b/zirgen/circuit/bigint/bigint2c.cpp index 43cecdcd..7fb935d9 100644 --- a/zirgen/circuit/bigint/bigint2c.cpp +++ b/zirgen/circuit/bigint/bigint2c.cpp @@ -462,6 +462,66 @@ void genECAdd(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { builder.create(loc, result.y(), 14, chunkwidth); } +void genModAdd(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { + // TODO: Examine what happens on bitwidths not a multiple of 8, of 32, of 128 + auto lhs = builder.create(loc, bitwidth, 11, 0); + auto rhs = builder.create(loc, bitwidth, 12, 0); + auto prime = builder.create(loc, bitwidth, 13, 0, bitwidth - 1); + auto result = BigInt::field::modAdd(builder, loc, lhs, rhs, prime); + builder.create(loc, result, 14, 0); +} + +void genModInv(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { + // TODO: Examine what happens on bitwidths not a multiple of 8, of 32, of 128 + auto inp = builder.create(loc, bitwidth, 11, 0); + auto prime = builder.create(loc, bitwidth, 12, 0, bitwidth - 1); + auto result = BigInt::field::modInv(builder, loc, inp, prime); + builder.create(loc, result, 13, 0); +} + +void genModMul(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { + // TODO: Examine what happens on bitwidths not a multiple of 8, of 32, of 128 + auto lhs = builder.create(loc, bitwidth, 11, 0); + auto rhs = builder.create(loc, bitwidth, 12, 0); + auto prime = builder.create(loc, bitwidth, 13, 0, bitwidth - 1); + auto result = BigInt::field::modMul(builder, loc, lhs, rhs, prime); + builder.create(loc, result, 14, 0); +} + +void genModSub(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { + // TODO: Examine what happens on bitwidths not a multiple of 8, of 32, of 128 + auto lhs = builder.create(loc, bitwidth, 11, 0); + auto rhs = builder.create(loc, bitwidth, 12, 0); + auto prime = builder.create(loc, bitwidth, 13, 0, bitwidth - 1); + auto result = BigInt::field::modSub(builder, loc, lhs, rhs, prime); + builder.create(loc, result, 14, 0); +} + +void genExtFieldAdd(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth, size_t degree) { + // TODO: will need to handle bitwidth slightly smaller than data chunks + assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks + size_t chunkwidth = bitwidth / 128; + llvm::SmallVector lhs(degree); + llvm::SmallVector rhs(degree); + auto lhs = builder.create(loc, bitwidth, 11, 0); + auto rhs = builder.create(loc, bitwidth, 12, 0); + auto prime = builder.create(loc, bitwidth, 13, 0, bitwidth - 1); + auto result = BigInt::field::modSub(builder, loc, lhs, rhs, prime); + builder.create(loc, result, 14, 0); +} + + + +// Extension fields we use are most commonly degree 2 +// TODO: ^ Hence the use of 2 in the SmallVectors ... but is this true? +Value extAdd(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); +Value extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime, llvm::SmallVector monic_irred_poly); +Value extSub(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); + + + + + int main(int argc, char* argv[]) { llvm::InitLLVM y(argc, argv); mlir::registerAsmPrinterCLOptions(); diff --git a/zirgen/circuit/bigint/field.cpp b/zirgen/circuit/bigint/field.cpp new file mode 100644 index 00000000..ad18378d --- /dev/null +++ b/zirgen/circuit/bigint/field.cpp @@ -0,0 +1,107 @@ +// Copyright 2024 RISC Zero, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "zirgen/circuit/bigint/field.h" + +namespace zirgen::BigInt::field { + +// Prime field operations + +Value modAdd(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, Value prime) { + auto sum = builder.create(lhs, rhs); + auto result = builder.create(sum, prime); + return result; +} + +Value modInv(mlir::OpBuilder builder, mlir::Location loc, Value inp, Value prime) { + return builder.create(inp, prime); +} + +Value modMul(mlir::OpBuilder builder, mlir::Location loc, size_t bits, Value lhs, Value rhs, Value prime) { + auto prod = builder.create(lhs, rhs); + auto result = builder.create(prod, prime); + return result; +} + +Value modSub(mlir::OpBuilder builder, mlir::Location loc, size_t bits, Value lhs, Value rhs, Value prime) { + auto diff = builder.create(lhs, rhs); + auto result = builder.create(diff, prime); + return result; +} + +// Extension field operations + +llvm::SmallVector extAdd(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime) { + auto deg = lhs.size(); + assert(rhs.size() === deg); + llvm::SmallVector result(deg); + + for (size_t i = 0; i < deg; i++) { + auto sum = builder.create(lhs[i], rhs[i]); + result[i] = builder.create(sum, prime); + } + return result; +} + +llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime, llvm::SmallVector monic_irred_poly) { + // TODO: We could have a simplified version for nth roots x^n - a + // Here `monic_irred_poly` is the coefficients a_i such that x^n - sum_i a_i x^i = 0 + auto deg = lhs.size(); + // Note: The field is not an extension field if deg <= 1 + assert(deg > 1); + assert(rhs.size() === deg); + assert(monic_irred_poly.size() == deg); + llvm::SmallVector result(2 * deg - 1); + + // Compute product of polynomials + for (size_t i = 0; i < deg; i++) { + for (size_t j = 0; j < deg; j++) { + size_t idx = i + j; + auto prod = builder.create(lhs[i], rhs[j]); + auto reduced_prod = builder.create(prod, prime); + result[idx] = TODO init or sum + } + auto sum = builder.create(lhs[i], rhs[i]); + result[i] = builder.create(sum, prime); + } + // Reduce using the monic irred polynomial of the extension field + for (size_t i = 2 * deg - 2; i >= deg; k--) { + for (size_t j = 0; j < deg; j++) { + auto prod = builder.create(result[i], monic_irred_poly[j]); + result[i - deg + j] = builder.create(result[i - deg + j], prod); + result[i - deg + j] = builder.create(result[i - deg + j], prime); + } + // No need to zero out result[i], it will just get dropped + } + // Result's degree is just `deg`, drop the coefficients beyond that + result.truncate(deg); + + return result; +} + +llvm::SmallVector extSub(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime) { + auto deg = lhs.size(); + assert(rhs.size() === deg); + llvm::SmallVector result(deg); + + for (size_t i = 0; i < deg; i++) { + auto diff = builder.create(lhs[i], rhs[i]); + result[i] = builder.create(diff[i], prime); + } + return result; +} + + + +} // namespace zirgen::BigInt::field diff --git a/zirgen/circuit/bigint/field.h b/zirgen/circuit/bigint/field.h new file mode 100644 index 00000000..0c52f447 --- /dev/null +++ b/zirgen/circuit/bigint/field.h @@ -0,0 +1,36 @@ +// Copyright 2024 RISC Zero, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "zirgen/Dialect/BigInt/IR/BigInt.h" + +using namespace mlir; + +namespace zirgen::BigInt::field { + +// Prime field arithmetic (aka modular arithmetic) +Value modAdd(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, Value prime); +Value modInv(mlir::OpBuilder builder, mlir::Location loc, Value inp, Value prime); +Value modMul(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, Value prime); +Value modSub(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, Value prime); + +// Extension field arithmetic +// Extension fields we use are most commonly degree 2 +// TODO: ^ Hence the use of 2 in the SmallVectors ... but is this true? +llvm::SmallVector extAdd(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); +llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime, llvm::SmallVector monic_irred_poly); +llvm::SmallVector extSub(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); + +} // namespace zirgen::BigInt::field From 32ad2a871db45275af6be6c80d738bb75817f5e7 Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Mon, 2 Dec 2024 16:34:38 -0800 Subject: [PATCH 03/34] Update field op draft with codegen --- zirgen/bootstrap/src/main.rs | 4 +- zirgen/circuit/bigint/BUILD.bazel | 10 +++ zirgen/circuit/bigint/bigint2c.cpp | 100 +++++++++++++++++++---------- zirgen/circuit/bigint/field.cpp | 59 +++++++++-------- 4 files changed, 114 insertions(+), 59 deletions(-) diff --git a/zirgen/bootstrap/src/main.rs b/zirgen/bootstrap/src/main.rs index 4cd1a6db..c1215359 100644 --- a/zirgen/bootstrap/src/main.rs +++ b/zirgen/bootstrap/src/main.rs @@ -652,9 +652,11 @@ impl Bootstrap { let risc0_root = risc0_root.join("risc0"); let bazel_bin = get_bazel_bin(); let src_path = bazel_bin.join("zirgen/circuit/bigint"); - let rsa_path = risc0_root.join("bigint2/src/rsa"); let ec_path = risc0_root.join("bigint2/src/ec"); + let field_path = risc0_root.join("bigint2/src/field"); + let rsa_path = risc0_root.join("bigint2/src/rsa"); + self.copy_file(&src_path, &field_path, "modmul.blob"); self.copy_file(&src_path, &rsa_path, "modpow_65537.blob"); self.copy( &src_path.join("ec_double.blob"), diff --git a/zirgen/circuit/bigint/BUILD.bazel b/zirgen/circuit/bigint/BUILD.bazel index 58f1d344..67284600 100644 --- a/zirgen/circuit/bigint/BUILD.bazel +++ b/zirgen/circuit/bigint/BUILD.bazel @@ -9,11 +9,13 @@ cc_library( name = "lib", srcs = [ "elliptic_curve.cpp", + "field.cpp", "op_tests.cpp", "rsa.cpp", ], hdrs = [ "elliptic_curve.h", + "field.h", "op_tests.h", "rsa.h", "//zirgen/circuit/recursion", @@ -96,6 +98,7 @@ BLOBS = [ "modpow_65537", "ec_double", "ec_add", + "modmul", ] genrule( @@ -119,6 +122,13 @@ genrule( cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=ec_add > $(OUTS)" ) +genrule( + name = "modmul", + outs = ["modmul.blob"], + exec_tools = [":bigint2c"], + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=modmul > $(OUTS)" +) + pkg_zip( name = "bigint_blob", srcs = [x + ".blob" for x in BLOBS], diff --git a/zirgen/circuit/bigint/bigint2c.cpp b/zirgen/circuit/bigint/bigint2c.cpp index d0eb49f6..9d3dc59b 100644 --- a/zirgen/circuit/bigint/bigint2c.cpp +++ b/zirgen/circuit/bigint/bigint2c.cpp @@ -30,6 +30,7 @@ #include "zirgen/Dialect/BigInt/IR/BigInt.h" #include "zirgen/Dialect/BigInt/Transforms/Passes.h" #include "zirgen/circuit/bigint/elliptic_curve.h" +#include "zirgen/circuit/bigint/field.h" using namespace zirgen; namespace cl = llvm::cl; @@ -402,7 +403,7 @@ std::vector polySplit(mlir::func::FuncOp func) { } void genModPow65537(mlir::Location loc, mlir::OpBuilder& builder) { - const size_t bits = 4096; + const size_t bits = 3072; // Check if (S^e = M (mod N)), where e = 65537 auto S = builder.create(loc, bits, 11, 0); auto N = builder.create(loc, bits, 12, 0); @@ -419,15 +420,7 @@ void genModPow65537(mlir::Location loc, mlir::OpBuilder& builder) { builder.create(loc, x, 13, 0); } -void genModMul(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { - auto lhs = builder.create(loc, bitwidth, 11, 0); - auto rhs = builder.create(loc, bitwidth, 12, 0); - auto modulus = builder.create(loc, bitwidth, 13, 0); - auto prod = builder.create(loc, lhs, rhs); - auto result = builder.create(loc, prod, modulus); - builder.create(loc, result, 14, 0); -} - +// TODO: Examine what happens on bitwidths not a multiple of 8, of 32, of 128 void genECDouble(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks size_t chunkwidth = bitwidth / 128; @@ -444,6 +437,7 @@ void genECDouble(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) builder.create(loc, doubled.y(), 13, chunkwidth); } +// TODO: Examine what happens on bitwidths not a multiple of 8, of 32, of 128 void genECAdd(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks size_t chunkwidth = bitwidth / 128; @@ -462,65 +456,105 @@ void genECAdd(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { builder.create(loc, result.y(), 14, chunkwidth); } +// Finite Field arithmetic +// +// These functions accelerate finite field arithmetic +// - The `Mod` versions are for prime order fields +// - The `FieldExt` versions are for simple extensions +// - Every finite extension of a finite field is simple, so in a sense this covers every finite +// field, but to use these functions you must represent the extension as the adjunction of a +// primitive element to a prime order field, which is not always convenient (i.e. when you have +// a tower of extensions) +// +// We do not use integer quotients in these functions, so minBits does not give us performance gains +// and we therefore do not require the prime to be full bitwidth, enabling simpler generalization +// (i.e., there's no need to make sure the bitwidth is minimal for your use case) + void genModAdd(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { - // TODO: Examine what happens on bitwidths not a multiple of 8, of 32, of 128 auto lhs = builder.create(loc, bitwidth, 11, 0); auto rhs = builder.create(loc, bitwidth, 12, 0); - auto prime = builder.create(loc, bitwidth, 13, 0, bitwidth - 1); + auto prime = builder.create(loc, bitwidth, 13, 0); auto result = BigInt::field::modAdd(builder, loc, lhs, rhs, prime); builder.create(loc, result, 14, 0); } void genModInv(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { - // TODO: Examine what happens on bitwidths not a multiple of 8, of 32, of 128 auto inp = builder.create(loc, bitwidth, 11, 0); - auto prime = builder.create(loc, bitwidth, 12, 0, bitwidth - 1); + auto prime = builder.create(loc, bitwidth, 12, 0); auto result = BigInt::field::modInv(builder, loc, inp, prime); builder.create(loc, result, 13, 0); } void genModMul(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { - // TODO: Examine what happens on bitwidths not a multiple of 8, of 32, of 128 auto lhs = builder.create(loc, bitwidth, 11, 0); auto rhs = builder.create(loc, bitwidth, 12, 0); - auto prime = builder.create(loc, bitwidth, 13, 0, bitwidth - 1); + auto prime = builder.create(loc, bitwidth, 13, 0); auto result = BigInt::field::modMul(builder, loc, lhs, rhs, prime); builder.create(loc, result, 14, 0); } void genModSub(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { - // TODO: Examine what happens on bitwidths not a multiple of 8, of 32, of 128 auto lhs = builder.create(loc, bitwidth, 11, 0); auto rhs = builder.create(loc, bitwidth, 12, 0); - auto prime = builder.create(loc, bitwidth, 13, 0, bitwidth - 1); + auto prime = builder.create(loc, bitwidth, 13, 0); auto result = BigInt::field::modSub(builder, loc, lhs, rhs, prime); builder.create(loc, result, 14, 0); } +// Extension fields we use are most commonly degree 2 +// TODO: ^ Hence the use of 2 in the SmallVectors ... but is this true? void genExtFieldAdd(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth, size_t degree) { // TODO: will need to handle bitwidth slightly smaller than data chunks assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks size_t chunkwidth = bitwidth / 128; llvm::SmallVector lhs(degree); llvm::SmallVector rhs(degree); - auto lhs = builder.create(loc, bitwidth, 11, 0); - auto rhs = builder.create(loc, bitwidth, 12, 0); - auto prime = builder.create(loc, bitwidth, 13, 0, bitwidth - 1); - auto result = BigInt::field::modSub(builder, loc, lhs, rhs, prime); - builder.create(loc, result, 14, 0); + for (size_t i = 0; i < degree; i++) { + lhs[i] = builder.create(loc, bitwidth, 11, i * chunkwidth); + rhs[i] = builder.create(loc, bitwidth, 12, i * chunkwidth); + } + auto prime = builder.create(loc, bitwidth, 13, 0); + auto result = BigInt::field::extAdd(builder, loc, lhs, rhs, prime); + for (size_t i = 0; i < degree; i++) { + builder.create(loc, result[i], 14, i * chunkwidth); + } } +void genExtFieldMul(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth, size_t degree) { + // TODO: will need to handle bitwidth slightly smaller than data chunks + assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks + size_t chunkwidth = bitwidth / 128; + llvm::SmallVector lhs(degree); + llvm::SmallVector rhs(degree); + llvm::SmallVector monic_irred_poly(degree); + for (size_t i = 0; i < degree; i++) { + lhs[i] = builder.create(loc, bitwidth, 11, i * chunkwidth); + rhs[i] = builder.create(loc, bitwidth, 12, i * chunkwidth); + monic_irred_poly[i] = builder.create(loc, bitwidth, 13, i * chunkwidth); + } + auto prime = builder.create(loc, bitwidth, 14, 0); + auto result = BigInt::field::extMul(builder, loc, lhs, rhs, prime, monic_irred_poly); + for (size_t i = 0; i < degree; i++) { + builder.create(loc, result[i], 15, i * chunkwidth); + } +} - -// Extension fields we use are most commonly degree 2 -// TODO: ^ Hence the use of 2 in the SmallVectors ... but is this true? -Value extAdd(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); -Value extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime, llvm::SmallVector monic_irred_poly); -Value extSub(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); - - - - +void genExtFieldSub(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth, size_t degree) { + // TODO: will need to handle bitwidth slightly smaller than data chunks + assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks + size_t chunkwidth = bitwidth / 128; + llvm::SmallVector lhs(degree); + llvm::SmallVector rhs(degree); + for (size_t i = 0; i < degree; i++) { + lhs[i] = builder.create(loc, bitwidth, 11, i * chunkwidth); + rhs[i] = builder.create(loc, bitwidth, 12, i * chunkwidth); + } + auto prime = builder.create(loc, bitwidth, 13, 0); + auto result = BigInt::field::extSub(builder, loc, lhs, rhs, prime); + for (size_t i = 0; i < degree; i++) { + builder.create(loc, result[i], 14, i * chunkwidth); + } +} int main(int argc, char* argv[]) { llvm::InitLLVM y(argc, argv); diff --git a/zirgen/circuit/bigint/field.cpp b/zirgen/circuit/bigint/field.cpp index ad18378d..732cea12 100644 --- a/zirgen/circuit/bigint/field.cpp +++ b/zirgen/circuit/bigint/field.cpp @@ -19,24 +19,24 @@ namespace zirgen::BigInt::field { // Prime field operations Value modAdd(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, Value prime) { - auto sum = builder.create(lhs, rhs); - auto result = builder.create(sum, prime); + auto sum = builder.create(loc, lhs, rhs); + auto result = builder.create(loc, sum, prime); return result; } Value modInv(mlir::OpBuilder builder, mlir::Location loc, Value inp, Value prime) { - return builder.create(inp, prime); + return builder.create(loc, inp, prime); } -Value modMul(mlir::OpBuilder builder, mlir::Location loc, size_t bits, Value lhs, Value rhs, Value prime) { - auto prod = builder.create(lhs, rhs); - auto result = builder.create(prod, prime); +Value modMul(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, Value prime) { + auto prod = builder.create(loc, lhs, rhs); + auto result = builder.create(loc, prod, prime); return result; } -Value modSub(mlir::OpBuilder builder, mlir::Location loc, size_t bits, Value lhs, Value rhs, Value prime) { - auto diff = builder.create(lhs, rhs); - auto result = builder.create(diff, prime); +Value modSub(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, Value prime) { + auto diff = builder.create(loc, lhs, rhs); + auto result = builder.create(loc, diff, prime); return result; } @@ -44,43 +44,51 @@ Value modSub(mlir::OpBuilder builder, mlir::Location loc, size_t bits, Value lhs llvm::SmallVector extAdd(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime) { auto deg = lhs.size(); - assert(rhs.size() === deg); + assert(rhs.size() == deg); llvm::SmallVector result(deg); for (size_t i = 0; i < deg; i++) { - auto sum = builder.create(lhs[i], rhs[i]); - result[i] = builder.create(sum, prime); + auto sum = builder.create(loc, lhs[i], rhs[i]); + result[i] = builder.create(loc, sum, prime); } return result; } llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime, llvm::SmallVector monic_irred_poly) { + // TODO: Annoying to have a SmallVector output that needs to be deg - 1 bigger than the inputs; I think that means all should be 3... // TODO: We could have a simplified version for nth roots x^n - a // Here `monic_irred_poly` is the coefficients a_i such that x^n - sum_i a_i x^i = 0 auto deg = lhs.size(); // Note: The field is not an extension field if deg <= 1 assert(deg > 1); - assert(rhs.size() === deg); + assert(rhs.size() == deg); assert(monic_irred_poly.size() == deg); llvm::SmallVector result(2 * deg - 1); + llvm::SmallVector first_write(2 * deg - 1, true); // Compute product of polynomials for (size_t i = 0; i < deg; i++) { for (size_t j = 0; j < deg; j++) { size_t idx = i + j; - auto prod = builder.create(lhs[i], rhs[j]); - auto reduced_prod = builder.create(prod, prime); - result[idx] = TODO init or sum + auto prod = builder.create(loc, lhs[i], rhs[j]); + auto reduced_prod = builder.create(loc, prod, prime); + if (first_write[idx]) { + result[idx] = reduced_prod; + first_write[idx] = false; + } else { + result[idx] = builder.create(loc, result[idx], reduced_prod); + result[idx] = builder.create(loc, result[idx], prime); + } } - auto sum = builder.create(lhs[i], rhs[i]); - result[i] = builder.create(sum, prime); + auto sum = builder.create(loc, lhs[i], rhs[i]); + result[i] = builder.create(loc, sum, prime); } // Reduce using the monic irred polynomial of the extension field - for (size_t i = 2 * deg - 2; i >= deg; k--) { + for (size_t i = 2 * deg - 2; i >= deg; i--) { for (size_t j = 0; j < deg; j++) { - auto prod = builder.create(result[i], monic_irred_poly[j]); - result[i - deg + j] = builder.create(result[i - deg + j], prod); - result[i - deg + j] = builder.create(result[i - deg + j], prime); + auto prod = builder.create(loc, result[i], monic_irred_poly[j]); + result[i - deg + j] = builder.create(loc, result[i - deg + j], prod); + result[i - deg + j] = builder.create(loc, result[i - deg + j], prime); } // No need to zero out result[i], it will just get dropped } @@ -92,12 +100,13 @@ llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector extSub(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime) { auto deg = lhs.size(); - assert(rhs.size() === deg); + assert(rhs.size() == deg); llvm::SmallVector result(deg); for (size_t i = 0; i < deg; i++) { - auto diff = builder.create(lhs[i], rhs[i]); - result[i] = builder.create(diff[i], prime); + // auto diff = builder.create(loc, lhs[i], rhs[i]); + auto diff = builder.create(loc, lhs[i], rhs[i]); + result[i] = builder.create(loc, diff, prime); } return result; } From 04ed10733a77221899451a0f5376a7633cb5fdb2 Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Mon, 2 Dec 2024 16:48:56 -0800 Subject: [PATCH 04/34] WIP Add modadd --- zirgen/circuit/bigint/BUILD.bazel | 1 + zirgen/circuit/bigint/bigint2c.cpp | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/zirgen/circuit/bigint/BUILD.bazel b/zirgen/circuit/bigint/BUILD.bazel index ddafe269..6432ddb0 100644 --- a/zirgen/circuit/bigint/BUILD.bazel +++ b/zirgen/circuit/bigint/BUILD.bazel @@ -57,6 +57,7 @@ BLOBS = [ "modpow_65537", "ec_double", "ec_add", + "modadd", "modmul", ] diff --git a/zirgen/circuit/bigint/bigint2c.cpp b/zirgen/circuit/bigint/bigint2c.cpp index 9d3dc59b..17517802 100644 --- a/zirgen/circuit/bigint/bigint2c.cpp +++ b/zirgen/circuit/bigint/bigint2c.cpp @@ -43,6 +43,7 @@ enum class Program { ModPow_65537, EC_Double, EC_Add, + ModAdd, ModMul, }; } // namespace @@ -53,6 +54,7 @@ static cl::opt cl::values(clEnumValN(Program::ModPow_65537, "modpow_65537", "ModPow_65537"), clEnumValN(Program::EC_Double, "ec_double", "EC_Double"), clEnumValN(Program::EC_Add, "ec_add", "EC_Add"), + clEnumValN(Program::ModAdd, "modadd", "ModAdd"), clEnumValN(Program::ModMul, "modmul", "ModMul")), // TODO: Don't hardcode bitwidth cl::Required); @@ -589,6 +591,10 @@ int main(int argc, char* argv[]) { break; case Program::ModMul: genModMul(loc, builder, 256); // TODO: Selectable bitwidth + break; + case Program::ModAdd: + genModAdd(loc, builder, 256); // TODO: Selectable bitwidth + break; } builder.create(loc); From ac6c5ed73f45f91bbf29ff94c10546c31f4ba84b Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Tue, 3 Dec 2024 11:36:03 -0800 Subject: [PATCH 05/34] Build modadd program --- zirgen/circuit/bigint/BUILD.bazel | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/zirgen/circuit/bigint/BUILD.bazel b/zirgen/circuit/bigint/BUILD.bazel index 6432ddb0..388d343e 100644 --- a/zirgen/circuit/bigint/BUILD.bazel +++ b/zirgen/circuit/bigint/BUILD.bazel @@ -82,6 +82,13 @@ genrule( cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=ec_add > $(OUTS)" ) +genrule( + name = "modadd", + outs = ["modadd.blob"], + exec_tools = [":bigint2c"], + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=modadd > $(OUTS)" +) + genrule( name = "modmul", outs = ["modmul.blob"], From 6b7dd1fded6c5eade30abe2362a9091eef26cce0 Mon Sep 17 00:00:00 2001 From: Mars Saxman Date: Tue, 3 Dec 2024 11:51:55 -0800 Subject: [PATCH 06/34] build the bigint blobs when building //zirgen/circuit --- zirgen/circuit/BUILD.bazel | 1 + 1 file changed, 1 insertion(+) diff --git a/zirgen/circuit/BUILD.bazel b/zirgen/circuit/BUILD.bazel index c26ba568..a81f8f2d 100644 --- a/zirgen/circuit/BUILD.bazel +++ b/zirgen/circuit/BUILD.bazel @@ -5,6 +5,7 @@ package( filegroup( name = "circuit", srcs = [ + "//zirgen/circuit/bigint:bigint_blob", "//zirgen/circuit/fib", "//zirgen/circuit/keccak", "//zirgen/circuit/predicates:keccak_zkr", From b7aa908d4d954f311eafd683a1b64ab1ff4ab1bf Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Tue, 3 Dec 2024 12:02:14 -0800 Subject: [PATCH 07/34] Export more bigint programs --- zirgen/circuit/bigint/BUILD.bazel | 35 ++++++++++++++++++++++++++++++ zirgen/circuit/bigint/bigint2c.cpp | 31 +++++++++++++++++++++++--- 2 files changed, 63 insertions(+), 3 deletions(-) diff --git a/zirgen/circuit/bigint/BUILD.bazel b/zirgen/circuit/bigint/BUILD.bazel index 388d343e..fad38bd1 100644 --- a/zirgen/circuit/bigint/BUILD.bazel +++ b/zirgen/circuit/bigint/BUILD.bazel @@ -82,6 +82,27 @@ genrule( cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=ec_add > $(OUTS)" ) +genrule( + name = "extfieldadd", + outs = ["extfieldadd.blob"], + exec_tools = [":bigint2c"], + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfieldadd > $(OUTS)" +) + +genrule( + name = "extfieldmul", + outs = ["extfieldmul.blob"], + exec_tools = [":bigint2c"], + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfieldmul > $(OUTS)" +) + +genrule( + name = "extfieldsub", + outs = ["extfieldsub.blob"], + exec_tools = [":bigint2c"], + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfieldsub > $(OUTS)" +) + genrule( name = "modadd", outs = ["modadd.blob"], @@ -89,6 +110,13 @@ genrule( cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=modadd > $(OUTS)" ) +genrule( + name = "modinv", + outs = ["modinv.blob"], + exec_tools = [":bigint2c"], + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=modinv > $(OUTS)" +) + genrule( name = "modmul", outs = ["modmul.blob"], @@ -96,6 +124,13 @@ genrule( cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=modmul > $(OUTS)" ) +genrule( + name = "modsub", + outs = ["modsub.blob"], + exec_tools = [":bigint2c"], + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=modsub > $(OUTS)" +) + pkg_zip( name = "bigint_blob", srcs = [x + ".blob" for x in BLOBS], diff --git a/zirgen/circuit/bigint/bigint2c.cpp b/zirgen/circuit/bigint/bigint2c.cpp index 17517802..a6845215 100644 --- a/zirgen/circuit/bigint/bigint2c.cpp +++ b/zirgen/circuit/bigint/bigint2c.cpp @@ -43,8 +43,13 @@ enum class Program { ModPow_65537, EC_Double, EC_Add, + ExtFieldAdd, + ExtFieldMul, + ExtFieldSub, ModAdd, + ModInv, ModMul, + ModSub, }; } // namespace @@ -54,8 +59,13 @@ static cl::opt cl::values(clEnumValN(Program::ModPow_65537, "modpow_65537", "ModPow_65537"), clEnumValN(Program::EC_Double, "ec_double", "EC_Double"), clEnumValN(Program::EC_Add, "ec_add", "EC_Add"), + clEnumValN(Program::ExtFieldAdd, "extfieldadd", "ExtFieldAdd"), + clEnumValN(Program::ExtFieldMul, "extfieldmul", "ExtFieldMul"), + clEnumValN(Program::ExtFieldSub, "extfieldsub", "ExtFieldSub"), clEnumValN(Program::ModAdd, "modadd", "ModAdd"), - clEnumValN(Program::ModMul, "modmul", "ModMul")), // TODO: Don't hardcode bitwidth + clEnumValN(Program::ModInv, "modinv", "ModInv"), + clEnumValN(Program::ModMul, "modmul", "ModMul"), + clEnumValN(Program::ModSub, "modsub", "ModSub")), // TODO: Don't hardcode bitwidth cl::Required); const APInt secp256k1_prime = APInt::getAllOnes(256) - APInt::getOneBitSet(256, 32) - @@ -589,12 +599,27 @@ int main(int argc, char* argv[]) { case Program::EC_Add: genECAdd(loc, builder, 256); // TODO: Selectable bitwidth break; - case Program::ModMul: - genModMul(loc, builder, 256); // TODO: Selectable bitwidth + case Program::ExtFieldAdd: // TODO: Naming for degree 2 + genExtFieldAdd(loc, builder, 256, 2); // TODO: Selectable bitwidth + break; + case Program::ExtFieldMul: + genExtFieldMul(loc, builder, 256, 2); // TODO: Selectable bitwidth + break; + case Program::ExtFieldSub: + genExtFieldSub(loc, builder, 256, 2); // TODO: Selectable bitwidth break; case Program::ModAdd: genModAdd(loc, builder, 256); // TODO: Selectable bitwidth break; + case Program::ModInv: + genModInv(loc, builder, 256); // TODO: Selectable bitwidth + break; + case Program::ModMul: + genModMul(loc, builder, 256); // TODO: Selectable bitwidth + break; + case Program::ModSub: + genModSub(loc, builder, 256); // TODO: Selectable bitwidth + break; } builder.create(loc); From 6c1e663a9a2e59fa506c423a2f70dc2be2a24ce8 Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Tue, 3 Dec 2024 13:37:51 -0800 Subject: [PATCH 08/34] Complete new program exports --- zirgen/bootstrap/src/main.rs | 6 ++++++ zirgen/circuit/bigint/BUILD.bazel | 5 +++++ zirgen/circuit/bigint/bigint2c.cpp | 1 + zirgen/circuit/bigint/test/bibc.cpp | 2 +- 4 files changed, 13 insertions(+), 1 deletion(-) diff --git a/zirgen/bootstrap/src/main.rs b/zirgen/bootstrap/src/main.rs index 7e726a08..283454ce 100644 --- a/zirgen/bootstrap/src/main.rs +++ b/zirgen/bootstrap/src/main.rs @@ -593,7 +593,13 @@ impl Bootstrap { let field_path = risc0_root.join("bigint2/src/field"); let rsa_path = risc0_root.join("bigint2/src/rsa"); + self.copy_file(&src_path, &field_path, "extfieldadd.blob"); + self.copy_file(&src_path, &field_path, "extfieldmul.blob"); + self.copy_file(&src_path, &field_path, "extfieldsub.blob"); + self.copy_file(&src_path, &field_path, "modadd.blob"); + self.copy_file(&src_path, &field_path, "modinv.blob"); self.copy_file(&src_path, &field_path, "modmul.blob"); + self.copy_file(&src_path, &field_path, "modsub.blob"); self.copy_file(&src_path, &rsa_path, "modpow_65537.blob"); self.copy( &src_path.join("ec_double.blob"), diff --git a/zirgen/circuit/bigint/BUILD.bazel b/zirgen/circuit/bigint/BUILD.bazel index fad38bd1..af92ada1 100644 --- a/zirgen/circuit/bigint/BUILD.bazel +++ b/zirgen/circuit/bigint/BUILD.bazel @@ -57,8 +57,13 @@ BLOBS = [ "modpow_65537", "ec_double", "ec_add", + "extfieldadd", + "extfieldmul", + "extfieldsub", "modadd", + "modinv", "modmul", + "modsub", ] genrule( diff --git a/zirgen/circuit/bigint/bigint2c.cpp b/zirgen/circuit/bigint/bigint2c.cpp index a6845215..406ce5c6 100644 --- a/zirgen/circuit/bigint/bigint2c.cpp +++ b/zirgen/circuit/bigint/bigint2c.cpp @@ -628,6 +628,7 @@ int main(int argc, char* argv[]) { PassManager pm(&ctx); pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); + pm.addPass(BigInt::createLowerInvPass()); pm.addPass(BigInt::createLowerReducePass()); pm.addPass(createCSEPass()); if (failed(pm.run(module))) { diff --git a/zirgen/circuit/bigint/test/bibc.cpp b/zirgen/circuit/bigint/test/bibc.cpp index 4045c9a1..9c4433f3 100644 --- a/zirgen/circuit/bigint/test/bibc.cpp +++ b/zirgen/circuit/bigint/test/bibc.cpp @@ -66,8 +66,8 @@ void BibcTest::lower() { // Lower the inverse and reduce ops to simpler, executable ops mlir::PassManager pm(ctx); pm.enableVerifier(true); - pm.addPass(zirgen::BigInt::createLowerReducePass()); pm.addPass(zirgen::BigInt::createLowerInvPass()); + pm.addPass(zirgen::BigInt::createLowerReducePass()); if (failed(pm.run(module))) { llvm::errs() << "an internal validation error occurred:\n"; module.print(llvm::errs()); From bbaa189f82aad710a45a4a99e254d154e67cc537 Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Tue, 3 Dec 2024 15:58:07 -0800 Subject: [PATCH 09/34] Move program gen to program files --- zirgen/circuit/bigint/bigint2c.cpp | 59 ++---------------------- zirgen/circuit/bigint/elliptic_curve.cpp | 36 +++++++++++++++ zirgen/circuit/bigint/elliptic_curve.h | 4 ++ zirgen/circuit/bigint/rsa.cpp | 20 ++++++++ zirgen/circuit/bigint/rsa.h | 3 +- 5 files changed, 66 insertions(+), 56 deletions(-) diff --git a/zirgen/circuit/bigint/bigint2c.cpp b/zirgen/circuit/bigint/bigint2c.cpp index 4c2ae0b2..9648934d 100644 --- a/zirgen/circuit/bigint/bigint2c.cpp +++ b/zirgen/circuit/bigint/bigint2c.cpp @@ -30,6 +30,7 @@ #include "zirgen/Dialect/BigInt/IR/BigInt.h" #include "zirgen/Dialect/BigInt/Transforms/Passes.h" #include "zirgen/circuit/bigint/elliptic_curve.h" +#include "zirgen/circuit/bigint/rsa.h" using namespace zirgen; namespace cl = llvm::cl; @@ -399,58 +400,6 @@ std::vector polySplit(mlir::func::FuncOp func) { return flat; } -void genModPow65537(mlir::Location loc, mlir::OpBuilder& builder) { - const size_t bits = 4096; - // Check if (S^e = M (mod N)), where e = 65537 - auto S = builder.create(loc, bits, 11, 0); - auto N = builder.create(loc, bits, 12, 0); - // We square S 16 times to get S^65536 - Value x = S; - for (size_t i = 0; i < 16; i++) { - auto xm = builder.create(loc, x, x); - x = builder.create(loc, xm, N); - } - // Multiply in one more copy of S + reduce - auto xm = builder.create(loc, x, S); - x = builder.create(loc, xm, N); - // this is our result - builder.create(loc, x, 13, 0); -} - -void genECDouble(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { - assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks - size_t chunkwidth = bitwidth / 128; - - auto pt_x = builder.create(loc, bitwidth, 11, 0); - auto pt_y = builder.create(loc, bitwidth, 11, chunkwidth); - auto prime = builder.create(loc, bitwidth, 12, 0, bitwidth - 1); - auto a = builder.create(loc, bitwidth, 12, chunkwidth); - auto b = builder.create(loc, bitwidth, 12, 2 * chunkwidth); - auto curve = std::make_shared(prime, a, b); - auto pt = BigInt::EC::AffinePt(pt_x, pt_y, curve); - auto doubled = BigInt::EC::doub(builder, loc, pt); - builder.create(loc, doubled.x(), 13, 0); - builder.create(loc, doubled.y(), 13, chunkwidth); -} - -void genECAdd(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { - assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks - size_t chunkwidth = bitwidth / 128; - auto p_x = builder.create(loc, bitwidth, 11, 0); - auto p_y = builder.create(loc, bitwidth, 11, chunkwidth); - auto q_x = builder.create(loc, bitwidth, 12, 0); - auto q_y = builder.create(loc, bitwidth, 12, chunkwidth); - auto prime = builder.create(loc, bitwidth, 13, 0, bitwidth - 1); - auto a = builder.create(loc, bitwidth, 13, chunkwidth); - auto b = builder.create(loc, bitwidth, 13, 2 * chunkwidth); - auto curve = std::make_shared(prime, a, b); - auto lhs = BigInt::EC::AffinePt(p_x, p_y, curve); - auto rhs = BigInt::EC::AffinePt(q_x, q_y, curve); - auto result = BigInt::EC::add(builder, loc, lhs, rhs); - builder.create(loc, result.x(), 14, 0); - builder.create(loc, result.y(), 14, chunkwidth); -} - int main(int argc, char* argv[]) { llvm::InitLLVM y(argc, argv); mlir::registerAsmPrinterCLOptions(); @@ -474,13 +423,13 @@ int main(int argc, char* argv[]) { switch (program) { case Program::ModPow_65537: - genModPow65537(loc, builder); + zirgen::BigInt::genModPow65537(loc, builder, 4096); // TODO: Selectable bitwidth break; case Program::EC_Double: - genECDouble(loc, builder, 256); // TODO: Selectable bitwidth + zirgen::BigInt::EC::genECDouble(loc, builder, 256); // TODO: Selectable bitwidth break; case Program::EC_Add: - genECAdd(loc, builder, 256); // TODO: Selectable bitwidth + zirgen::BigInt::EC::genECAdd(loc, builder, 256); // TODO: Selectable bitwidth break; } diff --git a/zirgen/circuit/bigint/elliptic_curve.cpp b/zirgen/circuit/bigint/elliptic_curve.cpp index 75903f34..ab4c1924 100644 --- a/zirgen/circuit/bigint/elliptic_curve.cpp +++ b/zirgen/circuit/bigint/elliptic_curve.cpp @@ -350,6 +350,42 @@ AffinePt sub(OpBuilder builder, Location loc, const AffinePt& lhs, const AffineP return add(builder, loc, lhs, neg_rhs); } +// Full programs, including I/O + +void genECDouble(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { + assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks + size_t chunkwidth = bitwidth / 128; + + auto pt_x = builder.create(loc, bitwidth, 11, 0); + auto pt_y = builder.create(loc, bitwidth, 11, chunkwidth); + auto prime = builder.create(loc, bitwidth, 12, 0, bitwidth - 1); + auto a = builder.create(loc, bitwidth, 12, chunkwidth); + auto b = builder.create(loc, bitwidth, 12, 2 * chunkwidth); + auto curve = std::make_shared(prime, a, b); + auto pt = BigInt::EC::AffinePt(pt_x, pt_y, curve); + auto doubled = BigInt::EC::doub(builder, loc, pt); + builder.create(loc, doubled.x(), 13, 0); + builder.create(loc, doubled.y(), 13, chunkwidth); +} + +void genECAdd(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { + assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks + size_t chunkwidth = bitwidth / 128; + auto p_x = builder.create(loc, bitwidth, 11, 0); + auto p_y = builder.create(loc, bitwidth, 11, chunkwidth); + auto q_x = builder.create(loc, bitwidth, 12, 0); + auto q_y = builder.create(loc, bitwidth, 12, chunkwidth); + auto prime = builder.create(loc, bitwidth, 13, 0, bitwidth - 1); + auto a = builder.create(loc, bitwidth, 13, chunkwidth); + auto b = builder.create(loc, bitwidth, 13, 2 * chunkwidth); + auto curve = std::make_shared(prime, a, b); + auto lhs = BigInt::EC::AffinePt(p_x, p_y, curve); + auto rhs = BigInt::EC::AffinePt(q_x, q_y, curve); + auto result = BigInt::EC::add(builder, loc, lhs, rhs); + builder.create(loc, result.x(), 14, 0); + builder.create(loc, result.y(), 14, chunkwidth); +} + // Test functions void makeECAddTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits) { diff --git a/zirgen/circuit/bigint/elliptic_curve.h b/zirgen/circuit/bigint/elliptic_curve.h index a20805bf..b30bd632 100644 --- a/zirgen/circuit/bigint/elliptic_curve.h +++ b/zirgen/circuit/bigint/elliptic_curve.h @@ -65,6 +65,10 @@ AffinePt mul(OpBuilder builder, Location loc, Value scalar, const AffinePt& pt); AffinePt neg(OpBuilder builder, Location loc, const AffinePt& pt); AffinePt sub(OpBuilder builder, Location loc, const AffinePt& lhs, const AffinePt& rhs); +// Full Programs +void genECAdd(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth); +void genECDouble(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth); + // Test functions void makeECAddTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits); void makeECDoubleTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits); diff --git a/zirgen/circuit/bigint/rsa.cpp b/zirgen/circuit/bigint/rsa.cpp index 986e3a7e..5dc84cda 100644 --- a/zirgen/circuit/bigint/rsa.cpp +++ b/zirgen/circuit/bigint/rsa.cpp @@ -19,6 +19,25 @@ using namespace mlir; namespace zirgen::BigInt { +// TODO: Why don't we have consistent builder/loc order? +void genModPow65537(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { + // Check if (S^e = M (mod N)), where e = 65537 + auto S = builder.create(loc, bitwidth, 11, 0); + auto N = builder.create(loc, bitwidth, 12, 0); + // We square S 16 times to get S^65536 + Value x = S; + for (size_t i = 0; i < 16; i++) { + auto xm = builder.create(loc, x, x); + x = builder.create(loc, xm, N); + } + // Multiply in one more copy of S + reduce + auto xm = builder.create(loc, x, S); + x = builder.create(loc, xm, N); + // this is our result + builder.create(loc, x, 13, 0); +} + +// Used for testing, this RSA code uses `Def` instead of `Load`/`Store` void makeRSA(OpBuilder builder, Location loc, size_t bits) { // Check if (S^e = M (mod N)), where e = 65537 auto N = builder.create(loc, bits, 0, true, bits - 1); @@ -38,6 +57,7 @@ void makeRSA(OpBuilder builder, Location loc, size_t bits) { builder.create(loc, diff); } +// Used for testing, to compute expected outputs. // I verified this by comparing against: // pow(S, 65537, N) in python APInt RSA(APInt N, APInt S) { diff --git a/zirgen/circuit/bigint/rsa.h b/zirgen/circuit/bigint/rsa.h index fd33531e..f1212681 100644 --- a/zirgen/circuit/bigint/rsa.h +++ b/zirgen/circuit/bigint/rsa.h @@ -19,8 +19,9 @@ namespace zirgen::BigInt { +void genModPow65537(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth); +// TODO: Clarify this name void makeRSA(mlir::OpBuilder builder, mlir::Location loc, size_t bits); - llvm::APInt RSA(llvm::APInt N, llvm::APInt S); } // namespace zirgen::BigInt From 2c1024ec311aef3894f65d79fcf76d2e57fc411c Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Wed, 4 Dec 2024 16:37:09 -0800 Subject: [PATCH 10/34] Make builder, loc order more consistent --- zirgen/circuit/bigint/bigint2c.cpp | 6 ++--- zirgen/circuit/bigint/elliptic_curve.cpp | 34 ++++++++++++------------ zirgen/circuit/bigint/elliptic_curve.h | 4 +-- zirgen/circuit/bigint/rsa.cpp | 2 +- zirgen/circuit/bigint/rsa.h | 2 +- 5 files changed, 24 insertions(+), 24 deletions(-) diff --git a/zirgen/circuit/bigint/bigint2c.cpp b/zirgen/circuit/bigint/bigint2c.cpp index 9648934d..be4fc8d6 100644 --- a/zirgen/circuit/bigint/bigint2c.cpp +++ b/zirgen/circuit/bigint/bigint2c.cpp @@ -423,13 +423,13 @@ int main(int argc, char* argv[]) { switch (program) { case Program::ModPow_65537: - zirgen::BigInt::genModPow65537(loc, builder, 4096); // TODO: Selectable bitwidth + zirgen::BigInt::genModPow65537(builder, loc, 4096); // TODO: Selectable bitwidth break; case Program::EC_Double: - zirgen::BigInt::EC::genECDouble(loc, builder, 256); // TODO: Selectable bitwidth + zirgen::BigInt::EC::genECDouble(builder, loc, 256); // TODO: Selectable bitwidth break; case Program::EC_Add: - zirgen::BigInt::EC::genECAdd(loc, builder, 256); // TODO: Selectable bitwidth + zirgen::BigInt::EC::genECAdd(builder, loc, 256); // TODO: Selectable bitwidth break; } diff --git a/zirgen/circuit/bigint/elliptic_curve.cpp b/zirgen/circuit/bigint/elliptic_curve.cpp index ab4c1924..54fec74b 100644 --- a/zirgen/circuit/bigint/elliptic_curve.cpp +++ b/zirgen/circuit/bigint/elliptic_curve.cpp @@ -352,23 +352,7 @@ AffinePt sub(OpBuilder builder, Location loc, const AffinePt& lhs, const AffineP // Full programs, including I/O -void genECDouble(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { - assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks - size_t chunkwidth = bitwidth / 128; - - auto pt_x = builder.create(loc, bitwidth, 11, 0); - auto pt_y = builder.create(loc, bitwidth, 11, chunkwidth); - auto prime = builder.create(loc, bitwidth, 12, 0, bitwidth - 1); - auto a = builder.create(loc, bitwidth, 12, chunkwidth); - auto b = builder.create(loc, bitwidth, 12, 2 * chunkwidth); - auto curve = std::make_shared(prime, a, b); - auto pt = BigInt::EC::AffinePt(pt_x, pt_y, curve); - auto doubled = BigInt::EC::doub(builder, loc, pt); - builder.create(loc, doubled.x(), 13, 0); - builder.create(loc, doubled.y(), 13, chunkwidth); -} - -void genECAdd(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { +void genECAdd(mlir::OpBuilder& builder, mlir::Location loc, size_t bitwidth) { assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks size_t chunkwidth = bitwidth / 128; auto p_x = builder.create(loc, bitwidth, 11, 0); @@ -386,6 +370,22 @@ void genECAdd(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { builder.create(loc, result.y(), 14, chunkwidth); } +void genECDouble(mlir::OpBuilder& builder, mlir::Location loc, size_t bitwidth) { + assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks + size_t chunkwidth = bitwidth / 128; + + auto pt_x = builder.create(loc, bitwidth, 11, 0); + auto pt_y = builder.create(loc, bitwidth, 11, chunkwidth); + auto prime = builder.create(loc, bitwidth, 12, 0, bitwidth - 1); + auto a = builder.create(loc, bitwidth, 12, chunkwidth); + auto b = builder.create(loc, bitwidth, 12, 2 * chunkwidth); + auto curve = std::make_shared(prime, a, b); + auto pt = BigInt::EC::AffinePt(pt_x, pt_y, curve); + auto doubled = BigInt::EC::doub(builder, loc, pt); + builder.create(loc, doubled.x(), 13, 0); + builder.create(loc, doubled.y(), 13, chunkwidth); +} + // Test functions void makeECAddTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits) { diff --git a/zirgen/circuit/bigint/elliptic_curve.h b/zirgen/circuit/bigint/elliptic_curve.h index b30bd632..0a4f309e 100644 --- a/zirgen/circuit/bigint/elliptic_curve.h +++ b/zirgen/circuit/bigint/elliptic_curve.h @@ -66,8 +66,8 @@ AffinePt neg(OpBuilder builder, Location loc, const AffinePt& pt); AffinePt sub(OpBuilder builder, Location loc, const AffinePt& lhs, const AffinePt& rhs); // Full Programs -void genECAdd(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth); -void genECDouble(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth); +void genECAdd(mlir::OpBuilder& builder, mlir::Location loc, size_t bitwidth); +void genECDouble(mlir::OpBuilder& builder, mlir::Location loc, size_t bitwidth); // Test functions void makeECAddTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits); diff --git a/zirgen/circuit/bigint/rsa.cpp b/zirgen/circuit/bigint/rsa.cpp index 5dc84cda..ffe6c91c 100644 --- a/zirgen/circuit/bigint/rsa.cpp +++ b/zirgen/circuit/bigint/rsa.cpp @@ -20,7 +20,7 @@ using namespace mlir; namespace zirgen::BigInt { // TODO: Why don't we have consistent builder/loc order? -void genModPow65537(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { +void genModPow65537(mlir::OpBuilder& builder, mlir::Location loc, size_t bitwidth) { // Check if (S^e = M (mod N)), where e = 65537 auto S = builder.create(loc, bitwidth, 11, 0); auto N = builder.create(loc, bitwidth, 12, 0); diff --git a/zirgen/circuit/bigint/rsa.h b/zirgen/circuit/bigint/rsa.h index f1212681..937215e3 100644 --- a/zirgen/circuit/bigint/rsa.h +++ b/zirgen/circuit/bigint/rsa.h @@ -19,7 +19,7 @@ namespace zirgen::BigInt { -void genModPow65537(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth); +void genModPow65537(mlir::OpBuilder& builder, mlir::Location loc, size_t bitwidth); // TODO: Clarify this name void makeRSA(mlir::OpBuilder builder, mlir::Location loc, size_t bits); llvm::APInt RSA(llvm::APInt N, llvm::APInt S); From 25c1c9dc5708772d9a4b58610b24df71537b54b3 Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Thu, 5 Dec 2024 11:51:29 -0800 Subject: [PATCH 11/34] Parameterize bitwidth on circuit generation --- zirgen/bootstrap/src/main.rs | 9 +++------ zirgen/circuit/bigint/BUILD.bazel | 24 ++++++++++++------------ zirgen/circuit/bigint/bigint2c.cpp | 18 ++++++++++++------ 3 files changed, 27 insertions(+), 24 deletions(-) diff --git a/zirgen/bootstrap/src/main.rs b/zirgen/bootstrap/src/main.rs index 8940fddf..1dc9898d 100644 --- a/zirgen/bootstrap/src/main.rs +++ b/zirgen/bootstrap/src/main.rs @@ -592,12 +592,9 @@ impl Bootstrap { let rsa_path = risc0_root.join("bigint2/src/rsa"); let ec_path = risc0_root.join("bigint2/src/ec"); - self.copy_file(&src_path, &rsa_path, "modpow_65537.blob"); - self.copy( - &src_path.join("ec_double.blob"), - &ec_path.join("double.blob"), - ); - self.copy(&src_path.join("ec_add.blob"), &ec_path.join("add.blob")); + self.copy_file(&src_path, &rsa_path, "modpow65537_4096.blob"); + self.copy_file(&src_path, &ec_path, "ec_add_256.blob"); + self.copy_file(&src_path, &ec_path, "ec_double_256.blob"); } } diff --git a/zirgen/circuit/bigint/BUILD.bazel b/zirgen/circuit/bigint/BUILD.bazel index 6da0d27d..840072a3 100644 --- a/zirgen/circuit/bigint/BUILD.bazel +++ b/zirgen/circuit/bigint/BUILD.bazel @@ -52,30 +52,30 @@ cc_binary( ) BLOBS = [ - "modpow_65537", - "ec_double", - "ec_add", + "modpow65537_4096", + "ec_double_256", + "ec_add_256", ] genrule( - name = "modpow_65537", - outs = ["modpow_65537.blob"], + name = "modpow65537_4096", + outs = ["modpow65537_4096.blob"], exec_tools = [":bigint2c"], - cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=modpow_65537 > $(OUTS)" + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=modpow65537 --bitwidth 4096 > $(OUTS)" ) genrule( - name = "ec_double", - outs = ["ec_double.blob"], + name = "ec_double_256", + outs = ["ec_double_256.blob"], exec_tools = [":bigint2c"], - cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=ec_double > $(OUTS)" + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=ec_double --bitwidth 256 > $(OUTS)" ) genrule( - name = "ec_add", - outs = ["ec_add.blob"], + name = "ec_add_256", + outs = ["ec_add_256.blob"], exec_tools = [":bigint2c"], - cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=ec_add > $(OUTS)" + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=ec_add --bitwidth 256 > $(OUTS)" ) pkg_zip( diff --git a/zirgen/circuit/bigint/bigint2c.cpp b/zirgen/circuit/bigint/bigint2c.cpp index be4fc8d6..afaba099 100644 --- a/zirgen/circuit/bigint/bigint2c.cpp +++ b/zirgen/circuit/bigint/bigint2c.cpp @@ -40,7 +40,7 @@ cl::opt namespace { enum class Program { - ModPow_65537, + ModPow65537, EC_Double, EC_Add, }; @@ -49,11 +49,17 @@ enum class Program { static cl::opt program("program", cl::desc("The program to compile"), - cl::values(clEnumValN(Program::ModPow_65537, "modpow_65537", "ModPow_65537"), + cl::values(clEnumValN(Program::ModPow65537, "modpow65537", "ModPow65537"), clEnumValN(Program::EC_Double, "ec_double", "EC_Double"), clEnumValN(Program::EC_Add, "ec_add", "EC_Add")), cl::Required); +static cl::opt + bitwidth("bitwidth", + cl::desc("The bitwidth of program parameters"), + cl::value_desc("bitwidth"), + cl::Required); + const APInt secp256k1_prime = APInt::getAllOnes(256) - APInt::getOneBitSet(256, 32) - APInt::getOneBitSet(256, 9) - APInt::getOneBitSet(256, 8) - APInt::getOneBitSet(256, 7) - APInt::getOneBitSet(256, 6) - @@ -422,14 +428,14 @@ int main(int argc, char* argv[]) { builder.setInsertionPointToStart(func.addEntryBlock()); switch (program) { - case Program::ModPow_65537: - zirgen::BigInt::genModPow65537(builder, loc, 4096); // TODO: Selectable bitwidth + case Program::ModPow65537: + zirgen::BigInt::genModPow65537(builder, loc, bitwidth); break; case Program::EC_Double: - zirgen::BigInt::EC::genECDouble(builder, loc, 256); // TODO: Selectable bitwidth + zirgen::BigInt::EC::genECDouble(builder, loc, bitwidth); break; case Program::EC_Add: - zirgen::BigInt::EC::genECAdd(builder, loc, 256); // TODO: Selectable bitwidth + zirgen::BigInt::EC::genECAdd(builder, loc, bitwidth); break; } From 6fe1310c6e24f8a0ce60ca5d8292ffc7345ba415 Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Thu, 5 Dec 2024 12:07:11 -0800 Subject: [PATCH 12/34] Remove TODO comment for cleaned code --- zirgen/circuit/bigint/rsa.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/zirgen/circuit/bigint/rsa.cpp b/zirgen/circuit/bigint/rsa.cpp index ffe6c91c..fc6c39b9 100644 --- a/zirgen/circuit/bigint/rsa.cpp +++ b/zirgen/circuit/bigint/rsa.cpp @@ -19,7 +19,6 @@ using namespace mlir; namespace zirgen::BigInt { -// TODO: Why don't we have consistent builder/loc order? void genModPow65537(mlir::OpBuilder& builder, mlir::Location loc, size_t bitwidth) { // Check if (S^e = M (mod N)), where e = 65537 auto S = builder.create(loc, bitwidth, 11, 0); From 443dd192562615b013d7e3808f441e6dcb72df8e Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Thu, 5 Dec 2024 12:49:45 -0800 Subject: [PATCH 13/34] Naming --- zirgen/Dialect/BigInt/IR/test/test.cpp | 2 +- zirgen/circuit/bigint/rsa.cpp | 2 +- zirgen/circuit/bigint/rsa.h | 4 ++-- zirgen/circuit/bigint/test/rsa.cpp | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/zirgen/Dialect/BigInt/IR/test/test.cpp b/zirgen/Dialect/BigInt/IR/test/test.cpp index 9b51afaa..e8b9dfc6 100644 --- a/zirgen/Dialect/BigInt/IR/test/test.cpp +++ b/zirgen/Dialect/BigInt/IR/test/test.cpp @@ -152,7 +152,7 @@ int main(int argc, const char** argv) { builder.setInsertionPointToEnd(&inModule.getBodyRegion().front()); auto inFunc = builder.create(loc, "main", FunctionType::get(&context, {}, {})); builder.setInsertionPointToEnd(inFunc.addEntryBlock()); - makeRSA(builder, loc, numBits); + makeRSAChecker(builder, loc, numBits); builder.create(loc); PassManager pm(&context); diff --git a/zirgen/circuit/bigint/rsa.cpp b/zirgen/circuit/bigint/rsa.cpp index fc6c39b9..5ac9bab3 100644 --- a/zirgen/circuit/bigint/rsa.cpp +++ b/zirgen/circuit/bigint/rsa.cpp @@ -37,7 +37,7 @@ void genModPow65537(mlir::OpBuilder& builder, mlir::Location loc, size_t bitwidt } // Used for testing, this RSA code uses `Def` instead of `Load`/`Store` -void makeRSA(OpBuilder builder, Location loc, size_t bits) { +void makeRSAChecker(OpBuilder builder, Location loc, size_t bits) { // Check if (S^e = M (mod N)), where e = 65537 auto N = builder.create(loc, bits, 0, true, bits - 1); auto S = builder.create(loc, bits, 1, true); diff --git a/zirgen/circuit/bigint/rsa.h b/zirgen/circuit/bigint/rsa.h index 937215e3..06887596 100644 --- a/zirgen/circuit/bigint/rsa.h +++ b/zirgen/circuit/bigint/rsa.h @@ -20,8 +20,8 @@ namespace zirgen::BigInt { void genModPow65537(mlir::OpBuilder& builder, mlir::Location loc, size_t bitwidth); -// TODO: Clarify this name -void makeRSA(mlir::OpBuilder builder, mlir::Location loc, size_t bits); +// TODO: Unify our tests so we don't need separate codepaths for the RSA versions with & without Loads & Stores +void makeRSAChecker(mlir::OpBuilder builder, mlir::Location loc, size_t bits); llvm::APInt RSA(llvm::APInt N, llvm::APInt S); } // namespace zirgen::BigInt diff --git a/zirgen/circuit/bigint/test/rsa.cpp b/zirgen/circuit/bigint/test/rsa.cpp index 41af8ccc..719bbe4b 100644 --- a/zirgen/circuit/bigint/test/rsa.cpp +++ b/zirgen/circuit/bigint/test/rsa.cpp @@ -23,7 +23,7 @@ using namespace zirgen::BigInt::test; TEST_F(BibcTest, RSA256) { mlir::OpBuilder builder(ctx); auto func = makeFunc("rsa_256", builder); - BigInt::makeRSA(builder, func.getLoc(), 256); + BigInt::makeRSAChecker(builder, func.getLoc(), 256); lower(); llvm::APInt N(64, 101); @@ -40,7 +40,7 @@ TEST_F(BibcTest, RSA256) { TEST_F(BibcTest, RSA3072) { mlir::OpBuilder builder(ctx); auto func = makeFunc("rsa_3072", builder); - BigInt::makeRSA(builder, func.getLoc(), 3072); + BigInt::makeRSAChecker(builder, func.getLoc(), 3072); lower(); llvm::APInt N(64, 22764235167642101); From cddf63dd7501672583f05b9051ed918afbd1b0b3 Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Thu, 5 Dec 2024 12:53:56 -0800 Subject: [PATCH 14/34] Format --- zirgen/circuit/bigint/bigint2c.cpp | 9 ++++----- zirgen/circuit/bigint/rsa.h | 3 ++- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/zirgen/circuit/bigint/bigint2c.cpp b/zirgen/circuit/bigint/bigint2c.cpp index afaba099..384a86f3 100644 --- a/zirgen/circuit/bigint/bigint2c.cpp +++ b/zirgen/circuit/bigint/bigint2c.cpp @@ -54,11 +54,10 @@ static cl::opt clEnumValN(Program::EC_Add, "ec_add", "EC_Add")), cl::Required); -static cl::opt - bitwidth("bitwidth", - cl::desc("The bitwidth of program parameters"), - cl::value_desc("bitwidth"), - cl::Required); +static cl::opt bitwidth("bitwidth", + cl::desc("The bitwidth of program parameters"), + cl::value_desc("bitwidth"), + cl::Required); const APInt secp256k1_prime = APInt::getAllOnes(256) - APInt::getOneBitSet(256, 32) - APInt::getOneBitSet(256, 9) - APInt::getOneBitSet(256, 8) - diff --git a/zirgen/circuit/bigint/rsa.h b/zirgen/circuit/bigint/rsa.h index 06887596..3728ac35 100644 --- a/zirgen/circuit/bigint/rsa.h +++ b/zirgen/circuit/bigint/rsa.h @@ -20,7 +20,8 @@ namespace zirgen::BigInt { void genModPow65537(mlir::OpBuilder& builder, mlir::Location loc, size_t bitwidth); -// TODO: Unify our tests so we don't need separate codepaths for the RSA versions with & without Loads & Stores +// TODO: Unify our tests so we don't need separate codepaths for the RSA versions with & without +// Loads & Stores void makeRSAChecker(mlir::OpBuilder builder, mlir::Location loc, size_t bits); llvm::APInt RSA(llvm::APInt N, llvm::APInt S); From a9637b63a60f98cb9dc9f077812db3b2c3776335 Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Thu, 5 Dec 2024 14:17:47 -0800 Subject: [PATCH 15/34] Parameterize bitwidths on the new ops --- zirgen/bootstrap/src/main.rs | 14 ++++---- zirgen/circuit/bigint/BUILD.bazel | 56 +++++++++++++++--------------- zirgen/circuit/bigint/bigint2c.cpp | 14 ++++---- 3 files changed, 42 insertions(+), 42 deletions(-) diff --git a/zirgen/bootstrap/src/main.rs b/zirgen/bootstrap/src/main.rs index 6f587a99..7396506a 100644 --- a/zirgen/bootstrap/src/main.rs +++ b/zirgen/bootstrap/src/main.rs @@ -574,13 +574,13 @@ impl Bootstrap { let rsa_path = risc0_root.join("bigint2/src/rsa"); // TODO: Bitwidths on field ops - self.copy_file(&src_path, &field_path, "extfieldadd.blob"); - self.copy_file(&src_path, &field_path, "extfieldmul.blob"); - self.copy_file(&src_path, &field_path, "extfieldsub.blob"); - self.copy_file(&src_path, &field_path, "modadd.blob"); - self.copy_file(&src_path, &field_path, "modinv.blob"); - self.copy_file(&src_path, &field_path, "modmul.blob"); - self.copy_file(&src_path, &field_path, "modsub.blob"); + self.copy_file(&src_path, &field_path, "extfieldadd_256.blob"); + self.copy_file(&src_path, &field_path, "extfieldmul_256.blob"); + self.copy_file(&src_path, &field_path, "extfieldsub_256.blob"); + self.copy_file(&src_path, &field_path, "modadd_256.blob"); + self.copy_file(&src_path, &field_path, "modinv_256.blob"); + self.copy_file(&src_path, &field_path, "modmul_256.blob"); + self.copy_file(&src_path, &field_path, "modsub_256.blob"); self.copy_file(&src_path, &rsa_path, "modpow65537_4096.blob"); self.copy_file(&src_path, &ec_path, "ec_add_256.blob"); self.copy_file(&src_path, &ec_path, "ec_double_256.blob"); diff --git a/zirgen/circuit/bigint/BUILD.bazel b/zirgen/circuit/bigint/BUILD.bazel index 550906d9..2e2be417 100644 --- a/zirgen/circuit/bigint/BUILD.bazel +++ b/zirgen/circuit/bigint/BUILD.bazel @@ -58,13 +58,13 @@ BLOBS = [ "modpow65537_4096", "ec_double_256", "ec_add_256", - "extfieldadd", - "extfieldmul", - "extfieldsub", - "modadd", - "modinv", - "modmul", - "modsub", + "extfieldadd_256", + "extfieldmul_256", + "extfieldsub_256", + "modadd_256", + "modinv_256", + "modmul_256", + "modsub_256", ] genrule( @@ -89,52 +89,52 @@ genrule( ) genrule( - name = "extfieldadd", - outs = ["extfieldadd.blob"], + name = "extfieldadd_256", + outs = ["extfieldadd_256.blob"], exec_tools = [":bigint2c"], - cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfieldadd > $(OUTS)" + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfieldadd --bitwidth 256 > $(OUTS)" ) genrule( - name = "extfieldmul", - outs = ["extfieldmul.blob"], + name = "extfieldmul_256", + outs = ["extfieldmul_256.blob"], exec_tools = [":bigint2c"], - cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfieldmul > $(OUTS)" + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfieldmul --bitwidth 256 > $(OUTS)" ) genrule( - name = "extfieldsub", - outs = ["extfieldsub.blob"], + name = "extfieldsub_256", + outs = ["extfieldsub_256.blob"], exec_tools = [":bigint2c"], - cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfieldsub > $(OUTS)" + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfieldsub --bitwidth 256 > $(OUTS)" ) genrule( - name = "modadd", - outs = ["modadd.blob"], + name = "modadd_256", + outs = ["modadd_256.blob"], exec_tools = [":bigint2c"], - cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=modadd > $(OUTS)" + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=modadd --bitwidth 256 > $(OUTS)" ) genrule( - name = "modinv", - outs = ["modinv.blob"], + name = "modinv_256", + outs = ["modinv_256.blob"], exec_tools = [":bigint2c"], - cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=modinv > $(OUTS)" + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=modinv --bitwidth 256 > $(OUTS)" ) genrule( - name = "modmul", - outs = ["modmul.blob"], + name = "modmul_256", + outs = ["modmul_256.blob"], exec_tools = [":bigint2c"], - cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=modmul > $(OUTS)" + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=modmul --bitwidth 256 > $(OUTS)" ) genrule( - name = "modsub", - outs = ["modsub.blob"], + name = "modsub_256", + outs = ["modsub_256.blob"], exec_tools = [":bigint2c"], - cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=modsub > $(OUTS)" + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=modsub --bitwidth 256 > $(OUTS)" ) pkg_zip( diff --git a/zirgen/circuit/bigint/bigint2c.cpp b/zirgen/circuit/bigint/bigint2c.cpp index 0c2be96a..90c32a98 100644 --- a/zirgen/circuit/bigint/bigint2c.cpp +++ b/zirgen/circuit/bigint/bigint2c.cpp @@ -554,25 +554,25 @@ int main(int argc, char* argv[]) { zirgen::BigInt::EC::genECAdd(builder, loc, bitwidth); break; case Program::ExtFieldAdd: // TODO: Naming for degree 2 - genExtFieldAdd(loc, builder, 256, 2); // TODO: Selectable bitwidth + genExtFieldAdd(loc, builder, bitwidth, 2); break; case Program::ExtFieldMul: - genExtFieldMul(loc, builder, 256, 2); // TODO: Selectable bitwidth + genExtFieldMul(loc, builder, bitwidth, 2); break; case Program::ExtFieldSub: - genExtFieldSub(loc, builder, 256, 2); // TODO: Selectable bitwidth + genExtFieldSub(loc, builder, bitwidth, 2); break; case Program::ModAdd: - genModAdd(loc, builder, 256); // TODO: Selectable bitwidth + genModAdd(loc, builder, bitwidth); break; case Program::ModInv: - genModInv(loc, builder, 256); // TODO: Selectable bitwidth + genModInv(loc, builder, bitwidth); break; case Program::ModMul: - genModMul(loc, builder, 256); // TODO: Selectable bitwidth + genModMul(loc, builder, bitwidth); break; case Program::ModSub: - genModSub(loc, builder, 256); // TODO: Selectable bitwidth + genModSub(loc, builder, bitwidth); break; } From df85f0ebb13f7e9ce9115e8ba697eab61f65701a Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Thu, 5 Dec 2024 14:33:53 -0800 Subject: [PATCH 16/34] Move field gen ops into their file --- zirgen/circuit/bigint/bigint2c.cpp | 116 ++--------------------------- zirgen/circuit/bigint/field.cpp | 102 +++++++++++++++++++++++++ zirgen/circuit/bigint/field.h | 23 ++++++ 3 files changed, 132 insertions(+), 109 deletions(-) diff --git a/zirgen/circuit/bigint/bigint2c.cpp b/zirgen/circuit/bigint/bigint2c.cpp index 90c32a98..7a059dab 100644 --- a/zirgen/circuit/bigint/bigint2c.cpp +++ b/zirgen/circuit/bigint/bigint2c.cpp @@ -420,108 +420,6 @@ std::vector polySplit(mlir::func::FuncOp func) { return flat; } -// TODO: Move & cleanup finite field changes - -// Finite Field arithmetic -// -// These functions accelerate finite field arithmetic -// - The `Mod` versions are for prime order fields -// - The `FieldExt` versions are for simple extensions -// - Every finite extension of a finite field is simple, so in a sense this covers every finite -// field, but to use these functions you must represent the extension as the adjunction of a -// primitive element to a prime order field, which is not always convenient (i.e. when you have -// a tower of extensions) -// -// We do not use integer quotients in these functions, so minBits does not give us performance gains -// and we therefore do not require the prime to be full bitwidth, enabling simpler generalization -// (i.e., there's no need to make sure the bitwidth is minimal for your use case) - -void genModAdd(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { - auto lhs = builder.create(loc, bitwidth, 11, 0); - auto rhs = builder.create(loc, bitwidth, 12, 0); - auto prime = builder.create(loc, bitwidth, 13, 0); - auto result = BigInt::field::modAdd(builder, loc, lhs, rhs, prime); - builder.create(loc, result, 14, 0); -} - -void genModInv(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { - auto inp = builder.create(loc, bitwidth, 11, 0); - auto prime = builder.create(loc, bitwidth, 12, 0); - auto result = BigInt::field::modInv(builder, loc, inp, prime); - builder.create(loc, result, 13, 0); -} - -void genModMul(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { - auto lhs = builder.create(loc, bitwidth, 11, 0); - auto rhs = builder.create(loc, bitwidth, 12, 0); - auto prime = builder.create(loc, bitwidth, 13, 0); - auto result = BigInt::field::modMul(builder, loc, lhs, rhs, prime); - builder.create(loc, result, 14, 0); -} - -void genModSub(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth) { - auto lhs = builder.create(loc, bitwidth, 11, 0); - auto rhs = builder.create(loc, bitwidth, 12, 0); - auto prime = builder.create(loc, bitwidth, 13, 0); - auto result = BigInt::field::modSub(builder, loc, lhs, rhs, prime); - builder.create(loc, result, 14, 0); -} - -// Extension fields we use are most commonly degree 2 -// TODO: ^ Hence the use of 2 in the SmallVectors ... but is this true? -void genExtFieldAdd(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth, size_t degree) { - // TODO: will need to handle bitwidth slightly smaller than data chunks - assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks - size_t chunkwidth = bitwidth / 128; - llvm::SmallVector lhs(degree); - llvm::SmallVector rhs(degree); - for (size_t i = 0; i < degree; i++) { - lhs[i] = builder.create(loc, bitwidth, 11, i * chunkwidth); - rhs[i] = builder.create(loc, bitwidth, 12, i * chunkwidth); - } - auto prime = builder.create(loc, bitwidth, 13, 0); - auto result = BigInt::field::extAdd(builder, loc, lhs, rhs, prime); - for (size_t i = 0; i < degree; i++) { - builder.create(loc, result[i], 14, i * chunkwidth); - } -} - -void genExtFieldMul(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth, size_t degree) { - // TODO: will need to handle bitwidth slightly smaller than data chunks - assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks - size_t chunkwidth = bitwidth / 128; - llvm::SmallVector lhs(degree); - llvm::SmallVector rhs(degree); - llvm::SmallVector monic_irred_poly(degree); - for (size_t i = 0; i < degree; i++) { - lhs[i] = builder.create(loc, bitwidth, 11, i * chunkwidth); - rhs[i] = builder.create(loc, bitwidth, 12, i * chunkwidth); - monic_irred_poly[i] = builder.create(loc, bitwidth, 13, i * chunkwidth); - } - auto prime = builder.create(loc, bitwidth, 14, 0); - auto result = BigInt::field::extMul(builder, loc, lhs, rhs, prime, monic_irred_poly); - for (size_t i = 0; i < degree; i++) { - builder.create(loc, result[i], 15, i * chunkwidth); - } -} - -void genExtFieldSub(mlir::Location loc, mlir::OpBuilder& builder, size_t bitwidth, size_t degree) { - // TODO: will need to handle bitwidth slightly smaller than data chunks - assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks - size_t chunkwidth = bitwidth / 128; - llvm::SmallVector lhs(degree); - llvm::SmallVector rhs(degree); - for (size_t i = 0; i < degree; i++) { - lhs[i] = builder.create(loc, bitwidth, 11, i * chunkwidth); - rhs[i] = builder.create(loc, bitwidth, 12, i * chunkwidth); - } - auto prime = builder.create(loc, bitwidth, 13, 0); - auto result = BigInt::field::extSub(builder, loc, lhs, rhs, prime); - for (size_t i = 0; i < degree; i++) { - builder.create(loc, result[i], 14, i * chunkwidth); - } -} - int main(int argc, char* argv[]) { llvm::InitLLVM y(argc, argv); mlir::registerAsmPrinterCLOptions(); @@ -554,25 +452,25 @@ int main(int argc, char* argv[]) { zirgen::BigInt::EC::genECAdd(builder, loc, bitwidth); break; case Program::ExtFieldAdd: // TODO: Naming for degree 2 - genExtFieldAdd(loc, builder, bitwidth, 2); + zirgen::BigInt::field::genExtFieldAdd(builder, loc, bitwidth, 2); break; case Program::ExtFieldMul: - genExtFieldMul(loc, builder, bitwidth, 2); + zirgen::BigInt::field::genExtFieldMul(builder, loc, bitwidth, 2); break; case Program::ExtFieldSub: - genExtFieldSub(loc, builder, bitwidth, 2); + zirgen::BigInt::field::genExtFieldSub(builder, loc, bitwidth, 2); break; case Program::ModAdd: - genModAdd(loc, builder, bitwidth); + zirgen::BigInt::field::genModAdd(builder, loc, bitwidth); break; case Program::ModInv: - genModInv(loc, builder, bitwidth); + zirgen::BigInt::field::genModInv(builder, loc, bitwidth); break; case Program::ModMul: - genModMul(loc, builder, bitwidth); + zirgen::BigInt::field::genModMul(builder, loc, bitwidth); break; case Program::ModSub: - genModSub(loc, builder, bitwidth); + zirgen::BigInt::field::genModSub(builder, loc, bitwidth); break; } diff --git a/zirgen/circuit/bigint/field.cpp b/zirgen/circuit/bigint/field.cpp index 732cea12..b1d84220 100644 --- a/zirgen/circuit/bigint/field.cpp +++ b/zirgen/circuit/bigint/field.cpp @@ -111,6 +111,108 @@ llvm::SmallVector extSub(mlir::OpBuilder builder, mlir::Location loc, return result; } +// Full programs, including I/O +// TODO: Move & cleanup finite field changes + +// Finite Field arithmetic +// +// These functions accelerate finite field arithmetic +// - The `Mod` versions are for prime order fields +// - The `FieldExt` versions are for simple extensions +// - Every finite extension of a finite field is simple, so in a sense this covers every finite +// field, but to use these functions you must represent the extension as the adjunction of a +// primitive element to a prime order field, which is not always convenient (i.e. when you have +// a tower of extensions) +// +// We do not use integer quotients in these functions, so minBits does not give us performance gains +// and we therefore do not require the prime to be full bitwidth, enabling simpler generalization +// (i.e., there's no need to make sure the bitwidth is minimal for your use case) + +void genModAdd(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth) { + auto lhs = builder.create(loc, bitwidth, 11, 0); + auto rhs = builder.create(loc, bitwidth, 12, 0); + auto prime = builder.create(loc, bitwidth, 13, 0); + auto result = BigInt::field::modAdd(builder, loc, lhs, rhs, prime); + builder.create(loc, result, 14, 0); +} + +void genModInv(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth) { + auto inp = builder.create(loc, bitwidth, 11, 0); + auto prime = builder.create(loc, bitwidth, 12, 0); + auto result = BigInt::field::modInv(builder, loc, inp, prime); + builder.create(loc, result, 13, 0); +} + +void genModMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth) { + auto lhs = builder.create(loc, bitwidth, 11, 0); + auto rhs = builder.create(loc, bitwidth, 12, 0); + auto prime = builder.create(loc, bitwidth, 13, 0); + auto result = BigInt::field::modMul(builder, loc, lhs, rhs, prime); + builder.create(loc, result, 14, 0); +} + +void genModSub(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth) { + auto lhs = builder.create(loc, bitwidth, 11, 0); + auto rhs = builder.create(loc, bitwidth, 12, 0); + auto prime = builder.create(loc, bitwidth, 13, 0); + auto result = BigInt::field::modSub(builder, loc, lhs, rhs, prime); + builder.create(loc, result, 14, 0); +} + +// Extension fields we use are most commonly degree 2 +// TODO: ^ Hence the use of 2 in the SmallVectors ... but is this true? +void genExtFieldAdd(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree) { + // TODO: will need to handle bitwidth slightly smaller than data chunks + assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks + size_t chunkwidth = bitwidth / 128; + llvm::SmallVector lhs(degree); + llvm::SmallVector rhs(degree); + for (size_t i = 0; i < degree; i++) { + lhs[i] = builder.create(loc, bitwidth, 11, i * chunkwidth); + rhs[i] = builder.create(loc, bitwidth, 12, i * chunkwidth); + } + auto prime = builder.create(loc, bitwidth, 13, 0); + auto result = BigInt::field::extAdd(builder, loc, lhs, rhs, prime); + for (size_t i = 0; i < degree; i++) { + builder.create(loc, result[i], 14, i * chunkwidth); + } +} + +void genExtFieldMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree) { + // TODO: will need to handle bitwidth slightly smaller than data chunks + assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks + size_t chunkwidth = bitwidth / 128; + llvm::SmallVector lhs(degree); + llvm::SmallVector rhs(degree); + llvm::SmallVector monic_irred_poly(degree); + for (size_t i = 0; i < degree; i++) { + lhs[i] = builder.create(loc, bitwidth, 11, i * chunkwidth); + rhs[i] = builder.create(loc, bitwidth, 12, i * chunkwidth); + monic_irred_poly[i] = builder.create(loc, bitwidth, 13, i * chunkwidth); + } + auto prime = builder.create(loc, bitwidth, 14, 0); + auto result = BigInt::field::extMul(builder, loc, lhs, rhs, prime, monic_irred_poly); + for (size_t i = 0; i < degree; i++) { + builder.create(loc, result[i], 15, i * chunkwidth); + } +} + +void genExtFieldSub(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree) { + // TODO: will need to handle bitwidth slightly smaller than data chunks + assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks + size_t chunkwidth = bitwidth / 128; + llvm::SmallVector lhs(degree); + llvm::SmallVector rhs(degree); + for (size_t i = 0; i < degree; i++) { + lhs[i] = builder.create(loc, bitwidth, 11, i * chunkwidth); + rhs[i] = builder.create(loc, bitwidth, 12, i * chunkwidth); + } + auto prime = builder.create(loc, bitwidth, 13, 0); + auto result = BigInt::field::extSub(builder, loc, lhs, rhs, prime); + for (size_t i = 0; i < degree; i++) { + builder.create(loc, result[i], 14, i * chunkwidth); + } +} } // namespace zirgen::BigInt::field diff --git a/zirgen/circuit/bigint/field.h b/zirgen/circuit/bigint/field.h index 0c52f447..1eb76805 100644 --- a/zirgen/circuit/bigint/field.h +++ b/zirgen/circuit/bigint/field.h @@ -20,6 +20,29 @@ using namespace mlir; namespace zirgen::BigInt::field { +// Finite Field arithmetic +// +// These functions accelerate finite field arithmetic +// - The `Mod` versions are for prime order fields +// - The `FieldExt` versions are for simple extensions +// - Every finite extension of a finite field is simple, so in a sense this covers every finite +// field, but to use these functions you must represent the extension as the adjunction of a +// primitive element to a prime order field, which is not always convenient (i.e. when you have +// a tower of extensions) +// +// We do not use integer quotients in these functions, so minBits does not give us performance gains +// and we therefore do not require the prime to be full bitwidth, enabling simpler generalization +// (i.e., there's no need to make sure the bitwidth is minimal for your use case) + +// Full programs, including I/O +void genModAdd(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth); +void genModInv(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth); +void genModMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth); +void genModSub(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth); +void genExtFieldAdd(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree); +void genExtFieldMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree); +void genExtFieldSub(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree); + // Prime field arithmetic (aka modular arithmetic) Value modAdd(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, Value prime); Value modInv(mlir::OpBuilder builder, mlir::Location loc, Value inp, Value prime); From 1a9fe6eb37857fe3828291d138253ab9ff33e341 Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Thu, 5 Dec 2024 15:55:26 -0800 Subject: [PATCH 17/34] Fix modsub --- zirgen/circuit/bigint/field.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/zirgen/circuit/bigint/field.cpp b/zirgen/circuit/bigint/field.cpp index b1d84220..57afc7d4 100644 --- a/zirgen/circuit/bigint/field.cpp +++ b/zirgen/circuit/bigint/field.cpp @@ -36,7 +36,10 @@ Value modMul(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, Value modSub(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, Value prime) { auto diff = builder.create(loc, lhs, rhs); - auto result = builder.create(loc, diff, prime); + // True statements can fail to prove if a ReduceOp is given negative inputs; thus, add `prime` + // to ensure all normalized inputs can produce an answer + auto diff_aug = builder.create(loc, diff, prime); + auto result = builder.create(loc, diff_aug, prime); return result; } From ef3e665836b3dd0a693cf1065df6accc23859d85 Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Fri, 13 Dec 2024 13:38:23 -0800 Subject: [PATCH 18/34] Drop extension field code (for now) --- zirgen/bootstrap/src/main.rs | 4 - zirgen/circuit/bigint/BUILD.bazel | 24 ------ zirgen/circuit/bigint/bigint2c.cpp | 17 +--- zirgen/circuit/bigint/field.cpp | 132 +---------------------------- zirgen/circuit/bigint/field.h | 16 +--- 5 files changed, 3 insertions(+), 190 deletions(-) diff --git a/zirgen/bootstrap/src/main.rs b/zirgen/bootstrap/src/main.rs index d19f082f..ec10461e 100644 --- a/zirgen/bootstrap/src/main.rs +++ b/zirgen/bootstrap/src/main.rs @@ -610,10 +610,6 @@ impl Bootstrap { let field_path = risc0_root.join("bigint2/src/field"); let rsa_path = risc0_root.join("bigint2/src/rsa"); - // TODO: Bitwidths on field ops - self.copy_file(&src_path, &field_path, "extfieldadd_256.blob"); - self.copy_file(&src_path, &field_path, "extfieldmul_256.blob"); - self.copy_file(&src_path, &field_path, "extfieldsub_256.blob"); self.copy_file(&src_path, &field_path, "modadd_256.blob"); self.copy_file(&src_path, &field_path, "modinv_256.blob"); self.copy_file(&src_path, &field_path, "modmul_256.blob"); diff --git a/zirgen/circuit/bigint/BUILD.bazel b/zirgen/circuit/bigint/BUILD.bazel index 2e2be417..718eea7e 100644 --- a/zirgen/circuit/bigint/BUILD.bazel +++ b/zirgen/circuit/bigint/BUILD.bazel @@ -58,9 +58,6 @@ BLOBS = [ "modpow65537_4096", "ec_double_256", "ec_add_256", - "extfieldadd_256", - "extfieldmul_256", - "extfieldsub_256", "modadd_256", "modinv_256", "modmul_256", @@ -88,27 +85,6 @@ genrule( cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=ec_add --bitwidth 256 > $(OUTS)" ) -genrule( - name = "extfieldadd_256", - outs = ["extfieldadd_256.blob"], - exec_tools = [":bigint2c"], - cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfieldadd --bitwidth 256 > $(OUTS)" -) - -genrule( - name = "extfieldmul_256", - outs = ["extfieldmul_256.blob"], - exec_tools = [":bigint2c"], - cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfieldmul --bitwidth 256 > $(OUTS)" -) - -genrule( - name = "extfieldsub_256", - outs = ["extfieldsub_256.blob"], - exec_tools = [":bigint2c"], - cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfieldsub --bitwidth 256 > $(OUTS)" -) - genrule( name = "modadd_256", outs = ["modadd_256.blob"], diff --git a/zirgen/circuit/bigint/bigint2c.cpp b/zirgen/circuit/bigint/bigint2c.cpp index 2ea2389a..13970333 100644 --- a/zirgen/circuit/bigint/bigint2c.cpp +++ b/zirgen/circuit/bigint/bigint2c.cpp @@ -44,9 +44,6 @@ enum class Program { ModPow65537, EC_Double, EC_Add, - ExtFieldAdd, - ExtFieldMul, - ExtFieldSub, ModAdd, ModInv, ModMul, @@ -60,13 +57,10 @@ static cl::opt cl::values(clEnumValN(Program::ModPow65537, "modpow65537", "ModPow65537"), clEnumValN(Program::EC_Double, "ec_double", "EC_Double"), clEnumValN(Program::EC_Add, "ec_add", "EC_Add"), - clEnumValN(Program::ExtFieldAdd, "extfieldadd", "ExtFieldAdd"), - clEnumValN(Program::ExtFieldMul, "extfieldmul", "ExtFieldMul"), - clEnumValN(Program::ExtFieldSub, "extfieldsub", "ExtFieldSub"), clEnumValN(Program::ModAdd, "modadd", "ModAdd"), clEnumValN(Program::ModInv, "modinv", "ModInv"), clEnumValN(Program::ModMul, "modmul", "ModMul"), - clEnumValN(Program::ModSub, "modsub", "ModSub")), // TODO: Don't hardcode bitwidth + clEnumValN(Program::ModSub, "modsub", "ModSub")), cl::Required); static cl::opt bitwidth("bitwidth", @@ -448,15 +442,6 @@ int main(int argc, char* argv[]) { case Program::EC_Add: zirgen::BigInt::EC::genECAdd(builder, loc, bitwidth); break; - case Program::ExtFieldAdd: // TODO: Naming for degree 2 - zirgen::BigInt::field::genExtFieldAdd(builder, loc, bitwidth, 2); - break; - case Program::ExtFieldMul: - zirgen::BigInt::field::genExtFieldMul(builder, loc, bitwidth, 2); - break; - case Program::ExtFieldSub: - zirgen::BigInt::field::genExtFieldSub(builder, loc, bitwidth, 2); - break; case Program::ModAdd: zirgen::BigInt::field::genModAdd(builder, loc, bitwidth); break; diff --git a/zirgen/circuit/bigint/field.cpp b/zirgen/circuit/bigint/field.cpp index 57afc7d4..640e2a73 100644 --- a/zirgen/circuit/bigint/field.cpp +++ b/zirgen/circuit/bigint/field.cpp @@ -43,77 +43,6 @@ Value modSub(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, return result; } -// Extension field operations - -llvm::SmallVector extAdd(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime) { - auto deg = lhs.size(); - assert(rhs.size() == deg); - llvm::SmallVector result(deg); - - for (size_t i = 0; i < deg; i++) { - auto sum = builder.create(loc, lhs[i], rhs[i]); - result[i] = builder.create(loc, sum, prime); - } - return result; -} - -llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime, llvm::SmallVector monic_irred_poly) { - // TODO: Annoying to have a SmallVector output that needs to be deg - 1 bigger than the inputs; I think that means all should be 3... - // TODO: We could have a simplified version for nth roots x^n - a - // Here `monic_irred_poly` is the coefficients a_i such that x^n - sum_i a_i x^i = 0 - auto deg = lhs.size(); - // Note: The field is not an extension field if deg <= 1 - assert(deg > 1); - assert(rhs.size() == deg); - assert(monic_irred_poly.size() == deg); - llvm::SmallVector result(2 * deg - 1); - llvm::SmallVector first_write(2 * deg - 1, true); - - // Compute product of polynomials - for (size_t i = 0; i < deg; i++) { - for (size_t j = 0; j < deg; j++) { - size_t idx = i + j; - auto prod = builder.create(loc, lhs[i], rhs[j]); - auto reduced_prod = builder.create(loc, prod, prime); - if (first_write[idx]) { - result[idx] = reduced_prod; - first_write[idx] = false; - } else { - result[idx] = builder.create(loc, result[idx], reduced_prod); - result[idx] = builder.create(loc, result[idx], prime); - } - } - auto sum = builder.create(loc, lhs[i], rhs[i]); - result[i] = builder.create(loc, sum, prime); - } - // Reduce using the monic irred polynomial of the extension field - for (size_t i = 2 * deg - 2; i >= deg; i--) { - for (size_t j = 0; j < deg; j++) { - auto prod = builder.create(loc, result[i], monic_irred_poly[j]); - result[i - deg + j] = builder.create(loc, result[i - deg + j], prod); - result[i - deg + j] = builder.create(loc, result[i - deg + j], prime); - } - // No need to zero out result[i], it will just get dropped - } - // Result's degree is just `deg`, drop the coefficients beyond that - result.truncate(deg); - - return result; -} - -llvm::SmallVector extSub(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime) { - auto deg = lhs.size(); - assert(rhs.size() == deg); - llvm::SmallVector result(deg); - - for (size_t i = 0; i < deg; i++) { - // auto diff = builder.create(loc, lhs[i], rhs[i]); - auto diff = builder.create(loc, lhs[i], rhs[i]); - result[i] = builder.create(loc, diff, prime); - } - return result; -} - // Full programs, including I/O // TODO: Move & cleanup finite field changes @@ -122,11 +51,7 @@ llvm::SmallVector extSub(mlir::OpBuilder builder, mlir::Location loc, // // These functions accelerate finite field arithmetic // - The `Mod` versions are for prime order fields -// - The `FieldExt` versions are for simple extensions -// - Every finite extension of a finite field is simple, so in a sense this covers every finite -// field, but to use these functions you must represent the extension as the adjunction of a -// primitive element to a prime order field, which is not always convenient (i.e. when you have -// a tower of extensions) +// - Versions for finite extensions of prime fields are planned as future work // // We do not use integer quotients in these functions, so minBits does not give us performance gains // and we therefore do not require the prime to be full bitwidth, enabling simpler generalization @@ -163,59 +88,4 @@ void genModSub(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth) { builder.create(loc, result, 14, 0); } -// Extension fields we use are most commonly degree 2 -// TODO: ^ Hence the use of 2 in the SmallVectors ... but is this true? -void genExtFieldAdd(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree) { - // TODO: will need to handle bitwidth slightly smaller than data chunks - assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks - size_t chunkwidth = bitwidth / 128; - llvm::SmallVector lhs(degree); - llvm::SmallVector rhs(degree); - for (size_t i = 0; i < degree; i++) { - lhs[i] = builder.create(loc, bitwidth, 11, i * chunkwidth); - rhs[i] = builder.create(loc, bitwidth, 12, i * chunkwidth); - } - auto prime = builder.create(loc, bitwidth, 13, 0); - auto result = BigInt::field::extAdd(builder, loc, lhs, rhs, prime); - for (size_t i = 0; i < degree; i++) { - builder.create(loc, result[i], 14, i * chunkwidth); - } -} - -void genExtFieldMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree) { - // TODO: will need to handle bitwidth slightly smaller than data chunks - assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks - size_t chunkwidth = bitwidth / 128; - llvm::SmallVector lhs(degree); - llvm::SmallVector rhs(degree); - llvm::SmallVector monic_irred_poly(degree); - for (size_t i = 0; i < degree; i++) { - lhs[i] = builder.create(loc, bitwidth, 11, i * chunkwidth); - rhs[i] = builder.create(loc, bitwidth, 12, i * chunkwidth); - monic_irred_poly[i] = builder.create(loc, bitwidth, 13, i * chunkwidth); - } - auto prime = builder.create(loc, bitwidth, 14, 0); - auto result = BigInt::field::extMul(builder, loc, lhs, rhs, prime, monic_irred_poly); - for (size_t i = 0; i < degree; i++) { - builder.create(loc, result[i], 15, i * chunkwidth); - } -} - -void genExtFieldSub(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree) { - // TODO: will need to handle bitwidth slightly smaller than data chunks - assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks - size_t chunkwidth = bitwidth / 128; - llvm::SmallVector lhs(degree); - llvm::SmallVector rhs(degree); - for (size_t i = 0; i < degree; i++) { - lhs[i] = builder.create(loc, bitwidth, 11, i * chunkwidth); - rhs[i] = builder.create(loc, bitwidth, 12, i * chunkwidth); - } - auto prime = builder.create(loc, bitwidth, 13, 0); - auto result = BigInt::field::extSub(builder, loc, lhs, rhs, prime); - for (size_t i = 0; i < degree; i++) { - builder.create(loc, result[i], 14, i * chunkwidth); - } -} - } // namespace zirgen::BigInt::field diff --git a/zirgen/circuit/bigint/field.h b/zirgen/circuit/bigint/field.h index 1eb76805..2597a3a4 100644 --- a/zirgen/circuit/bigint/field.h +++ b/zirgen/circuit/bigint/field.h @@ -24,11 +24,7 @@ namespace zirgen::BigInt::field { // // These functions accelerate finite field arithmetic // - The `Mod` versions are for prime order fields -// - The `FieldExt` versions are for simple extensions -// - Every finite extension of a finite field is simple, so in a sense this covers every finite -// field, but to use these functions you must represent the extension as the adjunction of a -// primitive element to a prime order field, which is not always convenient (i.e. when you have -// a tower of extensions) +// - Versions for finite extensions of prime fields are planned as future work // // We do not use integer quotients in these functions, so minBits does not give us performance gains // and we therefore do not require the prime to be full bitwidth, enabling simpler generalization @@ -39,9 +35,6 @@ void genModAdd(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth); void genModInv(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth); void genModMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth); void genModSub(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth); -void genExtFieldAdd(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree); -void genExtFieldMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree); -void genExtFieldSub(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree); // Prime field arithmetic (aka modular arithmetic) Value modAdd(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, Value prime); @@ -49,11 +42,4 @@ Value modInv(mlir::OpBuilder builder, mlir::Location loc, Value inp, Value prime Value modMul(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, Value prime); Value modSub(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, Value prime); -// Extension field arithmetic -// Extension fields we use are most commonly degree 2 -// TODO: ^ Hence the use of 2 in the SmallVectors ... but is this true? -llvm::SmallVector extAdd(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); -llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime, llvm::SmallVector monic_irred_poly); -llvm::SmallVector extSub(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); - } // namespace zirgen::BigInt::field From 57f2f90cb31b07d5236f77f2624ccb48bcc02d0c Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Fri, 13 Dec 2024 14:18:13 -0800 Subject: [PATCH 19/34] Clear completed TODOs --- zirgen/circuit/bigint/BUILD.bazel | 1 - zirgen/circuit/bigint/field.cpp | 2 -- 2 files changed, 3 deletions(-) diff --git a/zirgen/circuit/bigint/BUILD.bazel b/zirgen/circuit/bigint/BUILD.bazel index 718eea7e..e21c5be4 100644 --- a/zirgen/circuit/bigint/BUILD.bazel +++ b/zirgen/circuit/bigint/BUILD.bazel @@ -53,7 +53,6 @@ cc_binary( ], ) -# TODO: Bitwidths everywhere BLOBS = [ "modpow65537_4096", "ec_double_256", diff --git a/zirgen/circuit/bigint/field.cpp b/zirgen/circuit/bigint/field.cpp index 640e2a73..231830de 100644 --- a/zirgen/circuit/bigint/field.cpp +++ b/zirgen/circuit/bigint/field.cpp @@ -45,8 +45,6 @@ Value modSub(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, // Full programs, including I/O -// TODO: Move & cleanup finite field changes - // Finite Field arithmetic // // These functions accelerate finite field arithmetic From dcd455eba8ea5fb0efbf1a5a74af0f8a9906ac7f Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Fri, 13 Dec 2024 14:33:47 -0800 Subject: [PATCH 20/34] Format --- zirgen/circuit/bigint/field.cpp | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/zirgen/circuit/bigint/field.cpp b/zirgen/circuit/bigint/field.cpp index 231830de..00ff1e0f 100644 --- a/zirgen/circuit/bigint/field.cpp +++ b/zirgen/circuit/bigint/field.cpp @@ -19,28 +19,28 @@ namespace zirgen::BigInt::field { // Prime field operations Value modAdd(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, Value prime) { - auto sum = builder.create(loc, lhs, rhs); - auto result = builder.create(loc, sum, prime); - return result; + auto sum = builder.create(loc, lhs, rhs); + auto result = builder.create(loc, sum, prime); + return result; } Value modInv(mlir::OpBuilder builder, mlir::Location loc, Value inp, Value prime) { - return builder.create(loc, inp, prime); + return builder.create(loc, inp, prime); } Value modMul(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, Value prime) { - auto prod = builder.create(loc, lhs, rhs); - auto result = builder.create(loc, prod, prime); - return result; + auto prod = builder.create(loc, lhs, rhs); + auto result = builder.create(loc, prod, prime); + return result; } Value modSub(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, Value prime) { - auto diff = builder.create(loc, lhs, rhs); - // True statements can fail to prove if a ReduceOp is given negative inputs; thus, add `prime` - // to ensure all normalized inputs can produce an answer - auto diff_aug = builder.create(loc, diff, prime); - auto result = builder.create(loc, diff_aug, prime); - return result; + auto diff = builder.create(loc, lhs, rhs); + // True statements can fail to prove if a ReduceOp is given negative inputs; thus, add `prime` + // to ensure all normalized inputs can produce an answer + auto diff_aug = builder.create(loc, diff, prime); + auto result = builder.create(loc, diff_aug, prime); + return result; } // Full programs, including I/O From 01ed946df29d756647d56ea0ea3818801f15373a Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Mon, 16 Dec 2024 10:01:45 -0800 Subject: [PATCH 21/34] Revert "Drop extension field code (for now)" This reverts commit ef3e665836b3dd0a693cf1065df6accc23859d85. --- zirgen/bootstrap/src/main.rs | 4 + zirgen/circuit/bigint/BUILD.bazel | 24 ++++++ zirgen/circuit/bigint/bigint2c.cpp | 17 +++- zirgen/circuit/bigint/field.cpp | 132 ++++++++++++++++++++++++++++- zirgen/circuit/bigint/field.h | 16 +++- 5 files changed, 190 insertions(+), 3 deletions(-) diff --git a/zirgen/bootstrap/src/main.rs b/zirgen/bootstrap/src/main.rs index ec10461e..d19f082f 100644 --- a/zirgen/bootstrap/src/main.rs +++ b/zirgen/bootstrap/src/main.rs @@ -610,6 +610,10 @@ impl Bootstrap { let field_path = risc0_root.join("bigint2/src/field"); let rsa_path = risc0_root.join("bigint2/src/rsa"); + // TODO: Bitwidths on field ops + self.copy_file(&src_path, &field_path, "extfieldadd_256.blob"); + self.copy_file(&src_path, &field_path, "extfieldmul_256.blob"); + self.copy_file(&src_path, &field_path, "extfieldsub_256.blob"); self.copy_file(&src_path, &field_path, "modadd_256.blob"); self.copy_file(&src_path, &field_path, "modinv_256.blob"); self.copy_file(&src_path, &field_path, "modmul_256.blob"); diff --git a/zirgen/circuit/bigint/BUILD.bazel b/zirgen/circuit/bigint/BUILD.bazel index e21c5be4..84bafe24 100644 --- a/zirgen/circuit/bigint/BUILD.bazel +++ b/zirgen/circuit/bigint/BUILD.bazel @@ -57,6 +57,9 @@ BLOBS = [ "modpow65537_4096", "ec_double_256", "ec_add_256", + "extfieldadd_256", + "extfieldmul_256", + "extfieldsub_256", "modadd_256", "modinv_256", "modmul_256", @@ -84,6 +87,27 @@ genrule( cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=ec_add --bitwidth 256 > $(OUTS)" ) +genrule( + name = "extfieldadd_256", + outs = ["extfieldadd_256.blob"], + exec_tools = [":bigint2c"], + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfieldadd --bitwidth 256 > $(OUTS)" +) + +genrule( + name = "extfieldmul_256", + outs = ["extfieldmul_256.blob"], + exec_tools = [":bigint2c"], + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfieldmul --bitwidth 256 > $(OUTS)" +) + +genrule( + name = "extfieldsub_256", + outs = ["extfieldsub_256.blob"], + exec_tools = [":bigint2c"], + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfieldsub --bitwidth 256 > $(OUTS)" +) + genrule( name = "modadd_256", outs = ["modadd_256.blob"], diff --git a/zirgen/circuit/bigint/bigint2c.cpp b/zirgen/circuit/bigint/bigint2c.cpp index 13970333..2ea2389a 100644 --- a/zirgen/circuit/bigint/bigint2c.cpp +++ b/zirgen/circuit/bigint/bigint2c.cpp @@ -44,6 +44,9 @@ enum class Program { ModPow65537, EC_Double, EC_Add, + ExtFieldAdd, + ExtFieldMul, + ExtFieldSub, ModAdd, ModInv, ModMul, @@ -57,10 +60,13 @@ static cl::opt cl::values(clEnumValN(Program::ModPow65537, "modpow65537", "ModPow65537"), clEnumValN(Program::EC_Double, "ec_double", "EC_Double"), clEnumValN(Program::EC_Add, "ec_add", "EC_Add"), + clEnumValN(Program::ExtFieldAdd, "extfieldadd", "ExtFieldAdd"), + clEnumValN(Program::ExtFieldMul, "extfieldmul", "ExtFieldMul"), + clEnumValN(Program::ExtFieldSub, "extfieldsub", "ExtFieldSub"), clEnumValN(Program::ModAdd, "modadd", "ModAdd"), clEnumValN(Program::ModInv, "modinv", "ModInv"), clEnumValN(Program::ModMul, "modmul", "ModMul"), - clEnumValN(Program::ModSub, "modsub", "ModSub")), + clEnumValN(Program::ModSub, "modsub", "ModSub")), // TODO: Don't hardcode bitwidth cl::Required); static cl::opt bitwidth("bitwidth", @@ -442,6 +448,15 @@ int main(int argc, char* argv[]) { case Program::EC_Add: zirgen::BigInt::EC::genECAdd(builder, loc, bitwidth); break; + case Program::ExtFieldAdd: // TODO: Naming for degree 2 + zirgen::BigInt::field::genExtFieldAdd(builder, loc, bitwidth, 2); + break; + case Program::ExtFieldMul: + zirgen::BigInt::field::genExtFieldMul(builder, loc, bitwidth, 2); + break; + case Program::ExtFieldSub: + zirgen::BigInt::field::genExtFieldSub(builder, loc, bitwidth, 2); + break; case Program::ModAdd: zirgen::BigInt::field::genModAdd(builder, loc, bitwidth); break; diff --git a/zirgen/circuit/bigint/field.cpp b/zirgen/circuit/bigint/field.cpp index 00ff1e0f..1b5d6b3d 100644 --- a/zirgen/circuit/bigint/field.cpp +++ b/zirgen/circuit/bigint/field.cpp @@ -43,13 +43,88 @@ Value modSub(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, return result; } +// Extension field operations + +llvm::SmallVector extAdd(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime) { + auto deg = lhs.size(); + assert(rhs.size() == deg); + llvm::SmallVector result(deg); + + for (size_t i = 0; i < deg; i++) { + auto sum = builder.create(loc, lhs[i], rhs[i]); + result[i] = builder.create(loc, sum, prime); + } + return result; +} + +llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime, llvm::SmallVector monic_irred_poly) { + // TODO: Annoying to have a SmallVector output that needs to be deg - 1 bigger than the inputs; I think that means all should be 3... + // TODO: We could have a simplified version for nth roots x^n - a + // Here `monic_irred_poly` is the coefficients a_i such that x^n - sum_i a_i x^i = 0 + auto deg = lhs.size(); + // Note: The field is not an extension field if deg <= 1 + assert(deg > 1); + assert(rhs.size() == deg); + assert(monic_irred_poly.size() == deg); + llvm::SmallVector result(2 * deg - 1); + llvm::SmallVector first_write(2 * deg - 1, true); + + // Compute product of polynomials + for (size_t i = 0; i < deg; i++) { + for (size_t j = 0; j < deg; j++) { + size_t idx = i + j; + auto prod = builder.create(loc, lhs[i], rhs[j]); + auto reduced_prod = builder.create(loc, prod, prime); + if (first_write[idx]) { + result[idx] = reduced_prod; + first_write[idx] = false; + } else { + result[idx] = builder.create(loc, result[idx], reduced_prod); + result[idx] = builder.create(loc, result[idx], prime); + } + } + auto sum = builder.create(loc, lhs[i], rhs[i]); + result[i] = builder.create(loc, sum, prime); + } + // Reduce using the monic irred polynomial of the extension field + for (size_t i = 2 * deg - 2; i >= deg; i--) { + for (size_t j = 0; j < deg; j++) { + auto prod = builder.create(loc, result[i], monic_irred_poly[j]); + result[i - deg + j] = builder.create(loc, result[i - deg + j], prod); + result[i - deg + j] = builder.create(loc, result[i - deg + j], prime); + } + // No need to zero out result[i], it will just get dropped + } + // Result's degree is just `deg`, drop the coefficients beyond that + result.truncate(deg); + + return result; +} + +llvm::SmallVector extSub(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime) { + auto deg = lhs.size(); + assert(rhs.size() == deg); + llvm::SmallVector result(deg); + + for (size_t i = 0; i < deg; i++) { + // auto diff = builder.create(loc, lhs[i], rhs[i]); + auto diff = builder.create(loc, lhs[i], rhs[i]); + result[i] = builder.create(loc, diff, prime); + } + return result; +} + // Full programs, including I/O // Finite Field arithmetic // // These functions accelerate finite field arithmetic // - The `Mod` versions are for prime order fields -// - Versions for finite extensions of prime fields are planned as future work +// - The `FieldExt` versions are for simple extensions +// - Every finite extension of a finite field is simple, so in a sense this covers every finite +// field, but to use these functions you must represent the extension as the adjunction of a +// primitive element to a prime order field, which is not always convenient (i.e. when you have +// a tower of extensions) // // We do not use integer quotients in these functions, so minBits does not give us performance gains // and we therefore do not require the prime to be full bitwidth, enabling simpler generalization @@ -86,4 +161,59 @@ void genModSub(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth) { builder.create(loc, result, 14, 0); } +// Extension fields we use are most commonly degree 2 +// TODO: ^ Hence the use of 2 in the SmallVectors ... but is this true? +void genExtFieldAdd(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree) { + // TODO: will need to handle bitwidth slightly smaller than data chunks + assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks + size_t chunkwidth = bitwidth / 128; + llvm::SmallVector lhs(degree); + llvm::SmallVector rhs(degree); + for (size_t i = 0; i < degree; i++) { + lhs[i] = builder.create(loc, bitwidth, 11, i * chunkwidth); + rhs[i] = builder.create(loc, bitwidth, 12, i * chunkwidth); + } + auto prime = builder.create(loc, bitwidth, 13, 0); + auto result = BigInt::field::extAdd(builder, loc, lhs, rhs, prime); + for (size_t i = 0; i < degree; i++) { + builder.create(loc, result[i], 14, i * chunkwidth); + } +} + +void genExtFieldMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree) { + // TODO: will need to handle bitwidth slightly smaller than data chunks + assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks + size_t chunkwidth = bitwidth / 128; + llvm::SmallVector lhs(degree); + llvm::SmallVector rhs(degree); + llvm::SmallVector monic_irred_poly(degree); + for (size_t i = 0; i < degree; i++) { + lhs[i] = builder.create(loc, bitwidth, 11, i * chunkwidth); + rhs[i] = builder.create(loc, bitwidth, 12, i * chunkwidth); + monic_irred_poly[i] = builder.create(loc, bitwidth, 13, i * chunkwidth); + } + auto prime = builder.create(loc, bitwidth, 14, 0); + auto result = BigInt::field::extMul(builder, loc, lhs, rhs, prime, monic_irred_poly); + for (size_t i = 0; i < degree; i++) { + builder.create(loc, result[i], 15, i * chunkwidth); + } +} + +void genExtFieldSub(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree) { + // TODO: will need to handle bitwidth slightly smaller than data chunks + assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks + size_t chunkwidth = bitwidth / 128; + llvm::SmallVector lhs(degree); + llvm::SmallVector rhs(degree); + for (size_t i = 0; i < degree; i++) { + lhs[i] = builder.create(loc, bitwidth, 11, i * chunkwidth); + rhs[i] = builder.create(loc, bitwidth, 12, i * chunkwidth); + } + auto prime = builder.create(loc, bitwidth, 13, 0); + auto result = BigInt::field::extSub(builder, loc, lhs, rhs, prime); + for (size_t i = 0; i < degree; i++) { + builder.create(loc, result[i], 14, i * chunkwidth); + } +} + } // namespace zirgen::BigInt::field diff --git a/zirgen/circuit/bigint/field.h b/zirgen/circuit/bigint/field.h index 2597a3a4..1eb76805 100644 --- a/zirgen/circuit/bigint/field.h +++ b/zirgen/circuit/bigint/field.h @@ -24,7 +24,11 @@ namespace zirgen::BigInt::field { // // These functions accelerate finite field arithmetic // - The `Mod` versions are for prime order fields -// - Versions for finite extensions of prime fields are planned as future work +// - The `FieldExt` versions are for simple extensions +// - Every finite extension of a finite field is simple, so in a sense this covers every finite +// field, but to use these functions you must represent the extension as the adjunction of a +// primitive element to a prime order field, which is not always convenient (i.e. when you have +// a tower of extensions) // // We do not use integer quotients in these functions, so minBits does not give us performance gains // and we therefore do not require the prime to be full bitwidth, enabling simpler generalization @@ -35,6 +39,9 @@ void genModAdd(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth); void genModInv(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth); void genModMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth); void genModSub(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth); +void genExtFieldAdd(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree); +void genExtFieldMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree); +void genExtFieldSub(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree); // Prime field arithmetic (aka modular arithmetic) Value modAdd(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, Value prime); @@ -42,4 +49,11 @@ Value modInv(mlir::OpBuilder builder, mlir::Location loc, Value inp, Value prime Value modMul(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, Value prime); Value modSub(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, Value prime); +// Extension field arithmetic +// Extension fields we use are most commonly degree 2 +// TODO: ^ Hence the use of 2 in the SmallVectors ... but is this true? +llvm::SmallVector extAdd(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); +llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime, llvm::SmallVector monic_irred_poly); +llvm::SmallVector extSub(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); + } // namespace zirgen::BigInt::field From 3bed4078359ed783d72a4c95cac802ce0c214b91 Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Tue, 17 Dec 2024 18:06:00 +0000 Subject: [PATCH 22/34] Drop extraneous op from ext field mul --- zirgen/circuit/bigint/field.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/zirgen/circuit/bigint/field.cpp b/zirgen/circuit/bigint/field.cpp index 1b5d6b3d..91e8b10b 100644 --- a/zirgen/circuit/bigint/field.cpp +++ b/zirgen/circuit/bigint/field.cpp @@ -83,8 +83,6 @@ llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, result[idx] = builder.create(loc, result[idx], prime); } } - auto sum = builder.create(loc, lhs[i], rhs[i]); - result[i] = builder.create(loc, sum, prime); } // Reduce using the monic irred polynomial of the extension field for (size_t i = 2 * deg - 2; i >= deg; i--) { From ba0aeae8939ffa29dd4d76d21e9bc9a50531d128 Mon Sep 17 00:00:00 2001 From: iddo Date: Tue, 17 Dec 2024 16:01:52 -0800 Subject: [PATCH 23/34] extSub fix, same as modSub --- zirgen/circuit/bigint/field.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/zirgen/circuit/bigint/field.cpp b/zirgen/circuit/bigint/field.cpp index 91e8b10b..85466d6e 100644 --- a/zirgen/circuit/bigint/field.cpp +++ b/zirgen/circuit/bigint/field.cpp @@ -107,7 +107,9 @@ llvm::SmallVector extSub(mlir::OpBuilder builder, mlir::Location loc, for (size_t i = 0; i < deg; i++) { // auto diff = builder.create(loc, lhs[i], rhs[i]); auto diff = builder.create(loc, lhs[i], rhs[i]); - result[i] = builder.create(loc, diff, prime); + //Add `prime` due to the same reason as in modSub + auto diff_aug = builder.create(loc, diff, prime); + result[i] = builder.create(loc, diff_aug, prime); } return result; } From 0646a03a67d0ddcf0369e23b76c4f12102298548 Mon Sep 17 00:00:00 2001 From: iddo Date: Wed, 18 Dec 2024 13:57:31 -0800 Subject: [PATCH 24/34] extfield mult params ordering fix --- zirgen/circuit/bigint/field.cpp | 4 ++-- zirgen/circuit/bigint/field.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/zirgen/circuit/bigint/field.cpp b/zirgen/circuit/bigint/field.cpp index 85466d6e..cb6b82cb 100644 --- a/zirgen/circuit/bigint/field.cpp +++ b/zirgen/circuit/bigint/field.cpp @@ -57,7 +57,7 @@ llvm::SmallVector extAdd(mlir::OpBuilder builder, mlir::Location loc, return result; } -llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime, llvm::SmallVector monic_irred_poly) { +llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, llvm::SmallVector monic_irred_poly, Value prime) { // TODO: Annoying to have a SmallVector output that needs to be deg - 1 bigger than the inputs; I think that means all should be 3... // TODO: We could have a simplified version for nth roots x^n - a // Here `monic_irred_poly` is the coefficients a_i such that x^n - sum_i a_i x^i = 0 @@ -193,7 +193,7 @@ void genExtFieldMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth monic_irred_poly[i] = builder.create(loc, bitwidth, 13, i * chunkwidth); } auto prime = builder.create(loc, bitwidth, 14, 0); - auto result = BigInt::field::extMul(builder, loc, lhs, rhs, prime, monic_irred_poly); + auto result = BigInt::field::extMul(builder, loc, lhs, rhs, monic_irred_poly, prime); for (size_t i = 0; i < degree; i++) { builder.create(loc, result[i], 15, i * chunkwidth); } diff --git a/zirgen/circuit/bigint/field.h b/zirgen/circuit/bigint/field.h index 1eb76805..393b68dc 100644 --- a/zirgen/circuit/bigint/field.h +++ b/zirgen/circuit/bigint/field.h @@ -53,7 +53,7 @@ Value modSub(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, // Extension fields we use are most commonly degree 2 // TODO: ^ Hence the use of 2 in the SmallVectors ... but is this true? llvm::SmallVector extAdd(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); -llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime, llvm::SmallVector monic_irred_poly); +llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, llvm::SmallVector monic_irred_poly, Value prime); llvm::SmallVector extSub(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); } // namespace zirgen::BigInt::field From 298b894aa9b352cb3bbc2b65dcc2d75ac20b180f Mon Sep 17 00:00:00 2001 From: iddo Date: Tue, 7 Jan 2025 23:10:08 -0800 Subject: [PATCH 25/34] deg4 extmul --- zirgen/bootstrap/src/main.rs | 1 + zirgen/circuit/bigint/BUILD.bazel | 8 ++++++++ zirgen/circuit/bigint/bigint2c.cpp | 6 ++++++ 3 files changed, 15 insertions(+) diff --git a/zirgen/bootstrap/src/main.rs b/zirgen/bootstrap/src/main.rs index d19f082f..594a50a2 100644 --- a/zirgen/bootstrap/src/main.rs +++ b/zirgen/bootstrap/src/main.rs @@ -613,6 +613,7 @@ impl Bootstrap { // TODO: Bitwidths on field ops self.copy_file(&src_path, &field_path, "extfieldadd_256.blob"); self.copy_file(&src_path, &field_path, "extfieldmul_256.blob"); + self.copy_file(&src_path, &field_path, "extfield_deg4_mul_256.blob"); self.copy_file(&src_path, &field_path, "extfieldsub_256.blob"); self.copy_file(&src_path, &field_path, "modadd_256.blob"); self.copy_file(&src_path, &field_path, "modinv_256.blob"); diff --git a/zirgen/circuit/bigint/BUILD.bazel b/zirgen/circuit/bigint/BUILD.bazel index 84bafe24..b868315d 100644 --- a/zirgen/circuit/bigint/BUILD.bazel +++ b/zirgen/circuit/bigint/BUILD.bazel @@ -59,6 +59,7 @@ BLOBS = [ "ec_add_256", "extfieldadd_256", "extfieldmul_256", + "extfield_deg4_mul_256", "extfieldsub_256", "modadd_256", "modinv_256", @@ -94,6 +95,13 @@ genrule( cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfieldadd --bitwidth 256 > $(OUTS)" ) +genrule( + name = "extfield_deg4_mul_256", + outs = ["extfield_deg4_mul_256.blob"], + exec_tools = [":bigint2c"], + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfield_deg4_mul --bitwidth 256 > $(OUTS)" +) + genrule( name = "extfieldmul_256", outs = ["extfieldmul_256.blob"], diff --git a/zirgen/circuit/bigint/bigint2c.cpp b/zirgen/circuit/bigint/bigint2c.cpp index 2ea2389a..f551dc5a 100644 --- a/zirgen/circuit/bigint/bigint2c.cpp +++ b/zirgen/circuit/bigint/bigint2c.cpp @@ -46,6 +46,7 @@ enum class Program { EC_Add, ExtFieldAdd, ExtFieldMul, + ExtField_Deg4_Mul, ExtFieldSub, ModAdd, ModInv, @@ -62,6 +63,8 @@ static cl::opt clEnumValN(Program::EC_Add, "ec_add", "EC_Add"), clEnumValN(Program::ExtFieldAdd, "extfieldadd", "ExtFieldAdd"), clEnumValN(Program::ExtFieldMul, "extfieldmul", "ExtFieldMul"), + clEnumValN(Program::ExtField_Deg4_Mul, + "extfield_deg4_mul", "ExtField_Deg4_Mul"), clEnumValN(Program::ExtFieldSub, "extfieldsub", "ExtFieldSub"), clEnumValN(Program::ModAdd, "modadd", "ModAdd"), clEnumValN(Program::ModInv, "modinv", "ModInv"), @@ -454,6 +457,9 @@ int main(int argc, char* argv[]) { case Program::ExtFieldMul: zirgen::BigInt::field::genExtFieldMul(builder, loc, bitwidth, 2); break; + case Program::ExtField_Deg4_Mul: + zirgen::BigInt::field::genExtFieldMul(builder, loc, bitwidth, 4); + break; case Program::ExtFieldSub: zirgen::BigInt::field::genExtFieldSub(builder, loc, bitwidth, 2); break; From 047656681d2975a1e7b2b311ab1bae83df74a4ab Mon Sep 17 00:00:00 2001 From: iddo Date: Thu, 9 Jan 2025 11:08:29 -0800 Subject: [PATCH 26/34] extmulxxone for specific xx+1 irreducible poly --- zirgen/bootstrap/src/main.rs | 1 + zirgen/circuit/bigint/BUILD.bazel | 8 +++++++ zirgen/circuit/bigint/bigint2c.cpp | 6 +++++ zirgen/circuit/bigint/field.cpp | 38 ++++++++++++++++++++++++++++++ zirgen/circuit/bigint/field.h | 2 ++ 5 files changed, 55 insertions(+) diff --git a/zirgen/bootstrap/src/main.rs b/zirgen/bootstrap/src/main.rs index 594a50a2..6aef47a3 100644 --- a/zirgen/bootstrap/src/main.rs +++ b/zirgen/bootstrap/src/main.rs @@ -613,6 +613,7 @@ impl Bootstrap { // TODO: Bitwidths on field ops self.copy_file(&src_path, &field_path, "extfieldadd_256.blob"); self.copy_file(&src_path, &field_path, "extfieldmul_256.blob"); + self.copy_file(&src_path, &field_path, "extfield_xxone_mul_256.blob"); self.copy_file(&src_path, &field_path, "extfield_deg4_mul_256.blob"); self.copy_file(&src_path, &field_path, "extfieldsub_256.blob"); self.copy_file(&src_path, &field_path, "modadd_256.blob"); diff --git a/zirgen/circuit/bigint/BUILD.bazel b/zirgen/circuit/bigint/BUILD.bazel index b868315d..83eace54 100644 --- a/zirgen/circuit/bigint/BUILD.bazel +++ b/zirgen/circuit/bigint/BUILD.bazel @@ -59,6 +59,7 @@ BLOBS = [ "ec_add_256", "extfieldadd_256", "extfieldmul_256", + "extfield_xxone_mul_256", "extfield_deg4_mul_256", "extfieldsub_256", "modadd_256", @@ -102,6 +103,13 @@ genrule( cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfield_deg4_mul --bitwidth 256 > $(OUTS)" ) +genrule( + name = "extfield_xxone_mul_256", + outs = ["extfield_xxone_mul_256.blob"], + exec_tools = [":bigint2c"], + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfield_xxone_mul --bitwidth 256 > $(OUTS)" +) + genrule( name = "extfieldmul_256", outs = ["extfieldmul_256.blob"], diff --git a/zirgen/circuit/bigint/bigint2c.cpp b/zirgen/circuit/bigint/bigint2c.cpp index f551dc5a..34e33a11 100644 --- a/zirgen/circuit/bigint/bigint2c.cpp +++ b/zirgen/circuit/bigint/bigint2c.cpp @@ -47,6 +47,7 @@ enum class Program { ExtFieldAdd, ExtFieldMul, ExtField_Deg4_Mul, + ExtField_XXOne_Mul, ExtFieldSub, ModAdd, ModInv, @@ -63,6 +64,8 @@ static cl::opt clEnumValN(Program::EC_Add, "ec_add", "EC_Add"), clEnumValN(Program::ExtFieldAdd, "extfieldadd", "ExtFieldAdd"), clEnumValN(Program::ExtFieldMul, "extfieldmul", "ExtFieldMul"), + clEnumValN(Program::ExtField_XXOne_Mul, + "extfield_xxone_mul", "ExtField_XXOne_Mul"), clEnumValN(Program::ExtField_Deg4_Mul, "extfield_deg4_mul", "ExtField_Deg4_Mul"), clEnumValN(Program::ExtFieldSub, "extfieldsub", "ExtFieldSub"), @@ -457,6 +460,9 @@ int main(int argc, char* argv[]) { case Program::ExtFieldMul: zirgen::BigInt::field::genExtFieldMul(builder, loc, bitwidth, 2); break; + case Program::ExtField_XXOne_Mul: + zirgen::BigInt::field::genExtFieldXXOneMul(builder, loc, bitwidth); + break; case Program::ExtField_Deg4_Mul: zirgen::BigInt::field::genExtFieldMul(builder, loc, bitwidth, 4); break; diff --git a/zirgen/circuit/bigint/field.cpp b/zirgen/circuit/bigint/field.cpp index cb6b82cb..ab09324a 100644 --- a/zirgen/circuit/bigint/field.cpp +++ b/zirgen/circuit/bigint/field.cpp @@ -57,6 +57,27 @@ llvm::SmallVector extAdd(mlir::OpBuilder builder, mlir::Location loc, return result; } +// Deg2 extfield mul with irreducible polynomial x^2+1 +// (ax+b)(cx+d) == acxx-ac(xx+1) + (ad+bc)x + bd == (ad+bc)x + bd-ac +llvm::SmallVector extMulXXONE(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime) { + assert(lhs.size() == 2); + assert(rhs.size() == 2); + llvm::SmallVector result(2); + + auto ad = builder.create(loc, lhs[1], rhs[0]); + auto bc = builder.create(loc, lhs[0], rhs[1]); + result[1] = builder.create(loc, ad, bc); + result[1] = builder.create(loc, result[1], prime); + + auto bd = builder.create(loc, lhs[0], rhs[0]); + auto ac = builder.create(loc, lhs[1], rhs[1]); + result[0] = builder.create(loc, bd, ac); + result[0] = builder.create(loc, result[0], prime); + result[0] = builder.create(loc, result[0], prime); + + return result; +} + llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, llvm::SmallVector monic_irred_poly, Value prime) { // TODO: Annoying to have a SmallVector output that needs to be deg - 1 bigger than the inputs; I think that means all should be 3... // TODO: We could have a simplified version for nth roots x^n - a @@ -199,6 +220,23 @@ void genExtFieldMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth } } +void genExtFieldXXOneMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth) { + // TODO: will need to handle bitwidth slightly smaller than data chunks + assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks + size_t chunkwidth = bitwidth / 128; + llvm::SmallVector lhs(2); + llvm::SmallVector rhs(2); + for (size_t i = 0; i < 2; i++) { + lhs[i] = builder.create(loc, bitwidth, 11, i * chunkwidth); + rhs[i] = builder.create(loc, bitwidth, 12, i * chunkwidth); + } + auto prime = builder.create(loc, bitwidth, 13, 0); + auto result = BigInt::field::extMulXXONE(builder, loc, lhs, rhs, prime); + for (size_t i = 0; i < 2; i++) { + builder.create(loc, result[i], 14, i * chunkwidth); + } +} + void genExtFieldSub(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree) { // TODO: will need to handle bitwidth slightly smaller than data chunks assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks diff --git a/zirgen/circuit/bigint/field.h b/zirgen/circuit/bigint/field.h index 393b68dc..452a58ef 100644 --- a/zirgen/circuit/bigint/field.h +++ b/zirgen/circuit/bigint/field.h @@ -41,6 +41,7 @@ void genModMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth); void genModSub(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth); void genExtFieldAdd(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree); void genExtFieldMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree); +void genExtFieldXXOneMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth); void genExtFieldSub(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree); // Prime field arithmetic (aka modular arithmetic) @@ -54,6 +55,7 @@ Value modSub(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, // TODO: ^ Hence the use of 2 in the SmallVectors ... but is this true? llvm::SmallVector extAdd(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, llvm::SmallVector monic_irred_poly, Value prime); +llvm::SmallVector extMulXXONE(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); llvm::SmallVector extSub(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); } // namespace zirgen::BigInt::field From bbb7ad8f76aca5560458bc422085e0838e1676a6 Mon Sep 17 00:00:00 2001 From: iddo Date: Thu, 9 Jan 2025 13:25:55 -0800 Subject: [PATCH 27/34] fix primesqr --- zirgen/circuit/bigint/field.cpp | 9 +++++---- zirgen/circuit/bigint/field.h | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/zirgen/circuit/bigint/field.cpp b/zirgen/circuit/bigint/field.cpp index ab09324a..1e9d7ddf 100644 --- a/zirgen/circuit/bigint/field.cpp +++ b/zirgen/circuit/bigint/field.cpp @@ -59,7 +59,7 @@ llvm::SmallVector extAdd(mlir::OpBuilder builder, mlir::Location loc, // Deg2 extfield mul with irreducible polynomial x^2+1 // (ax+b)(cx+d) == acxx-ac(xx+1) + (ad+bc)x + bd == (ad+bc)x + bd-ac -llvm::SmallVector extMulXXONE(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime) { +llvm::SmallVector extMulXXONE(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime, Value primesqr) { assert(lhs.size() == 2); assert(rhs.size() == 2); llvm::SmallVector result(2); @@ -72,7 +72,7 @@ llvm::SmallVector extMulXXONE(mlir::OpBuilder builder, mlir::Location auto bd = builder.create(loc, lhs[0], rhs[0]); auto ac = builder.create(loc, lhs[1], rhs[1]); result[0] = builder.create(loc, bd, ac); - result[0] = builder.create(loc, result[0], prime); + result[0] = builder.create(loc, result[0], primesqr); result[0] = builder.create(loc, result[0], prime); return result; @@ -231,9 +231,10 @@ void genExtFieldXXOneMul(mlir::OpBuilder builder, mlir::Location loc, size_t bit rhs[i] = builder.create(loc, bitwidth, 12, i * chunkwidth); } auto prime = builder.create(loc, bitwidth, 13, 0); - auto result = BigInt::field::extMulXXONE(builder, loc, lhs, rhs, prime); + auto primesqr = builder.create(loc, bitwidth, 14, 0); + auto result = BigInt::field::extMulXXONE(builder, loc, lhs, rhs, prime, primesqr); for (size_t i = 0; i < 2; i++) { - builder.create(loc, result[i], 14, i * chunkwidth); + builder.create(loc, result[i], 15, i * chunkwidth); } } diff --git a/zirgen/circuit/bigint/field.h b/zirgen/circuit/bigint/field.h index 452a58ef..9c2d3750 100644 --- a/zirgen/circuit/bigint/field.h +++ b/zirgen/circuit/bigint/field.h @@ -55,7 +55,7 @@ Value modSub(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, // TODO: ^ Hence the use of 2 in the SmallVectors ... but is this true? llvm::SmallVector extAdd(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, llvm::SmallVector monic_irred_poly, Value prime); -llvm::SmallVector extMulXXONE(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); +llvm::SmallVector extMulXXONE(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime, Value primesqr); llvm::SmallVector extSub(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); } // namespace zirgen::BigInt::field From 9ac30a12f1cde6388717a8765c982c8712cdcacf Mon Sep 17 00:00:00 2001 From: iddo Date: Thu, 9 Jan 2025 13:54:40 -0800 Subject: [PATCH 28/34] fix: double bitwidth for primesqr --- zirgen/circuit/bigint/field.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zirgen/circuit/bigint/field.cpp b/zirgen/circuit/bigint/field.cpp index 1e9d7ddf..3f78cc81 100644 --- a/zirgen/circuit/bigint/field.cpp +++ b/zirgen/circuit/bigint/field.cpp @@ -231,7 +231,7 @@ void genExtFieldXXOneMul(mlir::OpBuilder builder, mlir::Location loc, size_t bit rhs[i] = builder.create(loc, bitwidth, 12, i * chunkwidth); } auto prime = builder.create(loc, bitwidth, 13, 0); - auto primesqr = builder.create(loc, bitwidth, 14, 0); + auto primesqr = builder.create(loc, 2*bitwidth, 14, 0); auto result = BigInt::field::extMulXXONE(builder, loc, lhs, rhs, prime, primesqr); for (size_t i = 0; i < 2; i++) { builder.create(loc, result[i], 15, i * chunkwidth); From 516754bf2ddeda0198eda7c055873e34f360385f Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Mon, 13 Jan 2025 23:42:21 +0000 Subject: [PATCH 29/34] Clean up extension field naming --- zirgen/bootstrap/src/main.rs | 8 ++++---- zirgen/circuit/bigint/BUILD.bazel | 32 +++++++++++++++--------------- zirgen/circuit/bigint/bigint2c.cpp | 28 +++++++++++++------------- zirgen/circuit/bigint/field.cpp | 32 +++++++++++++++--------------- zirgen/circuit/bigint/field.h | 4 +++- 5 files changed, 53 insertions(+), 51 deletions(-) diff --git a/zirgen/bootstrap/src/main.rs b/zirgen/bootstrap/src/main.rs index 6aef47a3..99285db5 100644 --- a/zirgen/bootstrap/src/main.rs +++ b/zirgen/bootstrap/src/main.rs @@ -611,11 +611,11 @@ impl Bootstrap { let rsa_path = risc0_root.join("bigint2/src/rsa"); // TODO: Bitwidths on field ops - self.copy_file(&src_path, &field_path, "extfieldadd_256.blob"); - self.copy_file(&src_path, &field_path, "extfieldmul_256.blob"); - self.copy_file(&src_path, &field_path, "extfield_xxone_mul_256.blob"); + self.copy_file(&src_path, &field_path, "extfield_deg2_add_256.blob"); + self.copy_file(&src_path, &field_path, "extfield_deg2_mul_256.blob"); self.copy_file(&src_path, &field_path, "extfield_deg4_mul_256.blob"); - self.copy_file(&src_path, &field_path, "extfieldsub_256.blob"); + self.copy_file(&src_path, &field_path, "extfield_deg2_sub_256.blob"); + self.copy_file(&src_path, &field_path, "extfield_xxone_mul_256.blob"); self.copy_file(&src_path, &field_path, "modadd_256.blob"); self.copy_file(&src_path, &field_path, "modinv_256.blob"); self.copy_file(&src_path, &field_path, "modmul_256.blob"); diff --git a/zirgen/circuit/bigint/BUILD.bazel b/zirgen/circuit/bigint/BUILD.bazel index 83eace54..c0ff6551 100644 --- a/zirgen/circuit/bigint/BUILD.bazel +++ b/zirgen/circuit/bigint/BUILD.bazel @@ -57,11 +57,11 @@ BLOBS = [ "modpow65537_4096", "ec_double_256", "ec_add_256", - "extfieldadd_256", - "extfieldmul_256", - "extfield_xxone_mul_256", + "extfield_deg2_add_256", + "extfield_deg2_mul_256", "extfield_deg4_mul_256", - "extfieldsub_256", + "extfield_deg2_sub_256", + "extfield_xxone_mul_256", "modadd_256", "modinv_256", "modmul_256", @@ -90,10 +90,10 @@ genrule( ) genrule( - name = "extfieldadd_256", - outs = ["extfieldadd_256.blob"], + name = "extfield_deg2_add_256", + outs = ["extfield_deg2_add_256.blob"], exec_tools = [":bigint2c"], - cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfieldadd --bitwidth 256 > $(OUTS)" + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfield_deg2_add --bitwidth 256 > $(OUTS)" ) genrule( @@ -104,24 +104,24 @@ genrule( ) genrule( - name = "extfield_xxone_mul_256", - outs = ["extfield_xxone_mul_256.blob"], + name = "extfield_deg2_mul_256", + outs = ["extfield_deg2_mul_256.blob"], exec_tools = [":bigint2c"], - cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfield_xxone_mul --bitwidth 256 > $(OUTS)" + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfield_deg2_mul --bitwidth 256 > $(OUTS)" ) genrule( - name = "extfieldmul_256", - outs = ["extfieldmul_256.blob"], + name = "extfield_deg2_sub_256", + outs = ["extfield_deg2_sub_256.blob"], exec_tools = [":bigint2c"], - cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfieldmul --bitwidth 256 > $(OUTS)" + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfield_deg2_sub --bitwidth 256 > $(OUTS)" ) genrule( - name = "extfieldsub_256", - outs = ["extfieldsub_256.blob"], + name = "extfield_xxone_mul_256", + outs = ["extfield_xxone_mul_256.blob"], exec_tools = [":bigint2c"], - cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfieldsub --bitwidth 256 > $(OUTS)" + cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfield_xxone_mul --bitwidth 256 > $(OUTS)" ) genrule( diff --git a/zirgen/circuit/bigint/bigint2c.cpp b/zirgen/circuit/bigint/bigint2c.cpp index 34e33a11..92e8b691 100644 --- a/zirgen/circuit/bigint/bigint2c.cpp +++ b/zirgen/circuit/bigint/bigint2c.cpp @@ -44,11 +44,11 @@ enum class Program { ModPow65537, EC_Double, EC_Add, - ExtFieldAdd, - ExtFieldMul, + ExtField_Deg2_Add, + ExtField_Deg2_Mul, ExtField_Deg4_Mul, + ExtField_Deg2_Sub, ExtField_XXOne_Mul, - ExtFieldSub, ModAdd, ModInv, ModMul, @@ -62,13 +62,13 @@ static cl::opt cl::values(clEnumValN(Program::ModPow65537, "modpow65537", "ModPow65537"), clEnumValN(Program::EC_Double, "ec_double", "EC_Double"), clEnumValN(Program::EC_Add, "ec_add", "EC_Add"), - clEnumValN(Program::ExtFieldAdd, "extfieldadd", "ExtFieldAdd"), - clEnumValN(Program::ExtFieldMul, "extfieldmul", "ExtFieldMul"), - clEnumValN(Program::ExtField_XXOne_Mul, - "extfield_xxone_mul", "ExtField_XXOne_Mul"), + clEnumValN(Program::ExtField_Deg2_Add, "extfield_deg2_add", "ExtField_Deg2_Add"), + clEnumValN(Program::ExtField_Deg2_Mul, "extfield_deg2_mul", "ExtField_Deg2_Mul"), clEnumValN(Program::ExtField_Deg4_Mul, "extfield_deg4_mul", "ExtField_Deg4_Mul"), - clEnumValN(Program::ExtFieldSub, "extfieldsub", "ExtFieldSub"), + clEnumValN(Program::ExtField_Deg2_Sub, "extfield_deg2_sub", "ExtField_Deg2_Sub"), + clEnumValN(Program::ExtField_XXOne_Mul, + "extfield_xxone_mul", "ExtField_XXOne_Mul"), clEnumValN(Program::ModAdd, "modadd", "ModAdd"), clEnumValN(Program::ModInv, "modinv", "ModInv"), clEnumValN(Program::ModMul, "modmul", "ModMul"), @@ -454,21 +454,21 @@ int main(int argc, char* argv[]) { case Program::EC_Add: zirgen::BigInt::EC::genECAdd(builder, loc, bitwidth); break; - case Program::ExtFieldAdd: // TODO: Naming for degree 2 + case Program::ExtField_Deg2_Add: zirgen::BigInt::field::genExtFieldAdd(builder, loc, bitwidth, 2); break; - case Program::ExtFieldMul: + case Program::ExtField_Deg2_Mul: zirgen::BigInt::field::genExtFieldMul(builder, loc, bitwidth, 2); break; - case Program::ExtField_XXOne_Mul: - zirgen::BigInt::field::genExtFieldXXOneMul(builder, loc, bitwidth); - break; case Program::ExtField_Deg4_Mul: zirgen::BigInt::field::genExtFieldMul(builder, loc, bitwidth, 4); break; - case Program::ExtFieldSub: + case Program::ExtField_Deg2_Sub: zirgen::BigInt::field::genExtFieldSub(builder, loc, bitwidth, 2); break; + case Program::ExtField_XXOne_Mul: + zirgen::BigInt::field::genExtFieldXXOneMul(builder, loc, bitwidth); + break; case Program::ModAdd: zirgen::BigInt::field::genModAdd(builder, loc, bitwidth); break; diff --git a/zirgen/circuit/bigint/field.cpp b/zirgen/circuit/bigint/field.cpp index 3f78cc81..36be19b7 100644 --- a/zirgen/circuit/bigint/field.cpp +++ b/zirgen/circuit/bigint/field.cpp @@ -59,7 +59,7 @@ llvm::SmallVector extAdd(mlir::OpBuilder builder, mlir::Location loc, // Deg2 extfield mul with irreducible polynomial x^2+1 // (ax+b)(cx+d) == acxx-ac(xx+1) + (ad+bc)x + bd == (ad+bc)x + bd-ac -llvm::SmallVector extMulXXONE(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime, Value primesqr) { +llvm::SmallVector extXXOneMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime, Value primesqr) { assert(lhs.size() == 2); assert(rhs.size() == 2); llvm::SmallVector result(2); @@ -220,38 +220,38 @@ void genExtFieldMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth } } -void genExtFieldXXOneMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth) { +void genExtFieldSub(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree) { // TODO: will need to handle bitwidth slightly smaller than data chunks assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks size_t chunkwidth = bitwidth / 128; - llvm::SmallVector lhs(2); - llvm::SmallVector rhs(2); - for (size_t i = 0; i < 2; i++) { + llvm::SmallVector lhs(degree); + llvm::SmallVector rhs(degree); + for (size_t i = 0; i < degree; i++) { lhs[i] = builder.create(loc, bitwidth, 11, i * chunkwidth); rhs[i] = builder.create(loc, bitwidth, 12, i * chunkwidth); } auto prime = builder.create(loc, bitwidth, 13, 0); - auto primesqr = builder.create(loc, 2*bitwidth, 14, 0); - auto result = BigInt::field::extMulXXONE(builder, loc, lhs, rhs, prime, primesqr); - for (size_t i = 0; i < 2; i++) { - builder.create(loc, result[i], 15, i * chunkwidth); + auto result = BigInt::field::extSub(builder, loc, lhs, rhs, prime); + for (size_t i = 0; i < degree; i++) { + builder.create(loc, result[i], 14, i * chunkwidth); } } -void genExtFieldSub(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree) { +void genExtFieldXXOneMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth) { // TODO: will need to handle bitwidth slightly smaller than data chunks assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks size_t chunkwidth = bitwidth / 128; - llvm::SmallVector lhs(degree); - llvm::SmallVector rhs(degree); - for (size_t i = 0; i < degree; i++) { + llvm::SmallVector lhs(2); + llvm::SmallVector rhs(2); + for (size_t i = 0; i < 2; i++) { lhs[i] = builder.create(loc, bitwidth, 11, i * chunkwidth); rhs[i] = builder.create(loc, bitwidth, 12, i * chunkwidth); } auto prime = builder.create(loc, bitwidth, 13, 0); - auto result = BigInt::field::extSub(builder, loc, lhs, rhs, prime); - for (size_t i = 0; i < degree; i++) { - builder.create(loc, result[i], 14, i * chunkwidth); + auto primesqr = builder.create(loc, 2*bitwidth, 14, 0); + auto result = BigInt::field::extXXOneMul(builder, loc, lhs, rhs, prime, primesqr); + for (size_t i = 0; i < 2; i++) { + builder.create(loc, result[i], 15, i * chunkwidth); } } diff --git a/zirgen/circuit/bigint/field.h b/zirgen/circuit/bigint/field.h index 9c2d3750..d4fd13a5 100644 --- a/zirgen/circuit/bigint/field.h +++ b/zirgen/circuit/bigint/field.h @@ -29,6 +29,8 @@ namespace zirgen::BigInt::field { // field, but to use these functions you must represent the extension as the adjunction of a // primitive element to a prime order field, which is not always convenient (i.e. when you have // a tower of extensions) +// - The `ExtFieldXXOne` version of multiply is for specifically the field extension with +// irreducible polynomial `x^2 + 1` (i.e., extension by the square root of negative one) // // We do not use integer quotients in these functions, so minBits does not give us performance gains // and we therefore do not require the prime to be full bitwidth, enabling simpler generalization @@ -55,7 +57,7 @@ Value modSub(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, // TODO: ^ Hence the use of 2 in the SmallVectors ... but is this true? llvm::SmallVector extAdd(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, llvm::SmallVector monic_irred_poly, Value prime); -llvm::SmallVector extMulXXONE(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime, Value primesqr); +llvm::SmallVector extXXOneMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime, Value primesqr); llvm::SmallVector extSub(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); } // namespace zirgen::BigInt::field From 199f596ec777d1375b5b157759003d2fe9a61abe Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Mon, 13 Jan 2025 23:56:47 +0000 Subject: [PATCH 30/34] Remove handled TODOs --- zirgen/bootstrap/src/main.rs | 1 - zirgen/circuit/bigint/bigint2c.cpp | 2 +- zirgen/circuit/bigint/field.cpp | 6 ------ zirgen/circuit/bigint/field.h | 1 - 4 files changed, 1 insertion(+), 9 deletions(-) diff --git a/zirgen/bootstrap/src/main.rs b/zirgen/bootstrap/src/main.rs index 99285db5..67831698 100644 --- a/zirgen/bootstrap/src/main.rs +++ b/zirgen/bootstrap/src/main.rs @@ -610,7 +610,6 @@ impl Bootstrap { let field_path = risc0_root.join("bigint2/src/field"); let rsa_path = risc0_root.join("bigint2/src/rsa"); - // TODO: Bitwidths on field ops self.copy_file(&src_path, &field_path, "extfield_deg2_add_256.blob"); self.copy_file(&src_path, &field_path, "extfield_deg2_mul_256.blob"); self.copy_file(&src_path, &field_path, "extfield_deg4_mul_256.blob"); diff --git a/zirgen/circuit/bigint/bigint2c.cpp b/zirgen/circuit/bigint/bigint2c.cpp index 92e8b691..ee37e589 100644 --- a/zirgen/circuit/bigint/bigint2c.cpp +++ b/zirgen/circuit/bigint/bigint2c.cpp @@ -72,7 +72,7 @@ static cl::opt clEnumValN(Program::ModAdd, "modadd", "ModAdd"), clEnumValN(Program::ModInv, "modinv", "ModInv"), clEnumValN(Program::ModMul, "modmul", "ModMul"), - clEnumValN(Program::ModSub, "modsub", "ModSub")), // TODO: Don't hardcode bitwidth + clEnumValN(Program::ModSub, "modsub", "ModSub")), cl::Required); static cl::opt bitwidth("bitwidth", diff --git a/zirgen/circuit/bigint/field.cpp b/zirgen/circuit/bigint/field.cpp index 36be19b7..1dc5f942 100644 --- a/zirgen/circuit/bigint/field.cpp +++ b/zirgen/circuit/bigint/field.cpp @@ -80,7 +80,6 @@ llvm::SmallVector extXXOneMul(mlir::OpBuilder builder, mlir::Location llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, llvm::SmallVector monic_irred_poly, Value prime) { // TODO: Annoying to have a SmallVector output that needs to be deg - 1 bigger than the inputs; I think that means all should be 3... - // TODO: We could have a simplified version for nth roots x^n - a // Here `monic_irred_poly` is the coefficients a_i such that x^n - sum_i a_i x^i = 0 auto deg = lhs.size(); // Note: The field is not an extension field if deg <= 1 @@ -183,9 +182,7 @@ void genModSub(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth) { } // Extension fields we use are most commonly degree 2 -// TODO: ^ Hence the use of 2 in the SmallVectors ... but is this true? void genExtFieldAdd(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree) { - // TODO: will need to handle bitwidth slightly smaller than data chunks assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks size_t chunkwidth = bitwidth / 128; llvm::SmallVector lhs(degree); @@ -202,7 +199,6 @@ void genExtFieldAdd(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth } void genExtFieldMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree) { - // TODO: will need to handle bitwidth slightly smaller than data chunks assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks size_t chunkwidth = bitwidth / 128; llvm::SmallVector lhs(degree); @@ -221,7 +217,6 @@ void genExtFieldMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth } void genExtFieldSub(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree) { - // TODO: will need to handle bitwidth slightly smaller than data chunks assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks size_t chunkwidth = bitwidth / 128; llvm::SmallVector lhs(degree); @@ -238,7 +233,6 @@ void genExtFieldSub(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth } void genExtFieldXXOneMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth) { - // TODO: will need to handle bitwidth slightly smaller than data chunks assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks size_t chunkwidth = bitwidth / 128; llvm::SmallVector lhs(2); diff --git a/zirgen/circuit/bigint/field.h b/zirgen/circuit/bigint/field.h index d4fd13a5..6857a0f8 100644 --- a/zirgen/circuit/bigint/field.h +++ b/zirgen/circuit/bigint/field.h @@ -54,7 +54,6 @@ Value modSub(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, // Extension field arithmetic // Extension fields we use are most commonly degree 2 -// TODO: ^ Hence the use of 2 in the SmallVectors ... but is this true? llvm::SmallVector extAdd(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, llvm::SmallVector monic_irred_poly, Value prime); llvm::SmallVector extXXOneMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime, Value primesqr); From 2fa549deaf208693daff7c0389853a9b6c47e7d6 Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Tue, 14 Jan 2025 00:02:17 +0000 Subject: [PATCH 31/34] Update SmallVector sizes --- zirgen/circuit/bigint/field.cpp | 35 ++++++++++++++++----------------- zirgen/circuit/bigint/field.h | 8 ++++---- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/zirgen/circuit/bigint/field.cpp b/zirgen/circuit/bigint/field.cpp index 1dc5f942..39566567 100644 --- a/zirgen/circuit/bigint/field.cpp +++ b/zirgen/circuit/bigint/field.cpp @@ -45,10 +45,10 @@ Value modSub(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, // Extension field operations -llvm::SmallVector extAdd(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime) { +llvm::SmallVector extAdd(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime) { auto deg = lhs.size(); assert(rhs.size() == deg); - llvm::SmallVector result(deg); + llvm::SmallVector result(deg); for (size_t i = 0; i < deg; i++) { auto sum = builder.create(loc, lhs[i], rhs[i]); @@ -59,10 +59,10 @@ llvm::SmallVector extAdd(mlir::OpBuilder builder, mlir::Location loc, // Deg2 extfield mul with irreducible polynomial x^2+1 // (ax+b)(cx+d) == acxx-ac(xx+1) + (ad+bc)x + bd == (ad+bc)x + bd-ac -llvm::SmallVector extXXOneMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime, Value primesqr) { +llvm::SmallVector extXXOneMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime, Value primesqr) { assert(lhs.size() == 2); assert(rhs.size() == 2); - llvm::SmallVector result(2); + llvm::SmallVector result(2); auto ad = builder.create(loc, lhs[1], rhs[0]); auto bc = builder.create(loc, lhs[0], rhs[1]); @@ -78,15 +78,14 @@ llvm::SmallVector extXXOneMul(mlir::OpBuilder builder, mlir::Location return result; } -llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, llvm::SmallVector monic_irred_poly, Value prime) { - // TODO: Annoying to have a SmallVector output that needs to be deg - 1 bigger than the inputs; I think that means all should be 3... +llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, llvm::SmallVector monic_irred_poly, Value prime) { // Here `monic_irred_poly` is the coefficients a_i such that x^n - sum_i a_i x^i = 0 auto deg = lhs.size(); // Note: The field is not an extension field if deg <= 1 assert(deg > 1); assert(rhs.size() == deg); assert(monic_irred_poly.size() == deg); - llvm::SmallVector result(2 * deg - 1); + llvm::SmallVector result(2 * deg - 1); llvm::SmallVector first_write(2 * deg - 1, true); // Compute product of polynomials @@ -119,10 +118,10 @@ llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, return result; } -llvm::SmallVector extSub(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime) { +llvm::SmallVector extSub(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime) { auto deg = lhs.size(); assert(rhs.size() == deg); - llvm::SmallVector result(deg); + llvm::SmallVector result(deg); for (size_t i = 0; i < deg; i++) { // auto diff = builder.create(loc, lhs[i], rhs[i]); @@ -185,8 +184,8 @@ void genModSub(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth) { void genExtFieldAdd(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree) { assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks size_t chunkwidth = bitwidth / 128; - llvm::SmallVector lhs(degree); - llvm::SmallVector rhs(degree); + llvm::SmallVector lhs(degree); + llvm::SmallVector rhs(degree); for (size_t i = 0; i < degree; i++) { lhs[i] = builder.create(loc, bitwidth, 11, i * chunkwidth); rhs[i] = builder.create(loc, bitwidth, 12, i * chunkwidth); @@ -201,9 +200,9 @@ void genExtFieldAdd(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth void genExtFieldMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree) { assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks size_t chunkwidth = bitwidth / 128; - llvm::SmallVector lhs(degree); - llvm::SmallVector rhs(degree); - llvm::SmallVector monic_irred_poly(degree); + llvm::SmallVector lhs(degree); + llvm::SmallVector rhs(degree); + llvm::SmallVector monic_irred_poly(degree); for (size_t i = 0; i < degree; i++) { lhs[i] = builder.create(loc, bitwidth, 11, i * chunkwidth); rhs[i] = builder.create(loc, bitwidth, 12, i * chunkwidth); @@ -219,8 +218,8 @@ void genExtFieldMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth void genExtFieldSub(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree) { assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks size_t chunkwidth = bitwidth / 128; - llvm::SmallVector lhs(degree); - llvm::SmallVector rhs(degree); + llvm::SmallVector lhs(degree); + llvm::SmallVector rhs(degree); for (size_t i = 0; i < degree; i++) { lhs[i] = builder.create(loc, bitwidth, 11, i * chunkwidth); rhs[i] = builder.create(loc, bitwidth, 12, i * chunkwidth); @@ -235,8 +234,8 @@ void genExtFieldSub(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth void genExtFieldXXOneMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth) { assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks size_t chunkwidth = bitwidth / 128; - llvm::SmallVector lhs(2); - llvm::SmallVector rhs(2); + llvm::SmallVector lhs(2); + llvm::SmallVector rhs(2); for (size_t i = 0; i < 2; i++) { lhs[i] = builder.create(loc, bitwidth, 11, i * chunkwidth); rhs[i] = builder.create(loc, bitwidth, 12, i * chunkwidth); diff --git a/zirgen/circuit/bigint/field.h b/zirgen/circuit/bigint/field.h index 6857a0f8..ab7a44bb 100644 --- a/zirgen/circuit/bigint/field.h +++ b/zirgen/circuit/bigint/field.h @@ -54,9 +54,9 @@ Value modSub(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, // Extension field arithmetic // Extension fields we use are most commonly degree 2 -llvm::SmallVector extAdd(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); -llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, llvm::SmallVector monic_irred_poly, Value prime); -llvm::SmallVector extXXOneMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime, Value primesqr); -llvm::SmallVector extSub(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); +llvm::SmallVector extAdd(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); +llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, llvm::SmallVector monic_irred_poly, Value prime); +llvm::SmallVector extXXOneMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime, Value primesqr); +llvm::SmallVector extSub(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); } // namespace zirgen::BigInt::field From 43ff2b679215daef6b99b3561950aaaf5de439b7 Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Tue, 14 Jan 2025 05:05:13 +0000 Subject: [PATCH 32/34] Add comment --- zirgen/circuit/bigint/field.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/zirgen/circuit/bigint/field.cpp b/zirgen/circuit/bigint/field.cpp index 39566567..6e8f6692 100644 --- a/zirgen/circuit/bigint/field.cpp +++ b/zirgen/circuit/bigint/field.cpp @@ -59,6 +59,8 @@ llvm::SmallVector extAdd(mlir::OpBuilder builder, mlir::Location loc, // Deg2 extfield mul with irreducible polynomial x^2+1 // (ax+b)(cx+d) == acxx-ac(xx+1) + (ad+bc)x + bd == (ad+bc)x + bd-ac +// This is a more optimized algorithm specialized to the x^2+1 polynomial; +// you could also use the degree 2 extMul code for this, but it is generally slower llvm::SmallVector extXXOneMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime, Value primesqr) { assert(lhs.size() == 2); assert(rhs.size() == 2); From 3e0a53b16447d6a9e68bfed29d52515c56963449 Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Tue, 14 Jan 2025 05:13:50 +0000 Subject: [PATCH 33/34] Clang format --- zirgen/circuit/bigint/bigint2c.cpp | 34 +++--- zirgen/circuit/bigint/field.cpp | 172 ++++++++++++++++------------- zirgen/circuit/bigint/field.h | 26 ++++- 3 files changed, 133 insertions(+), 99 deletions(-) diff --git a/zirgen/circuit/bigint/bigint2c.cpp b/zirgen/circuit/bigint/bigint2c.cpp index ee37e589..22c004b0 100644 --- a/zirgen/circuit/bigint/bigint2c.cpp +++ b/zirgen/circuit/bigint/bigint2c.cpp @@ -56,24 +56,22 @@ enum class Program { }; } // namespace -static cl::opt - program("program", - cl::desc("The program to compile"), - cl::values(clEnumValN(Program::ModPow65537, "modpow65537", "ModPow65537"), - clEnumValN(Program::EC_Double, "ec_double", "EC_Double"), - clEnumValN(Program::EC_Add, "ec_add", "EC_Add"), - clEnumValN(Program::ExtField_Deg2_Add, "extfield_deg2_add", "ExtField_Deg2_Add"), - clEnumValN(Program::ExtField_Deg2_Mul, "extfield_deg2_mul", "ExtField_Deg2_Mul"), - clEnumValN(Program::ExtField_Deg4_Mul, - "extfield_deg4_mul", "ExtField_Deg4_Mul"), - clEnumValN(Program::ExtField_Deg2_Sub, "extfield_deg2_sub", "ExtField_Deg2_Sub"), - clEnumValN(Program::ExtField_XXOne_Mul, - "extfield_xxone_mul", "ExtField_XXOne_Mul"), - clEnumValN(Program::ModAdd, "modadd", "ModAdd"), - clEnumValN(Program::ModInv, "modinv", "ModInv"), - clEnumValN(Program::ModMul, "modmul", "ModMul"), - clEnumValN(Program::ModSub, "modsub", "ModSub")), - cl::Required); +static cl::opt program( + "program", + cl::desc("The program to compile"), + cl::values(clEnumValN(Program::ModPow65537, "modpow65537", "ModPow65537"), + clEnumValN(Program::EC_Double, "ec_double", "EC_Double"), + clEnumValN(Program::EC_Add, "ec_add", "EC_Add"), + clEnumValN(Program::ExtField_Deg2_Add, "extfield_deg2_add", "ExtField_Deg2_Add"), + clEnumValN(Program::ExtField_Deg2_Mul, "extfield_deg2_mul", "ExtField_Deg2_Mul"), + clEnumValN(Program::ExtField_Deg4_Mul, "extfield_deg4_mul", "ExtField_Deg4_Mul"), + clEnumValN(Program::ExtField_Deg2_Sub, "extfield_deg2_sub", "ExtField_Deg2_Sub"), + clEnumValN(Program::ExtField_XXOne_Mul, "extfield_xxone_mul", "ExtField_XXOne_Mul"), + clEnumValN(Program::ModAdd, "modadd", "ModAdd"), + clEnumValN(Program::ModInv, "modinv", "ModInv"), + clEnumValN(Program::ModMul, "modmul", "ModMul"), + clEnumValN(Program::ModSub, "modsub", "ModSub")), + cl::Required); static cl::opt bitwidth("bitwidth", cl::desc("The bitwidth of program parameters"), diff --git a/zirgen/circuit/bigint/field.cpp b/zirgen/circuit/bigint/field.cpp index 6e8f6692..7bb49f1f 100644 --- a/zirgen/circuit/bigint/field.cpp +++ b/zirgen/circuit/bigint/field.cpp @@ -45,94 +45,112 @@ Value modSub(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, // Extension field operations -llvm::SmallVector extAdd(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime) { - auto deg = lhs.size(); - assert(rhs.size() == deg); - llvm::SmallVector result(deg); - - for (size_t i = 0; i < deg; i++) { - auto sum = builder.create(loc, lhs[i], rhs[i]); - result[i] = builder.create(loc, sum, prime); - } - return result; +llvm::SmallVector extAdd(mlir::OpBuilder builder, + mlir::Location loc, + llvm::SmallVector lhs, + llvm::SmallVector rhs, + Value prime) { + auto deg = lhs.size(); + assert(rhs.size() == deg); + llvm::SmallVector result(deg); + + for (size_t i = 0; i < deg; i++) { + auto sum = builder.create(loc, lhs[i], rhs[i]); + result[i] = builder.create(loc, sum, prime); + } + return result; } // Deg2 extfield mul with irreducible polynomial x^2+1 // (ax+b)(cx+d) == acxx-ac(xx+1) + (ad+bc)x + bd == (ad+bc)x + bd-ac // This is a more optimized algorithm specialized to the x^2+1 polynomial; // you could also use the degree 2 extMul code for this, but it is generally slower -llvm::SmallVector extXXOneMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime, Value primesqr) { - assert(lhs.size() == 2); - assert(rhs.size() == 2); - llvm::SmallVector result(2); - - auto ad = builder.create(loc, lhs[1], rhs[0]); - auto bc = builder.create(loc, lhs[0], rhs[1]); - result[1] = builder.create(loc, ad, bc); - result[1] = builder.create(loc, result[1], prime); - - auto bd = builder.create(loc, lhs[0], rhs[0]); - auto ac = builder.create(loc, lhs[1], rhs[1]); - result[0] = builder.create(loc, bd, ac); - result[0] = builder.create(loc, result[0], primesqr); - result[0] = builder.create(loc, result[0], prime); - - return result; -} - -llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, llvm::SmallVector monic_irred_poly, Value prime) { - // Here `monic_irred_poly` is the coefficients a_i such that x^n - sum_i a_i x^i = 0 - auto deg = lhs.size(); - // Note: The field is not an extension field if deg <= 1 - assert(deg > 1); - assert(rhs.size() == deg); - assert(monic_irred_poly.size() == deg); - llvm::SmallVector result(2 * deg - 1); - llvm::SmallVector first_write(2 * deg - 1, true); - - // Compute product of polynomials - for (size_t i = 0; i < deg; i++) { - for (size_t j = 0; j < deg; j++) { - size_t idx = i + j; - auto prod = builder.create(loc, lhs[i], rhs[j]); - auto reduced_prod = builder.create(loc, prod, prime); - if (first_write[idx]) { - result[idx] = reduced_prod; - first_write[idx] = false; - } else { - result[idx] = builder.create(loc, result[idx], reduced_prod); - result[idx] = builder.create(loc, result[idx], prime); - } - } +llvm::SmallVector extXXOneMul(mlir::OpBuilder builder, + mlir::Location loc, + llvm::SmallVector lhs, + llvm::SmallVector rhs, + Value prime, + Value primesqr) { + assert(lhs.size() == 2); + assert(rhs.size() == 2); + llvm::SmallVector result(2); + + auto ad = builder.create(loc, lhs[1], rhs[0]); + auto bc = builder.create(loc, lhs[0], rhs[1]); + result[1] = builder.create(loc, ad, bc); + result[1] = builder.create(loc, result[1], prime); + + auto bd = builder.create(loc, lhs[0], rhs[0]); + auto ac = builder.create(loc, lhs[1], rhs[1]); + result[0] = builder.create(loc, bd, ac); + result[0] = builder.create(loc, result[0], primesqr); + result[0] = builder.create(loc, result[0], prime); + + return result; +} + +llvm::SmallVector extMul(mlir::OpBuilder builder, + mlir::Location loc, + llvm::SmallVector lhs, + llvm::SmallVector rhs, + llvm::SmallVector monic_irred_poly, + Value prime) { + // Here `monic_irred_poly` is the coefficients a_i such that x^n - sum_i a_i x^i = 0 + auto deg = lhs.size(); + // Note: The field is not an extension field if deg <= 1 + assert(deg > 1); + assert(rhs.size() == deg); + assert(monic_irred_poly.size() == deg); + llvm::SmallVector result(2 * deg - 1); + llvm::SmallVector first_write(2 * deg - 1, true); + + // Compute product of polynomials + for (size_t i = 0; i < deg; i++) { + for (size_t j = 0; j < deg; j++) { + size_t idx = i + j; + auto prod = builder.create(loc, lhs[i], rhs[j]); + auto reduced_prod = builder.create(loc, prod, prime); + if (first_write[idx]) { + result[idx] = reduced_prod; + first_write[idx] = false; + } else { + result[idx] = builder.create(loc, result[idx], reduced_prod); + result[idx] = builder.create(loc, result[idx], prime); + } } - // Reduce using the monic irred polynomial of the extension field - for (size_t i = 2 * deg - 2; i >= deg; i--) { - for (size_t j = 0; j < deg; j++) { - auto prod = builder.create(loc, result[i], monic_irred_poly[j]); - result[i - deg + j] = builder.create(loc, result[i - deg + j], prod); - result[i - deg + j] = builder.create(loc, result[i - deg + j], prime); - } - // No need to zero out result[i], it will just get dropped + } + // Reduce using the monic irred polynomial of the extension field + for (size_t i = 2 * deg - 2; i >= deg; i--) { + for (size_t j = 0; j < deg; j++) { + auto prod = builder.create(loc, result[i], monic_irred_poly[j]); + result[i - deg + j] = builder.create(loc, result[i - deg + j], prod); + result[i - deg + j] = builder.create(loc, result[i - deg + j], prime); } - // Result's degree is just `deg`, drop the coefficients beyond that - result.truncate(deg); + // No need to zero out result[i], it will just get dropped + } + // Result's degree is just `deg`, drop the coefficients beyond that + result.truncate(deg); - return result; + return result; } -llvm::SmallVector extSub(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime) { - auto deg = lhs.size(); - assert(rhs.size() == deg); - llvm::SmallVector result(deg); +llvm::SmallVector extSub(mlir::OpBuilder builder, + mlir::Location loc, + llvm::SmallVector lhs, + llvm::SmallVector rhs, + Value prime) { + auto deg = lhs.size(); + assert(rhs.size() == deg); + llvm::SmallVector result(deg); - for (size_t i = 0; i < deg; i++) { - // auto diff = builder.create(loc, lhs[i], rhs[i]); - auto diff = builder.create(loc, lhs[i], rhs[i]); - //Add `prime` due to the same reason as in modSub - auto diff_aug = builder.create(loc, diff, prime); - result[i] = builder.create(loc, diff_aug, prime); - } - return result; + for (size_t i = 0; i < deg; i++) { + // auto diff = builder.create(loc, lhs[i], rhs[i]); + auto diff = builder.create(loc, lhs[i], rhs[i]); + // Add `prime` due to the same reason as in modSub + auto diff_aug = builder.create(loc, diff, prime); + result[i] = builder.create(loc, diff_aug, prime); + } + return result; } // Full programs, including I/O @@ -243,7 +261,7 @@ void genExtFieldXXOneMul(mlir::OpBuilder builder, mlir::Location loc, size_t bit rhs[i] = builder.create(loc, bitwidth, 12, i * chunkwidth); } auto prime = builder.create(loc, bitwidth, 13, 0); - auto primesqr = builder.create(loc, 2*bitwidth, 14, 0); + auto primesqr = builder.create(loc, 2 * bitwidth, 14, 0); auto result = BigInt::field::extXXOneMul(builder, loc, lhs, rhs, prime, primesqr); for (size_t i = 0; i < 2; i++) { builder.create(loc, result[i], 15, i * chunkwidth); diff --git a/zirgen/circuit/bigint/field.h b/zirgen/circuit/bigint/field.h index ab7a44bb..fa49c65e 100644 --- a/zirgen/circuit/bigint/field.h +++ b/zirgen/circuit/bigint/field.h @@ -54,9 +54,27 @@ Value modSub(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs, // Extension field arithmetic // Extension fields we use are most commonly degree 2 -llvm::SmallVector extAdd(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); -llvm::SmallVector extMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, llvm::SmallVector monic_irred_poly, Value prime); -llvm::SmallVector extXXOneMul(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime, Value primesqr); -llvm::SmallVector extSub(mlir::OpBuilder builder, mlir::Location loc, llvm::SmallVector lhs, llvm::SmallVector rhs, Value prime); +llvm::SmallVector extAdd(mlir::OpBuilder builder, + mlir::Location loc, + llvm::SmallVector lhs, + llvm::SmallVector rhs, + Value prime); +llvm::SmallVector extMul(mlir::OpBuilder builder, + mlir::Location loc, + llvm::SmallVector lhs, + llvm::SmallVector rhs, + llvm::SmallVector monic_irred_poly, + Value prime); +llvm::SmallVector extXXOneMul(mlir::OpBuilder builder, + mlir::Location loc, + llvm::SmallVector lhs, + llvm::SmallVector rhs, + Value prime, + Value primesqr); +llvm::SmallVector extSub(mlir::OpBuilder builder, + mlir::Location loc, + llvm::SmallVector lhs, + llvm::SmallVector rhs, + Value prime); } // namespace zirgen::BigInt::field From 50623725ccd7b90ed14c9fbf566b906696f895f0 Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Tue, 14 Jan 2025 17:39:52 +0000 Subject: [PATCH 34/34] Update license dates --- zirgen/bootstrap/src/main.rs | 2 +- zirgen/circuit/bigint/bigint2c.cpp | 2 +- zirgen/circuit/bigint/field.cpp | 2 +- zirgen/circuit/bigint/field.h | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/zirgen/bootstrap/src/main.rs b/zirgen/bootstrap/src/main.rs index 67831698..f5485cbc 100644 --- a/zirgen/bootstrap/src/main.rs +++ b/zirgen/bootstrap/src/main.rs @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/zirgen/circuit/bigint/bigint2c.cpp b/zirgen/circuit/bigint/bigint2c.cpp index 22c004b0..b859eb15 100644 --- a/zirgen/circuit/bigint/bigint2c.cpp +++ b/zirgen/circuit/bigint/bigint2c.cpp @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/zirgen/circuit/bigint/field.cpp b/zirgen/circuit/bigint/field.cpp index 7bb49f1f..72110161 100644 --- a/zirgen/circuit/bigint/field.cpp +++ b/zirgen/circuit/bigint/field.cpp @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/zirgen/circuit/bigint/field.h b/zirgen/circuit/bigint/field.h index fa49c65e..267fa189 100644 --- a/zirgen/circuit/bigint/field.h +++ b/zirgen/circuit/bigint/field.h @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License.