Skip to content

Commit

Permalink
Fixed div
Browse files Browse the repository at this point in the history
  • Loading branch information
jbruestle committed Dec 27, 2024
1 parent 6552ec9 commit 3ba0936
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 3 deletions.
34 changes: 31 additions & 3 deletions zirgen/circuit/rv32im/v2/dsl/inst_div.zir
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ component DivideReturn(quot: ValU32, rem: ValU32) {

extern Divide(numer: ValU32, denom: ValU32, sign_type: Val) : DivideReturn;

/*
component PairU32(a: ValU32, b: ValU32) {
aLow := NondetReg(a.low);
aHigh := NondetReg(a.high);
bLow := NondetReg(b.low);
bHigh := NondetReg(b.high);
public a := ValU32(aLow, aHigh);
public b := ValU32(bLow, bHigh);
}
*/

component DoDiv(numer: ValU32, denom: ValU32, signed: Val, ones_comp: Val) {
// Guess the answer
guess := Divide(numer, denom, signed + 2 * ones_comp);
Expand All @@ -59,19 +70,36 @@ component DoDiv(numer: ValU32, denom: ValU32, signed: Val, ones_comp: Val) {
settings := MultiplySettings(signed, signed, signed);
// Do the accumulate
mul := MultiplyAccumulate(quot, denom, rem, settings);
// Check the main result (numer = quot * denom + rem
// Check the main result (numer = quot * denom + rem)
AssertEqU32(mul.outLow, numer);
// The top bits should all be 0 or all be 1
topBitType := NondetBitReg(1 - Isz(mul.outHigh.low));
AssertEqU32(mul.outHigh, ValU32(0xffff * topBitType, 0xffff * topBitType));
// Check if denom is zero
isZero := IsZero(denom.low + denom.high);
// Get top bit of numerator
topNum := NondetBitReg((numer.high & 0x8000) / 0x8000);
// Verify we got it right
U16Reg((numer.high - 0x8000 * topNum) * 2);
numNeg := topNum * signed;
// Get the absolute value of the denominator
denomNeg := mul.bNeg;
denomAbs := NormalizeU32(DenormedValU32(
denomNeg * (0x10000 - denom.low) + (1 - denomNeg) * denom.low,
denomNeg * (0xffff - denom.high) + (1 - denomNeg) * denom.high
));
// Flip the sign of the remainder if the numerator is negative
remNormal := NormalizeU32(DenormedValU32(
numNeg * (0x10000 - rem.low) + (1 - numNeg) * rem.low,
numNeg * (0xffff - rem.high) + (1 - numNeg) * rem.high
));
// Decide if we need to swap order of
// If non-zero, make sure 0 <= rem < denom
if (isZero) {
AssertEqU32(rem, numer);
} else {
cmp := CmpLessThanUnsigned(rem, denom);
cmp.is_less_than = 1;
lt := CmpLessThanUnsigned(remNormal, denomAbs);
lt.is_less_than = 1;
};
DivideReturn(quot, rem)
}
Expand Down
2 changes: 2 additions & 0 deletions zirgen/circuit/rv32im/v2/dsl/mult.zir
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ component MultiplyAccumulate(a: ValU32, b: ValU32, c: ValU32, settings: Multiply
s3Carry := FakeTwitReg((s3Tot - s3Out) / 0x10000);
public outLow := ValU32(s0.out, s1.out);
public outHigh := ValU32(s2.out, s3Out);
public aNeg := ax.neg;
public bNeg := bx.neg;
}

component MultiplyTestCase(a: ValU32, b: ValU32, c: ValU32, settings: MultiplySettings, ol: ValU32, oh: ValU32) {
Expand Down
13 changes: 13 additions & 0 deletions zirgen/circuit/rv32im/v2/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ cc_test(
],
)

cc_test(
name = "test_riscv_bins",
srcs = [
"test_riscv_bins.cpp",
],
data = [
"//zirgen/circuit/rv32im/shared/test:riscv_test_bins",
],
deps = [
"//zirgen/circuit/rv32im/v2/run",
],
)

risc0_cc_kernel_binary(
name = "test_sha_kernel",
srcs = [
Expand Down
51 changes: 51 additions & 0 deletions zirgen/circuit/rv32im/v2/test/test_riscv_bins.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright 2024 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <iostream>

#include "zirgen/circuit/rv32im/v2/platform/constants.h"
#include "zirgen/circuit/rv32im/v2/run/run.h"

using namespace zirgen::rv32im_v2;

void runOneTest(const std::string& name) {
std::string kernel = "zirgen/circuit/rv32im/shared/test/" + name;
size_t cycles = 10000;

TestIoHandler io;

// Load image
auto image = MemoryImage::fromRawElf(kernel);
// Do executions
auto segments = execute(image, io, cycles, cycles);
// Do 'run' (preflight + expansion)
for (const auto& segment : segments) {
runSegment(segment, cycles);
}
}

std::vector<std::string> names = {
"add", "addi", "and", "andi", "auipc", "beq", "bge", "bgeu", "blt", "bltu",
"bne", "div", "divu", "jal", "jalr", "lb", "lbu", "lh", "lhu", "lui",
"lw", "mul", "mulh", "mulhsu", "mulhu", "or", "ori", "rem", "remu", "sb",
"sh", "sll", "slli", "slt", "slti", "sltiu", "sltu", "sra", "srai", "srl",
"srli", "sub", "sw", "test", "xor", "xori"};

const std::string kernelName = "zirgen/circuit/rv32im/shared/test/rem";

int main() {
for (const std::string& name : names) {
runOneTest(name);
}
}

0 comments on commit 3ba0936

Please sign in to comment.