Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZKVM-910: ZIR-325: Add field extension operations #159

Merged
merged 44 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
de6f134
Draft modmul accelerator
tzerrell Nov 22, 2024
59ac726
Draft field & field extension ops
tzerrell Nov 28, 2024
9accf05
Merge branch 'main' into tzerrell/bigint2-modmul
tzerrell Nov 28, 2024
32ad2a8
Update field op draft with codegen
tzerrell Dec 3, 2024
6a20ac4
Merge branch 'main' into tzerrell/bigint2-modmul
tzerrell Dec 3, 2024
04ed107
WIP Add modadd
tzerrell Dec 3, 2024
ac6c5ed
Build modadd program
tzerrell Dec 3, 2024
6b7dd1f
build the bigint blobs when building //zirgen/circuit
mars-risc0 Dec 3, 2024
b7aa908
Export more bigint programs
tzerrell Dec 3, 2024
c84f3ae
Merge branch 'mars/add-target-to-build-circuit' into tzerrell/bigint2…
tzerrell Dec 3, 2024
6c1e663
Complete new program exports
tzerrell Dec 3, 2024
ba358df
Merge branch 'main' into tzerrell/bigint2-modmul
tzerrell Dec 3, 2024
bbaa189
Move program gen to program files
tzerrell Dec 3, 2024
2c1024e
Make builder, loc order more consistent
tzerrell Dec 5, 2024
25c1c9d
Parameterize bitwidth on circuit generation
tzerrell Dec 5, 2024
6fe1310
Remove TODO comment for cleaned code
tzerrell Dec 5, 2024
443dd19
Naming
tzerrell Dec 5, 2024
cddf63d
Format
tzerrell Dec 5, 2024
cb0990b
Merge branch 'main' into tzerrell/bigint-cleanups
tzerrell Dec 5, 2024
b169c27
Merge branch 'tzerrell/bigint-cleanups' into tzerrell/bigint2-modmul
tzerrell Dec 5, 2024
a9637b6
Parameterize bitwidths on the new ops
tzerrell Dec 5, 2024
df85f0e
Move field gen ops into their file
tzerrell Dec 5, 2024
1a9fe6e
Fix modsub
tzerrell Dec 5, 2024
c62daf3
Merge branch 'main' into tzerrell/bigint2-field-ops
tzerrell Dec 12, 2024
afc8b5a
Merge branch 'main' into tzerrell/bigint2-field-ops
tzerrell Dec 13, 2024
0defd41
Merge branch 'main' into tzerrell/bigint2-field-ops
tzerrell Dec 13, 2024
ef3e665
Drop extension field code (for now)
tzerrell Dec 13, 2024
57f2f90
Clear completed TODOs
tzerrell Dec 13, 2024
dcd455e
Format
tzerrell Dec 13, 2024
01ed946
Revert "Drop extension field code (for now)"
tzerrell Dec 16, 2024
3bed407
Drop extraneous op from ext field mul
tzerrell Dec 17, 2024
ba0aeae
extSub fix, same as modSub
iddo-bentov Dec 18, 2024
0646a03
extfield mult params ordering fix
iddo-bentov Dec 18, 2024
298b894
deg4 extmul
iddo-bentov Jan 8, 2025
0476566
extmulxxone for specific xx+1 irreducible poly
iddo-bentov Jan 9, 2025
bbb7ad8
fix primesqr
iddo-bentov Jan 9, 2025
9ac30a1
fix: double bitwidth for primesqr
iddo-bentov Jan 9, 2025
516754b
Clean up extension field naming
tzerrell Jan 13, 2025
199f596
Remove handled TODOs
tzerrell Jan 13, 2025
2fa549d
Update SmallVector sizes
tzerrell Jan 14, 2025
43ff2b6
Add comment
tzerrell Jan 14, 2025
3e0a53b
Clang format
tzerrell Jan 14, 2025
5062372
Update license dates
tzerrell Jan 14, 2025
c23e9f9
Merge branch 'main' into tzerrell/bigint2-ext-field
tzerrell Jan 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion zirgen/bootstrap/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2024 RISC Zero, Inc.
// Copyright 2025 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -658,6 +658,11 @@ impl Bootstrap {
let field_path = risc0_root.join("bigint2/src/field");
let rsa_path = risc0_root.join("bigint2/src/rsa");

