Skip to content

Commit

Permalink
modules/zstd/zstd_dec: Add DSLX tests for ZstdDecoder
Browse files Browse the repository at this point in the history
Signed-off-by: Krzysztof Obłonczek <[email protected]>
  • Loading branch information
koblonczek authored and lpawelcz committed Oct 24, 2024
1 parent 4c887eb commit 4502d8b
Show file tree
Hide file tree
Showing 11 changed files with 765 additions and 320 deletions.
78 changes: 62 additions & 16 deletions xls/modules/zstd/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,19 @@ common_codegen_args = {
"multi_proc": "true",
}

xls_dslx_library(
name = "math_dslx",
srcs = [
"math.x",
],
)

xls_dslx_test(
name = "math_dslx_test",
library = ":math_dslx",
tags = ["manual"],
)

xls_dslx_library(
name = "buffer_dslx",
srcs = [
Expand Down Expand Up @@ -927,32 +940,65 @@ place_and_route(
target_die_utilization_percentage = "10",
)

py_binary(
name = "zstd_test_frames_generator",
main = "zstd_frame_dslx.py",
srcs = ["zstd_frame_dslx.py"],
imports = ["."],
tags = ["manual"],
visibility = ["//xls:xls_users"],
deps = [
requirement("zstandard"),
"//xls/common:runfiles",
"//xls/modules/zstd/cocotb:data_generator",
"@com_google_absl_py//absl:app",
"@com_google_absl_py//absl/flags",
"@com_google_protobuf//:protobuf_python",
],
)

genrule(
name = "zstd_test_frames_generate",
srcs = [],
outs = ["zstd_frame_testcases.x"],
cmd = "$(location :zstd_test_frames_generator) -n 2 --btype RAW RLE -o $@",
tools = [":zstd_test_frames_generator"],
)

zstd_dec_deps = [
":axi_csr_accessor_dslx",
":block_header_dec_dslx",
":block_header_dslx",
":common_dslx",
":csr_config_dslx",
":dec_mux_dslx",
":frame_header_dec_dslx",
":raw_block_dec_dslx",
":repacketizer_dslx",
":rle_block_dec_dslx",
":sequence_executor_dslx",
"//xls/examples:ram_dslx",
"//xls/modules/zstd/memory:mem_reader_dslx",
"//xls/modules/zstd/memory:axi_ram_dslx",
]

xls_dslx_library(
name = "zstd_dec_dslx",
srcs = [
"zstd_dec.x",
],
deps = [
":axi_csr_accessor_dslx",
":block_header_dec_dslx",
":block_header_dslx",
":common_dslx",
":csr_config_dslx",
":dec_mux_dslx",
":frame_header_dec_dslx",
":raw_block_dec_dslx",
":repacketizer_dslx",
":rle_block_dec_dslx",
":sequence_executor_dslx",
"//xls/examples:ram_dslx",
"//xls/modules/zstd/memory:mem_reader_dslx",
],
deps = zstd_dec_deps
)

xls_dslx_test(
name = "zstd_dec_dslx_test",
library = ":zstd_dec_dslx",
tags = ["manual"],
srcs = [
"zstd_dec.x",
"zstd_dec_test.x",
"zstd_frame_testcases.x",
],
deps = zstd_dec_deps,
)

zstd_dec_codegen_args = common_codegen_args | {
Expand Down
13 changes: 7 additions & 6 deletions xls/modules/zstd/axi_csr_accessor.x
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,16 @@ pub proc AxiCsrAccessor<
let tok_0 = join();
// write to CSR via AXI
let (tok_1_1, axi_aw, axi_aw_valid) = recv_non_blocking(tok_0, axi_aw_r, AxiAw {id: state.w_id, addr: state.w_addr, ..zero!<AxiAw>()});

// validate axi aw
assert!(!(axi_aw_valid && axi_aw.addr as u32 >= REGS_N), "invalid_aw_addr");
assert!(!(axi_aw_valid && axi_aw.addr as u32 >= (REGS_N << LOG2_DATA_W_DIV8)), "invalid_aw_addr");
assert!(!(axi_aw_valid && axi_aw.len != u8:0), "invalid_aw_len");

let (tok_1_2, axi_w, axi_w_valid) = recv_non_blocking(tok_1_1, axi_w_r, zero!<AxiW>());

// Send WriteRequest to CSRs
let data_w = if axi_w_valid {
trace_fmt!("[CSR ACCESSOR] received csr write at {:#x}", axi_w);

let (w_data, _, _) = for (i, (w_data, strb, mask)): (u32, (uN[DATA_W], uN[DATA_W_DIV8], uN[DATA_W])) in range(u32:0, DATA_W_DIV8) {
let w_data = if axi_w.strb as u1 {
w_data | (axi_w.data & mask)
Expand Down Expand Up @@ -133,7 +134,7 @@ pub proc AxiCsrAccessor<
// Send ReadRequest to CSRs
let (tok_3_1, axi_ar, axi_ar_valid) = recv_non_blocking(tok_0, axi_ar_r, AxiAr {id: state.r_id, addr: state.r_addr, ..zero!<AxiAr>()});
// validate ar bundle
assert!(!(axi_ar_valid && axi_ar.addr as u32 >= REGS_N), "invalid_ar_addr");
assert!(!(axi_ar_valid && axi_ar.addr as u32 >= (REGS_N << LOG2_DATA_W_DIV8)), "invalid_ar_addr");
assert!(!(axi_ar_valid && axi_ar.len != u8:0), "invalid_ar_len");
let rd_req = RdReq {
csr: (axi_ar.addr >> LOG2_DATA_W_DIV8) as uN[LOG2_REGS_N],
Expand Down Expand Up @@ -208,7 +209,7 @@ proc AxiCsrAccessorInst {
const TEST_ID_W = u32:4;
const TEST_DATA_W = u32:32;
const TEST_ADDR_W = u32:16;
const TEST_REGS_N = u32:16;
const TEST_REGS_N = u32:4;
const TEST_DATA_W_DIV8 = TEST_DATA_W / u32:8;
const TEST_LOG2_REGS_N = std::clog2(TEST_REGS_N);
const TEST_LOG2_DATA_W_DIV8 = std::clog2(TEST_DATA_W_DIV8);
Expand Down Expand Up @@ -309,7 +310,7 @@ proc AxiCsrAccessorTest {
// write CSR via AXI
let axi_aw = TestAxiAw {
id: i as uN[TEST_ID_W],
addr: (test_data.csr << TEST_LOG2_DATA_W_DIV8) as uN[TEST_ADDR_W],
addr: (test_data.csr as uN[TEST_ADDR_W]) << TEST_LOG2_DATA_W_DIV8,
size: axi::AxiAxSize::MAX_4B_TRANSFER,
len: u8:0,
burst: axi::AxiAxBurst::FIXED,
Expand Down Expand Up @@ -346,7 +347,7 @@ proc AxiCsrAccessorTest {
// read CSRs via AXI
let axi_ar = TestAxiAr {
id: i as uN[TEST_ID_W],
addr: (test_data.csr << TEST_LOG2_DATA_W_DIV8) as uN[TEST_ADDR_W],
addr: (test_data.csr as uN[TEST_ADDR_W]) << TEST_LOG2_DATA_W_DIV8,
len: u8:0,
..zero!<TestAxiAr>()
};
Expand Down
19 changes: 17 additions & 2 deletions xls/modules/zstd/cocotb/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,35 @@

from xls.common import runfiles
import subprocess
import zstandard

class BlockType(Enum):
RAW = 0
RLE = 1
COMPRESSED = 2
RANDOM = 3

def __str__(self):
return self.name

@staticmethod
def from_string(s):
try:
return BlockType[s]
except KeyError as e:
raise ValueError(str(e))

def CallDecodecorpus(args):
decodecorpus = Path(runfiles.get_path("decodecorpus", repository = "zstd"))
cmd = args
cmd.insert(0, str(decodecorpus))
cmd_concat = " ".join(cmd)
subprocess.run(cmd_concat, shell=True, check=True)

def DecompressFrame(data):
dctx = zstandard.ZstdDecompressor()
return dctx.decompress(data)

def GenerateFrame(seed, btype, output_path):
args = []
args.append("-s" + str(seed))
Expand All @@ -39,8 +54,8 @@ def GenerateFrame(seed, btype, output_path):
args.append("--content-size")
# Test payloads up to 16KB
args.append("--max-content-size-log=14")
args.append("-p" + output_path);
args.append("-vvvvvvv");
args.append("-p" + output_path)
args.append("-vvvvvvv")

CallDecodecorpus(args)

88 changes: 88 additions & 0 deletions xls/modules/zstd/math.x
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright 2024 The XLS Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import std;

fn fast_if<N: u32>(cond: bool, arg1: uN[N], arg2: uN[N]) -> uN[N] {
let mask = if cond { !bits[N]:0 } else { bits[N]:0 };
(arg1 & mask) | (arg2 & !mask)
}

#[test]
fn fast_if_test() {
assert_eq(if true { u32:1 } else { u32:5 }, fast_if(true, u32:1, u32:5));
assert_eq(if false { u32:1 } else { u32:5 }, fast_if(false, u32:1, u32:5));
}

// Log-depth shift bits left
pub fn logshiftl<N: u32, R: u32>(n: bits[N], r: bits[R]) -> bits[N] {
for (i, y) in u32:0..R {
fast_if(r[i+:u1], { y << (bits[R]:1 << i) }, { y })
}(n as bits[N])
}

#[test]
fn logshiftl_test() {
// Test varying base
assert_eq(logshiftl(bits[64]:0, bits[6]:3), bits[64]:0 << u32:3);
assert_eq(logshiftl(bits[64]:1, bits[6]:3), bits[64]:1 << u32:3);
assert_eq(logshiftl(bits[64]:2, bits[6]:3), bits[64]:2 << u32:3);
assert_eq(logshiftl(bits[64]:3, bits[6]:3), bits[64]:3 << u32:3);
assert_eq(logshiftl(bits[64]:4, bits[6]:3), bits[64]:4 << u32:3);

// Test varying exponent
assert_eq(logshiftl(bits[64]:50, bits[6]:0), bits[64]:50 << u32:0);
assert_eq(logshiftl(bits[64]:50, bits[6]:1), bits[64]:50 << u32:1);
assert_eq(logshiftl(bits[64]:50, bits[6]:2), bits[64]:50 << u32:2);
assert_eq(logshiftl(bits[64]:50, bits[6]:3), bits[64]:50 << u32:3);
assert_eq(logshiftl(bits[64]:50, bits[6]:4), bits[64]:50 << u32:4);

// Test overflow
let max = std::unsigned_max_value<u32:8>();
assert_eq(logshiftl(max, u4:4), max << u4:4);
assert_eq(logshiftl(max, u4:5), max << u4:5);
assert_eq(logshiftl(max, u4:15), max << u4:15);
assert_eq(logshiftl(bits[24]:0xc0ffee, u8:12), bits[24]:0xfee000);
}

// Log-depth shift bits right
pub fn logshiftr<N: u32, R: u32>(n: bits[N], r: bits[R]) -> bits[N] {
for (i, y) in u32:0..R {
fast_if(r[i+:u1], { y >> (bits[R]:1 << i) }, { y })
}(n as bits[N])
}

#[test]
fn logshiftr_test() {
// Test varying base
assert_eq(logshiftr(bits[64]:0x0fac4e782, bits[6]:3), bits[64]:0x0fac4e782 >> u32:3);
assert_eq(logshiftr(bits[64]:0x1fac4e782, bits[6]:3), bits[64]:0x1fac4e782 >> u32:3);
assert_eq(logshiftr(bits[64]:0x2fac4e782, bits[6]:3), bits[64]:0x2fac4e782 >> u32:3);
assert_eq(logshiftr(bits[64]:0x3fac4e782, bits[6]:3), bits[64]:0x3fac4e782 >> u32:3);
assert_eq(logshiftr(bits[64]:0x4fac4e782, bits[6]:3), bits[64]:0x4fac4e782 >> u32:3);

// Test varying exponent
assert_eq(logshiftr(bits[64]:0x50fac4e782, bits[6]:0), bits[64]:0x50fac4e782 >> u32:0);
assert_eq(logshiftr(bits[64]:0x50fac4e782, bits[6]:1), bits[64]:0x50fac4e782 >> u32:1);
assert_eq(logshiftr(bits[64]:0x50fac4e782, bits[6]:2), bits[64]:0x50fac4e782 >> u32:2);
assert_eq(logshiftr(bits[64]:0x50fac4e782, bits[6]:3), bits[64]:0x50fac4e782 >> u32:3);
assert_eq(logshiftr(bits[64]:0x50fac4e782, bits[6]:4), bits[64]:0x50fac4e782 >> u32:4);

// Test overflow
let max = std::unsigned_max_value<u32:8>();
assert_eq(logshiftr(max, u4:4), max >> u4:4);
assert_eq(logshiftr(max, u4:5), max >> u4:5);
assert_eq(logshiftr(max, u4:15), max >> u4:15);
assert_eq(logshiftr(bits[24]:0xc0ffee, u8:12), bits[24]:0x000c0f);
}
2 changes: 1 addition & 1 deletion xls/modules/zstd/memory/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ xls_dslx_verilog(
library = ":axi_ram_dslx",
opt_ir_args = {
"inline_procs": "true",
"top": "__axi_ram__AxiRamReaderInstWithEmptyWrites__AxiRamReader_0__AxiRamReaderResponder_0__32_32_4_8_8_32768_7_32_4_100_next",
"top": "__axi_ram__AxiRamReaderInstWithEmptyWrites__AxiRamReader_0__AxiRamReaderResponder_0__32_32_4_5_6_8_8_32768_7_32_5_6_4_100_next",
},
tags = ["manual"],
verilog_file = "axi_ram.v",
Expand Down
Loading

0 comments on commit 4502d8b

Please sign in to comment.