diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 65a27a76ad1..bc4e6a63e34 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -1040,6 +1040,69 @@ struct find_unpack_int4_mlir_op } }; +struct find_mlir_reshape_ops +{ + auto matcher() const + { + auto reshapes = reshaper_names(); + // slice is not supported + reshapes.erase("slice"); + return match::name(reshapes)(match::arg(0)(match::name("gpu::mlir_op")(match::used_once())), match::used_once()); + } + + void apply(module_pass_manager& mpm, const match::matcher_result& r) const + { + auto ins = r.result; + auto mlir_ins = ins->inputs().front(); + + auto* mm = mlir_ins->module_inputs().front(); + module_ref nm = mpm.create_module(mm->name() + ":" + ins->name()); + nm->set_bypass(); + + auto y = nm->fuse(*mm, mlir_ins->inputs()); + auto ret = nm->add_instruction(ins->get_operator(), y); + nm->add_return({ret}); + mpm.get_module().replace_instruction(ins, mlir_ins->get_operator(), mlir_ins->inputs(), {nm}); + } +}; + +struct find_convolution_reshape +{ + auto matcher() const + { + return match::name("reshape")(match::arg(0)(match::name("convolution").bind("convolution"))); + } + + void apply(module_pass_manager& mpm, const match::matcher_result& r) const + { + auto ins = r.result; + auto conv = r.instructions["convolution"]; + auto out_dims = ins->get_shape().lens(); + auto conv_dims = conv->get_shape().lens(); + if(out_dims.size() != 5) + return; + if(conv_dims.size() != 4) + return; + auto perm = find_permutation(conv->get_shape()); + if(perm.back() != 1) + return; + if(out_dims[0] != conv_dims[0]) + return; + if(not std::equal(conv_dims.begin() + 2, conv_dims.end(), out_dims.begin() + 3, out_dims.end())) + return; + if(out_dims[2] > 32) + return; + if(out_dims[1] < 4) + return; + auto reshape = mpm.get_module().insert_instruction(ins,ins->get_operator(), ins->inputs()); + // auto t2 = mpm.get_module().insert_instruction(ins, make_op("layout", {{"permutation", {0, 1, 3, 4, 2}}}), reshape); + auto t1 = mpm.get_module().insert_instruction(ins, make_op("transpose", {{"permutation", {0, 1, 3, 4, 2}}}), reshape); + auto c = mpm.get_module().insert_instruction(ins, make_op("contiguous"), t1); + auto t2 = mpm.get_module().insert_instruction(ins, make_op("transpose", {{"permutation", {0, 1, 4, 2, 3}}}), c); + mpm.get_module().replace_instruction(ins, t2); + } +}; + } // namespace #endif // MIGRAPHX_MLIR @@ -1061,6 +1124,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const return std::max(m1, m2); }; + match::find_matches(mpm, find_convolution_reshape{}); // Attention offloads; default disabled if(mlir_attention_enabled(ctx) or enable_extra) { @@ -1092,6 +1156,9 @@ void fuse_mlir::apply(module_pass_manager& mpm) const match::find_matches(mpm, find_pointwise_mlir{}); match::find_matches(mpm, find_unpack_int4_mlir_op{}); + + for(int i=0;i<4;i++) + match::find_matches(mpm, find_mlir_reshape_ops{}); #else (void)mpm;