self.copy_file(&src_path, &field_path, "extfield_deg2_add_256.blob");
self.copy_file(&src_path, &field_path, "extfield_deg2_mul_256.blob");
self.copy_file(&src_path, &field_path, "extfield_deg4_mul_256.blob");
self.copy_file(&src_path, &field_path, "extfield_deg2_sub_256.blob");
self.copy_file(&src_path, &field_path, "extfield_xxone_mul_256.blob");
self.copy_file(&src_path, &field_path, "modadd_256.blob");
self.copy_file(&src_path, &field_path, "modinv_256.blob");
self.copy_file(&src_path, &field_path, "modmul_256.blob");
Expand Down
40 changes: 40 additions & 0 deletions zirgen/circuit/bigint/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ BLOBS = [
"modpow65537_4096",
"ec_double_256",
"ec_add_256",
"extfield_deg2_add_256",
"extfield_deg2_mul_256",
"extfield_deg4_mul_256",
"extfield_deg2_sub_256",
"extfield_xxone_mul_256",
"modadd_256",
"modinv_256",
"modmul_256",
Expand Down Expand Up @@ -84,6 +89,41 @@ genrule(
cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=ec_add --bitwidth 256 > $(OUTS)"
)

genrule(
name = "extfield_deg2_add_256",
outs = ["extfield_deg2_add_256.blob"],
exec_tools = [":bigint2c"],
cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfield_deg2_add --bitwidth 256 > $(OUTS)"
)

genrule(
name = "extfield_deg4_mul_256",
outs = ["extfield_deg4_mul_256.blob"],
exec_tools = [":bigint2c"],
cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfield_deg4_mul --bitwidth 256 > $(OUTS)"
)

genrule(
name = "extfield_deg2_mul_256",
outs = ["extfield_deg2_mul_256.blob"],
exec_tools = [":bigint2c"],
cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfield_deg2_mul --bitwidth 256 > $(OUTS)"
)

genrule(
name = "extfield_deg2_sub_256",
outs = ["extfield_deg2_sub_256.blob"],
exec_tools = [":bigint2c"],
cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfield_deg2_sub --bitwidth 256 > $(OUTS)"
)

genrule(
name = "extfield_xxone_mul_256",
outs = ["extfield_xxone_mul_256.blob"],
exec_tools = [":bigint2c"],
cmd = "$(location //zirgen/circuit/bigint:bigint2c) --program=extfield_xxone_mul --bitwidth 256 > $(OUTS)"
)

genrule(
name = "modadd_256",
outs = ["modadd_256.blob"],
Expand Down
49 changes: 37 additions & 12 deletions zirgen/circuit/bigint/bigint2c.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2024 RISC Zero, Inc.
// Copyright 2025 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -44,24 +44,34 @@ enum class Program {
ModPow65537,
EC_Double,
EC_Add,
ExtField_Deg2_Add,
ExtField_Deg2_Mul,
ExtField_Deg4_Mul,
ExtField_Deg2_Sub,
ExtField_XXOne_Mul,
ModAdd,
ModInv,
ModMul,
ModSub,
};
} // namespace

