From 3ba0936fdfdfed7c5605026b783191c6913d7d28 Mon Sep 17 00:00:00 2001 From: Jeremy Bruestle Date: Thu, 26 Dec 2024 21:40:50 -0800 Subject: [PATCH] Fixed div --- zirgen/circuit/rv32im/v2/dsl/inst_div.zir | 34 +++++++++++-- zirgen/circuit/rv32im/v2/dsl/mult.zir | 2 + zirgen/circuit/rv32im/v2/test/BUILD.bazel | 13 +++++ .../rv32im/v2/test/test_riscv_bins.cpp | 51 +++++++++++++++++++ 4 files changed, 97 insertions(+), 3 deletions(-) create mode 100644 zirgen/circuit/rv32im/v2/test/test_riscv_bins.cpp diff --git a/zirgen/circuit/rv32im/v2/dsl/inst_div.zir b/zirgen/circuit/rv32im/v2/dsl/inst_div.zir index 5f0cb9a1..f373800d 100644 --- a/zirgen/circuit/rv32im/v2/dsl/inst_div.zir +++ b/zirgen/circuit/rv32im/v2/dsl/inst_div.zir @@ -43,6 +43,17 @@ component DivideReturn(quot: ValU32, rem: ValU32) { extern Divide(numer: ValU32, denom: ValU32, sign_type: Val) : DivideReturn; +/* +component PairU32(a: ValU32, b: ValU32) { + aLow := NondetReg(a.low); + aHigh := NondetReg(a.high); + bLow := NondetReg(b.low); + bHigh := NondetReg(b.high); + public a := ValU32(aLow, aHigh); + public b := ValU32(bLow, bHigh); +} +*/ + component DoDiv(numer: ValU32, denom: ValU32, signed: Val, ones_comp: Val) { // Guess the answer guess := Divide(numer, denom, signed + 2 * ones_comp); @@ -59,19 +70,36 @@ component DoDiv(numer: ValU32, denom: ValU32, signed: Val, ones_comp: Val) { settings := MultiplySettings(signed, signed, signed); // Do the accumulate mul := MultiplyAccumulate(quot, denom, rem, settings); - // Check the main result (numer = quot * denom + rem + // Check the main result (numer = quot * denom + rem) AssertEqU32(mul.outLow, numer); // The top bits should all be 0 or all be 1 topBitType := NondetBitReg(1 - Isz(mul.outHigh.low)); AssertEqU32(mul.outHigh, ValU32(0xffff * topBitType, 0xffff * topBitType)); // Check if denom is zero isZero := IsZero(denom.low + denom.high); + // Get top bit of numerator + topNum := NondetBitReg((numer.high & 0x8000) / 0x8000); + // Verify we got it right + U16Reg((numer.high - 0x8000 * topNum) * 2); + numNeg := topNum * signed; + // Get the absolute value of the denominator + denomNeg := mul.bNeg; + denomAbs := NormalizeU32(DenormedValU32( + denomNeg * (0x10000 - denom.low) + (1 - denomNeg) * denom.low, + denomNeg * (0xffff - denom.high) + (1 - denomNeg) * denom.high + )); + // Flip the sign of the remainder if the numerator is negative + remNormal := NormalizeU32(DenormedValU32( + numNeg * (0x10000 - rem.low) + (1 - numNeg) * rem.low, + numNeg * (0xffff - rem.high) + (1 - numNeg) * rem.high + )); + // Decide if we need to swap order of // If non-zero, make sure 0 <= rem < denom if (isZero) { AssertEqU32(rem, numer); } else { - cmp := CmpLessThanUnsigned(rem, denom); - cmp.is_less_than = 1; + lt := CmpLessThanUnsigned(remNormal, denomAbs); + lt.is_less_than = 1; }; DivideReturn(quot, rem) } diff --git a/zirgen/circuit/rv32im/v2/dsl/mult.zir b/zirgen/circuit/rv32im/v2/dsl/mult.zir index eed61631..da6ba6eb 100644 --- a/zirgen/circuit/rv32im/v2/dsl/mult.zir +++ b/zirgen/circuit/rv32im/v2/dsl/mult.zir @@ -151,6 +151,8 @@ component MultiplyAccumulate(a: ValU32, b: ValU32, c: ValU32, settings: Multiply s3Carry := FakeTwitReg((s3Tot - s3Out) / 0x10000); public outLow := ValU32(s0.out, s1.out); public outHigh := ValU32(s2.out, s3Out); + public aNeg := ax.neg; + public bNeg := bx.neg; } component MultiplyTestCase(a: ValU32, b: ValU32, c: ValU32, settings: MultiplySettings, ol: ValU32, oh: ValU32) { diff --git a/zirgen/circuit/rv32im/v2/test/BUILD.bazel b/zirgen/circuit/rv32im/v2/test/BUILD.bazel index fc7faa67..b051bb9a 100644 --- a/zirgen/circuit/rv32im/v2/test/BUILD.bazel +++ b/zirgen/circuit/rv32im/v2/test/BUILD.bazel @@ -37,6 +37,19 @@ cc_test( ], ) +cc_test( + name = "test_riscv_bins", + srcs = [ + "test_riscv_bins.cpp", + ], + data = [ + "//zirgen/circuit/rv32im/shared/test:riscv_test_bins", + ], + deps = [ + "//zirgen/circuit/rv32im/v2/run", + ], +) + risc0_cc_kernel_binary( name = "test_sha_kernel", srcs = [ diff --git a/zirgen/circuit/rv32im/v2/test/test_riscv_bins.cpp b/zirgen/circuit/rv32im/v2/test/test_riscv_bins.cpp new file mode 100644 index 00000000..e441e38b --- /dev/null +++ b/zirgen/circuit/rv32im/v2/test/test_riscv_bins.cpp @@ -0,0 +1,51 @@ +// Copyright 2024 RISC Zero, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "zirgen/circuit/rv32im/v2/platform/constants.h" +#include "zirgen/circuit/rv32im/v2/run/run.h" + +using namespace zirgen::rv32im_v2; + +void runOneTest(const std::string& name) { + std::string kernel = "zirgen/circuit/rv32im/shared/test/" + name; + size_t cycles = 10000; + + TestIoHandler io; + + // Load image + auto image = MemoryImage::fromRawElf(kernel); + // Do executions + auto segments = execute(image, io, cycles, cycles); + // Do 'run' (preflight + expansion) + for (const auto& segment : segments) { + runSegment(segment, cycles); + } +} + +std::vector names = { + "add", "addi", "and", "andi", "auipc", "beq", "bge", "bgeu", "blt", "bltu", + "bne", "div", "divu", "jal", "jalr", "lb", "lbu", "lh", "lhu", "lui", + "lw", "mul", "mulh", "mulhsu", "mulhu", "or", "ori", "rem", "remu", "sb", + "sh", "sll", "slli", "slt", "slti", "sltiu", "sltu", "sra", "srai", "srl", + "srli", "sub", "sw", "test", "xor", "xori"}; + +const std::string kernelName = "zirgen/circuit/rv32im/shared/test/rem"; + +int main() { + for (const std::string& name : names) { + runOneTest(name); + } +}