Skip to content

Commit

Permalink
Add bitwise op support (#2043)
Browse files Browse the repository at this point in the history
* Add e2e support for bitwise ops

Signed-off-by: philass <[email protected]>

* Fix docs

Signed-off-by: philass <[email protected]>

* Add lit tests

Signed-off-by: philass <[email protected]>

---------

Signed-off-by: philass <[email protected]>
  • Loading branch information
philass authored May 3, 2023
1 parent a70c43a commit b238664
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 3 deletions.
6 changes: 3 additions & 3 deletions docs/SupportedONNXOps-cpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 18. Limitatio
| **Bernoulli** | |unsupported | |
| **Binarizer** | |unsupported | |
| **BitShift** | |unsupported | |
| **BitwiseAnd** | |unsupported | |
| **BitwiseAnd** |18 | | |
| **BitwiseNot** | |unsupported | |
| **BitwiseOr** | |unsupported | |
| **BitwiseXor** | |unsupported | |
| **BitwiseOr** |18 | | |
| **BitwiseXor** |18 | | |
| **BlackmanWindow** | |unsupported | |
| **Cast** |13 |Cast only between float and double types. | |
| **CastLike** | |unsupported | |
Expand Down
21 changes: 21 additions & 0 deletions src/Conversion/ONNXToKrnl/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,24 @@ struct ScalarOp<ONNXXorOp> {
using IOp = arith::XOrIOp;
};

template <>
struct ScalarOp<ONNXBitwiseAndOp> {
using FOp = arith::AndIOp; // Not used.
using IOp = arith::AndIOp;
};

template <>
struct ScalarOp<ONNXBitwiseOrOp> {
using FOp = arith::OrIOp; // Not used.
using IOp = arith::OrIOp;
};

template <>
struct ScalarOp<ONNXBitwiseXorOp> {
using FOp = arith::XOrIOp; // Not used.
using IOp = arith::XOrIOp;
};

template <>
struct ScalarOp<ONNXExpOp> {
using FOp = math::ExpOp;
Expand Down Expand Up @@ -2294,6 +2312,9 @@ void populateLoweringONNXElementwiseOpPattern(RewritePatternSet &patterns,
ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXAndOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXAtanOp>,
ONNXElementwiseBinaryOpLowering<mlir::ONNXBitwiseAndOp>,
ONNXElementwiseBinaryOpLowering<mlir::ONNXBitwiseOrOp>,
ONNXElementwiseBinaryOpLowering<mlir::ONNXBitwiseXorOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXCastOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXCeilOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXCosOp>,
Expand Down
13 changes: 13 additions & 0 deletions test/backend/inference_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,19 @@ def get_test_models():

# Bitshift

# ==OP== BitwiseAnd
"test_bitwise_and_i32_2d_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_bitwise_and_i16_3d_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},

# ==OP== BitwiseOr
"test_bitwise_or_i32_2d_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_bitwise_or_i16_4d_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},

# ==OP== BitwiseXor
"test_bitwise_xor_i32_2d_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_bitwise_xor_i16_3d_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},


# ==OP== Cast
# ==LIM== Cast only between float and double types
"test_cast_FLOAT_to_DOUBLE_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
Expand Down
54 changes: 54 additions & 0 deletions test/mlir/onnx/onnx_lowering.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,60 @@ func.func private @test_xor(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>)

// -----

func.func private @test_bitwise_and(%arg0 : tensor<10x10xi8>, %arg1 : tensor<10x10xi8>) -> tensor<*xi8> {
%0 = "onnx.BitwiseAnd"(%arg0, %arg1) : (tensor<10x10xi8>, tensor<10x10xi8>) -> tensor<*xi8>
"func.return"(%0) : (tensor<*xi8>) -> ()

// CHECK-LABEL: test_bitwise_and
// CHECK: [[RES:%.+]] = memref.alloc() {{.*}}: memref<10x10xi8>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: krnl.iterate([[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10){
// CHECK: [[IV:%.+]]:2 = krnl.get_induction_var_value([[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK: [[LOAD1:%.+]] = krnl.load %arg0[[[IV]]#0, [[IV]]#1] : memref<10x10xi8>
// CHECK: [[LOAD2:%.+]] = krnl.load %arg1[[[IV]]#0, [[IV]]#1] : memref<10x10xi8>
// CHECK: [[AND:%.+]] = arith.andi [[LOAD1]], [[LOAD2]] : i8
// CHECK: krnl.store [[AND]], [[RES]][[[IV]]#0, [[IV]]#1] : memref<10x10xi8>
// CHECK: return [[RES]] : memref<10x10xi8>
}

// -----

func.func private @test_bitwise_or(%arg0 : tensor<10x10xi16>, %arg1 : tensor<10x10xi16>) -> tensor<*xi16> {
%0 = "onnx.BitwiseOr"(%arg0, %arg1) : (tensor<10x10xi16>, tensor<10x10xi16>) -> tensor<*xi16>
"func.return"(%0) : (tensor<*xi16>) -> ()

// CHECK-LABEL: test_bitwise_or
// CHECK: [[RES:%.+]] = memref.alloc() {{.*}}: memref<10x10xi16>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: krnl.iterate([[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10){
// CHECK: [[IV:%.+]]:2 = krnl.get_induction_var_value([[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK: [[LOAD1:%.+]] = krnl.load %arg0[[[IV]]#0, [[IV]]#1] : memref<10x10xi16>
// CHECK: [[LOAD2:%.+]] = krnl.load %arg1[[[IV]]#0, [[IV]]#1] : memref<10x10xi16>
// CHECK: [[OR:%.+]] = arith.ori [[LOAD1]], [[LOAD2]] : i16
// CHECK: krnl.store [[OR]], [[RES]][[[IV]]#0, [[IV]]#1] : memref<10x10xi16>
// CHECK: return [[RES]] : memref<10x10xi16>
}

// -----

func.func private @test_bitwise_xor(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> {
%0 = "onnx.BitwiseXor"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32>
"func.return"(%0) : (tensor<*xi32>) -> ()

// CHECK-LABEL: test_bitwise_xor
// CHECK: [[RES:%.+]] = memref.alloc() {{.*}}: memref<10x10xi32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: krnl.iterate([[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10){
// CHECK: [[IV:%.+]]:2 = krnl.get_induction_var_value([[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK: [[LOAD1:%.+]] = krnl.load %arg0[[[IV]]#0, [[IV]]#1] : memref<10x10xi32>
// CHECK: [[LOAD2:%.+]] = krnl.load %arg1[[[IV]]#0, [[IV]]#1] : memref<10x10xi32>
// CHECK: [[XOR:%.+]] = arith.xori [[LOAD1]], [[LOAD2]] : i32
// CHECK: krnl.store [[XOR]], [[RES]][[[IV]]#0, [[IV]]#1] : memref<10x10xi32>
// CHECK: return [[RES]] : memref<10x10xi32>
}

// -----

func.func private @test_exp(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Exp"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
"func.return"(%0) : (tensor<*xf32>) -> ()
Expand Down

0 comments on commit b238664

Please sign in to comment.