static cl::opt<enum Program>
program("program",
cl::desc("The program to compile"),
cl::values(clEnumValN(Program::ModPow65537, "modpow65537", "ModPow65537"),
clEnumValN(Program::EC_Double, "ec_double", "EC_Double"),
clEnumValN(Program::EC_Add, "ec_add", "EC_Add"),
clEnumValN(Program::ModAdd, "modadd", "ModAdd"),
clEnumValN(Program::ModInv, "modinv", "ModInv"),
clEnumValN(Program::ModMul, "modmul", "ModMul"),
clEnumValN(Program::ModSub, "modsub", "ModSub")),
cl::Required);
static cl::opt<enum Program> program(
"program",
cl::desc("The program to compile"),
cl::values(clEnumValN(Program::ModPow65537, "modpow65537", "ModPow65537"),
clEnumValN(Program::EC_Double, "ec_double", "EC_Double"),
clEnumValN(Program::EC_Add, "ec_add", "EC_Add"),
clEnumValN(Program::ExtField_Deg2_Add, "extfield_deg2_add", "ExtField_Deg2_Add"),
clEnumValN(Program::ExtField_Deg2_Mul, "extfield_deg2_mul", "ExtField_Deg2_Mul"),
clEnumValN(Program::ExtField_Deg4_Mul, "extfield_deg4_mul", "ExtField_Deg4_Mul"),
clEnumValN(Program::ExtField_Deg2_Sub, "extfield_deg2_sub", "ExtField_Deg2_Sub"),
clEnumValN(Program::ExtField_XXOne_Mul, "extfield_xxone_mul", "ExtField_XXOne_Mul"),
clEnumValN(Program::ModAdd, "modadd", "ModAdd"),
clEnumValN(Program::ModInv, "modinv", "ModInv"),
clEnumValN(Program::ModMul, "modmul", "ModMul"),
clEnumValN(Program::ModSub, "modsub", "ModSub")),
cl::Required);

