diff --git a/zirgen/Dialect/ZHLT/IR/TypeUtils.cpp b/zirgen/Dialect/ZHLT/IR/TypeUtils.cpp index ef00cb22..6c99b819 100644 --- a/zirgen/Dialect/ZHLT/IR/TypeUtils.cpp +++ b/zirgen/Dialect/ZHLT/IR/TypeUtils.cpp @@ -37,6 +37,8 @@ std::string mangledTypeName(StringRef componentName, llvm::ArrayRef t llvm::interleaveComma(typeArgs, stream, [&](Attribute typeArg) { if (auto strAttr = typeArg.dyn_cast()) { stream << strAttr.getValue(); + } else if (auto intAttr = typeArg.dyn_cast()) { + stream << intAttr; } else if (auto intAttr = typeArg.dyn_cast()) { stream << intAttr[0]; } else { diff --git a/zirgen/dsl/test/binary_adder.zir b/zirgen/dsl/test/binary_adder.zir new file mode 100644 index 00000000..ed05189b --- /dev/null +++ b/zirgen/dsl/test/binary_adder.zir @@ -0,0 +1,63 @@ +// RUN: zirgen --test %s 2>&1 | FileCheck %s + +component BitReg(x: Val) { + r := Reg(x); + r * (r - 1) = 0; + r +} + +component Not(x: BitReg) { + BitReg(1 - x) +} + +component And(x: BitReg, y: BitReg) { + BitReg(x * y) +} + +component Or(x: BitReg, y: BitReg) { + BitReg(x + y - x * y) +} + +component Xor(x: BitReg, y: BitReg) { + BitReg(x + y - 2 * x * y) +} + +component Xor3(x: BitReg, y: BitReg, z: BitReg) { + Xor(Xor(x, y), z) +} + +component HalfAdder(x: BitReg, y: BitReg) { + sum := Xor(x, y); + carry := And(x, y); +} + +component FullAdder(x: BitReg, y: BitReg, c: BitReg) { + sum := Xor(Xor(x, y), c); + carry := [c, 1 - c] -> ( + Or(x, y), + And(x, y) + ); +} + +component Adder(x: Array, y: Array) { + a0 := HalfAdder(x[0], y[0]); + a1 := FullAdder(x[1], y[1], a0.carry); + a2 := FullAdder(x[2], y[2], a1.carry); + a0.sum + 2 * a1.sum + 4 * a2.sum +} + +test { + // CHECK-DAG: 5 + 5 = 2 (mod 8) + one := BitReg(1); + zero := BitReg(0); + adder := Adder([one, zero, one], [one, zero, one]); + Log("5 + 5 = %u (mod 8)", adder); +} + +test { + // CHECK-DAG: 3 + 1 = 4 (mod 8) + one := BitReg(1); + zero := BitReg(0); + adder := Adder([one, one, zero], [one, zero, zero]); + Log("3 + 1 = %u (mod 8)", adder); +} diff --git a/zirgen/dsl/test/matrix.zir b/zirgen/dsl/test/matrix.zir new file mode 100644 index 00000000..2a45d039 --- /dev/null +++ b/zirgen/dsl/test/matrix.zir @@ -0,0 +1,82 @@ +// RUN: zirgen %s --test + +function Matrix(coef: Array, M>) { + coef +} + +component MatAdd(a: Matrix, b: Matrix) { + Matrix( + for m : 0..M { + for n : 0..N { + a[m][n] + b[m][n] + } + } + ) +} + +test GeneralAddition { + a := Matrix<2, 3>([[1, 2, 3], + [4, 5, 6]]); + b := Matrix<2, 3>([[ 7, 8, 9], + [10, 11, 12]]); + c := MatAdd<2, 3>(a, b); + Log("c = [[%u, %u, %u], [%u, %u, %u]]", c[0][0], c[0][1], c[0][2], c[1][0], c[1][1], c[1][2]); + c[0][0] = 8; + c[0][1] = 10; + c[0][2] = 12; + c[1][0] = 14; + c[1][1] = 16; + c[1][2] = 18; +} + +component MatMul( + a: Matrix, + b: Matrix +) { + Matrix( + for m : 0..M { + for p : 0..P { + product := for n : 0..N { a[m][n] * b[n][p] }; + reduce product init 0 with Add + } + } + ) +} + +test IdentitySquared { + // [1 0][1 0] [1 0] + // [0 1][0 1] = [0 1] + a := Matrix<2, 2>([[1, 0], + [0, 1]]); + c := MatMul<2, 2, 2>(a, a); + Log("c = [[%u, %u], [%u, %u]]", c[0][0], c[0][1], c[1][0], c[1][1]); + c[0][0] = 1; + c[0][1] = 0; + c[1][0] = 0; + c[1][1] = 1; +} + +test GeneralMultiplication { + // [0 1] [0 1 2] [ 3 4 5] + // [2 3] * [3 4 5] = [ 9 14 19] + // [4 5] [15 24 33] + a := Matrix<3, 2>([[0, 1], + [2, 3], + [4, 5]]); + b := Matrix<2, 3>([[0, 1, 2], + [3, 4, 5]]); + c := MatMul<3, 2, 3>(a, b); + Log("[[%u, %u, %u], [%u, %u, %u], [%u, %u, %u]]", + c[0][0], c[0][1], c[0][2], + c[1][0], c[1][1], c[1][2], + c[2][0], c[2][1], c[2][2]); + c[0][0] = 3; + c[0][1] = 4; + c[0][2] = 5; + c[1][0] = 9; + c[1][1] = 14; + c[1][2] = 19; + c[2][0] = 15; + c[2][1] = 24; + c[2][2] = 33; +}