static cl::opt<size_t> bitwidth("bitwidth",
cl::desc("The bitwidth of program parameters"),
Expand Down Expand Up @@ -442,6 +452,21 @@ int main(int argc, char* argv[]) {
case Program::EC_Add:
zirgen::BigInt::EC::genECAdd(builder, loc, bitwidth);
break;
case Program::ExtField_Deg2_Add:
zirgen::BigInt::field::genExtFieldAdd(builder, loc, bitwidth, 2);
break;
case Program::ExtField_Deg2_Mul:
zirgen::BigInt::field::genExtFieldMul(builder, loc, bitwidth, 2);
break;
case Program::ExtField_Deg4_Mul:
zirgen::BigInt::field::genExtFieldMul(builder, loc, bitwidth, 4);
break;
case Program::ExtField_Deg2_Sub:
zirgen::BigInt::field::genExtFieldSub(builder, loc, bitwidth, 2);
break;
case Program::ExtField_XXOne_Mul:
zirgen::BigInt::field::genExtFieldXXOneMul(builder, loc, bitwidth);
break;
case Program::ModAdd:
zirgen::BigInt::field::genModAdd(builder, loc, bitwidth);
break;
Expand Down
186 changes: 184 additions & 2 deletions zirgen/circuit/bigint/field.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2024 RISC Zero, Inc.
// Copyright 2025 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -43,13 +43,127 @@ Value modSub(mlir::OpBuilder builder, mlir::Location loc, Value lhs, Value rhs,
return result;
}

// Extension field operations

llvm::SmallVector<Value, 3> extAdd(mlir::OpBuilder builder,
mlir::Location loc,
llvm::SmallVector<Value, 3> lhs,
llvm::SmallVector<Value, 3> rhs,
Value prime) {
auto deg = lhs.size();
assert(rhs.size() == deg);
llvm::SmallVector<Value, 3> result(deg);

for (size_t i = 0; i < deg; i++) {
auto sum = builder.create<BigInt::AddOp>(loc, lhs[i], rhs[i]);
result[i] = builder.create<BigInt::ReduceOp>(loc, sum, prime);
}
return result;
}

// Deg2 extfield mul with irreducible polynomial x^2+1
// (ax+b)(cx+d) == acxx-ac(xx+1) + (ad+bc)x + bd == (ad+bc)x + bd-ac
// This is a more optimized algorithm specialized to the x^2+1 polynomial;
// you could also use the degree 2 extMul code for this, but it is generally slower
llvm::SmallVector<Value, 3> extXXOneMul(mlir::OpBuilder builder,
mlir::Location loc,
llvm::SmallVector<Value, 3> lhs,
llvm::SmallVector<Value, 3> rhs,
Value prime,
Value primesqr) {
assert(lhs.size() == 2);
assert(rhs.size() == 2);
llvm::SmallVector<Value, 3> result(2);

auto ad = builder.create<BigInt::MulOp>(loc, lhs[1], rhs[0]);
auto bc = builder.create<BigInt::MulOp>(loc, lhs[0], rhs[1]);
result[1] = builder.create<BigInt::AddOp>(loc, ad, bc);
result[1] = builder.create<BigInt::ReduceOp>(loc, result[1], prime);

auto bd = builder.create<BigInt::MulOp>(loc, lhs[0], rhs[0]);
auto ac = builder.create<BigInt::MulOp>(loc, lhs[1], rhs[1]);
result[0] = builder.create<BigInt::SubOp>(loc, bd, ac);
result[0] = builder.create<BigInt::AddOp>(loc, result[0], primesqr);
result[0] = builder.create<BigInt::ReduceOp>(loc, result[0], prime);

return result;
}

llvm::SmallVector<Value, 3> extMul(mlir::OpBuilder builder,
mlir::Location loc,
llvm::SmallVector<Value, 3> lhs,
llvm::SmallVector<Value, 3> rhs,
llvm::SmallVector<Value, 3> monic_irred_poly,
Value prime) {
// Here `monic_irred_poly` is the coefficients a_i such that x^n - sum_i a_i x^i = 0
auto deg = lhs.size();
// Note: The field is not an extension field if deg <= 1
assert(deg > 1);
assert(rhs.size() == deg);
assert(monic_irred_poly.size() == deg);
llvm::SmallVector<Value, 3> result(2 * deg - 1);
llvm::SmallVector<bool, 2> first_write(2 * deg - 1, true);

// Compute product of polynomials
for (size_t i = 0; i < deg; i++) {
for (size_t j = 0; j < deg; j++) {
size_t idx = i + j;
auto prod = builder.create<BigInt::MulOp>(loc, lhs[i], rhs[j]);
auto reduced_prod = builder.create<BigInt::ReduceOp>(loc, prod, prime);
if (first_write[idx]) {
result[idx] = reduced_prod;
first_write[idx] = false;
} else {
result[idx] = builder.create<BigInt::AddOp>(loc, result[idx], reduced_prod);
result[idx] = builder.create<BigInt::ReduceOp>(loc, result[idx], prime);
}
}
}
// Reduce using the monic irred polynomial of the extension field
for (size_t i = 2 * deg - 2; i >= deg; i--) {
for (size_t j = 0; j < deg; j++) {
auto prod = builder.create<BigInt::MulOp>(loc, result[i], monic_irred_poly[j]);
result[i - deg + j] = builder.create<BigInt::AddOp>(loc, result[i - deg + j], prod);
result[i - deg + j] = builder.create<BigInt::ReduceOp>(loc, result[i - deg + j], prime);
}
// No need to zero out result[i], it will just get dropped
}
// Result's degree is just `deg`, drop the coefficients beyond that
result.truncate(deg);

return result;
}

llvm::SmallVector<Value, 3> extSub(mlir::OpBuilder builder,
mlir::Location loc,
llvm::SmallVector<Value, 3> lhs,
llvm::SmallVector<Value, 3> rhs,
Value prime) {
auto deg = lhs.size();
assert(rhs.size() == deg);
llvm::SmallVector<Value, 3> result(deg);

for (size_t i = 0; i < deg; i++) {
// auto diff = builder.create<BigInt::SubOp>(loc, lhs[i], rhs[i]);
auto diff = builder.create<BigInt::SubOp>(loc, lhs[i], rhs[i]);
// Add `prime` due to the same reason as in modSub
auto diff_aug = builder.create<BigInt::AddOp>(loc, diff, prime);
result[i] = builder.create<BigInt::ReduceOp>(loc, diff_aug, prime);
}
return result;
}

// Full programs, including I/O

// Finite Field arithmetic
//
// These functions accelerate finite field arithmetic
// - The `Mod` versions are for prime order fields
// - Versions for finite extensions of prime fields are planned as future work
// - The `FieldExt` versions are for simple extensions
// - Every finite extension of a finite field is simple, so in a sense this covers every finite
// field, but to use these functions you must represent the extension as the adjunction of a
// primitive element to a prime order field, which is not always convenient (i.e. when you have
// a tower of extensions)
//
// We do not use integer quotients in these functions, so minBits does not give us performance gains
// and we therefore do not require the prime to be full bitwidth, enabling simpler generalization
Expand Down Expand Up @@ -86,4 +200,72 @@ void genModSub(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth) {
builder.create<BigInt::StoreOp>(loc, result, 14, 0);
}

// Extension fields we use are most commonly degree 2
void genExtFieldAdd(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree) {
assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks
size_t chunkwidth = bitwidth / 128;
llvm::SmallVector<Value, 3> lhs(degree);
llvm::SmallVector<Value, 3> rhs(degree);
for (size_t i = 0; i < degree; i++) {
lhs[i] = builder.create<BigInt::LoadOp>(loc, bitwidth, 11, i * chunkwidth);
rhs[i] = builder.create<BigInt::LoadOp>(loc, bitwidth, 12, i * chunkwidth);
}
auto prime = builder.create<BigInt::LoadOp>(loc, bitwidth, 13, 0);
auto result = BigInt::field::extAdd(builder, loc, lhs, rhs, prime);
for (size_t i = 0; i < degree; i++) {
builder.create<BigInt::StoreOp>(loc, result[i], 14, i * chunkwidth);
}
}

void genExtFieldMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree) {
assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks
size_t chunkwidth = bitwidth / 128;
llvm::SmallVector<Value, 3> lhs(degree);
llvm::SmallVector<Value, 3> rhs(degree);
llvm::SmallVector<Value, 3> monic_irred_poly(degree);
for (size_t i = 0; i < degree; i++) {
lhs[i] = builder.create<BigInt::LoadOp>(loc, bitwidth, 11, i * chunkwidth);
rhs[i] = builder.create<BigInt::LoadOp>(loc, bitwidth, 12, i * chunkwidth);
monic_irred_poly[i] = builder.create<BigInt::LoadOp>(loc, bitwidth, 13, i * chunkwidth);
}
auto prime = builder.create<BigInt::LoadOp>(loc, bitwidth, 14, 0);
auto result = BigInt::field::extMul(builder, loc, lhs, rhs, monic_irred_poly, prime);
for (size_t i = 0; i < degree; i++) {
builder.create<BigInt::StoreOp>(loc, result[i], 15, i * chunkwidth);
}
}

void genExtFieldSub(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth, size_t degree) {
assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks
size_t chunkwidth = bitwidth / 128;
llvm::SmallVector<Value, 3> lhs(degree);
llvm::SmallVector<Value, 3> rhs(degree);
for (size_t i = 0; i < degree; i++) {
lhs[i] = builder.create<BigInt::LoadOp>(loc, bitwidth, 11, i * chunkwidth);
rhs[i] = builder.create<BigInt::LoadOp>(loc, bitwidth, 12, i * chunkwidth);
}
auto prime = builder.create<BigInt::LoadOp>(loc, bitwidth, 13, 0);
auto result = BigInt::field::extSub(builder, loc, lhs, rhs, prime);
for (size_t i = 0; i < degree; i++) {
builder.create<BigInt::StoreOp>(loc, result[i], 14, i * chunkwidth);
}
}

void genExtFieldXXOneMul(mlir::OpBuilder builder, mlir::Location loc, size_t bitwidth) {
assert(bitwidth % 128 == 0); // Bitwidth must be an even number of 128-bit chunks
size_t chunkwidth = bitwidth / 128;
llvm::SmallVector<Value, 3> lhs(2);
llvm::SmallVector<Value, 3> rhs(2);
for (size_t i = 0; i < 2; i++) {
lhs[i] = builder.create<BigInt::LoadOp>(loc, bitwidth, 11, i * chunkwidth);
rhs[i] = builder.create<BigInt::LoadOp>(loc, bitwidth, 12, i * chunkwidth);
}
auto prime = builder.create<BigInt::LoadOp>(loc, bitwidth, 13, 0);
auto primesqr = builder.create<BigInt::LoadOp>(loc, 2 * bitwidth, 14, 0);
auto result = BigInt::field::extXXOneMul(builder, loc, lhs, rhs, prime, primesqr);
for (size_t i = 0; i < 2; i++) {
builder.create<BigInt::StoreOp>(loc, result[i], 15, i * chunkwidth);
}
}

} // namespace zirgen::BigInt::field
Loading
Loading