From 5c4197271cca4d199ce51bcd3aac0ef89afe60be Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 9 Jan 2025 13:59:19 -0600 Subject: [PATCH 1/2] Relayout convolution --- src/targets/gpu/fuse_mlir.cpp | 39 +++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 65a27a76ad1..dce1a30cf0d 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -1040,6 +1040,42 @@ struct find_unpack_int4_mlir_op } }; +struct find_convolution_reshape +{ + auto matcher() const + { + return match::name("reshape")(match::arg(0)(match::name("convolution").bind("convolution"))); + } + + void apply(module_pass_manager& mpm, const match::matcher_result& r) const + { + auto ins = r.result; + auto conv = r.instructions["convolution"]; + auto out_dims = ins->get_shape().lens(); + auto conv_dims = conv->get_shape().lens(); + if(out_dims.size() != 5) + return; + if(conv_dims.size() != 4) + return; + auto perm = find_permutation(conv->get_shape()); + if(perm.back() != 1) + return; + if(out_dims[0] != conv_dims[0]) + return; + if(not std::equal(conv_dims.begin() + 2, conv_dims.end(), out_dims.begin() + 3, out_dims.end())) + return; + if(out_dims[2] > 32) + return; + if(out_dims[1] < 4) + return; + auto reshape = mpm.get_module().insert_instruction(ins,ins->get_operator(), ins->inputs()); + auto t1 = mpm.get_module().insert_instruction(ins, make_op("transpose", {{"permutation", {0, 1, 3, 4, 2}}}), reshape); + auto c = mpm.get_module().insert_instruction(ins, make_op("contiguous"), t1); + auto t2 = mpm.get_module().insert_instruction(ins, make_op("transpose", {{"permutation", {0, 1, 4, 2, 3}}}), c); + mpm.get_module().replace_instruction(ins, t2); + } +}; + } // namespace #endif // MIGRAPHX_MLIR @@ -1061,6 +1097,9 @@ void fuse_mlir::apply(module_pass_manager& mpm) const return std::max(m1, m2); }; + mpm.get_module().debug_print(); + match::find_matches(mpm, find_convolution_reshape{}); + mpm.get_module().debug_print(); // Attention offloads; default disabled if(mlir_attention_enabled(ctx) or enable_extra) { From 134979f57fc5be068ce19318b9d6d588db9005c7 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 9 Jan 2025 15:05:55 -0600 Subject: [PATCH 2/2] Fuse output reshapes --- src/targets/gpu/fuse_mlir.cpp | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index dce1a30cf0d..bc4e6a63e34 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -1040,6 +1040,32 @@ struct find_unpack_int4_mlir_op } }; +struct find_mlir_reshape_ops +{ + auto matcher() const + { + auto reshapes = reshaper_names(); + // slice is not supported + reshapes.erase("slice"); + return match::name(reshapes)(match::arg(0)(match::name("gpu::mlir_op")(match::used_once())), match::used_once()); + } + + void apply(module_pass_manager& mpm, const match::matcher_result& r) const + { + auto ins = r.result; + auto mlir_ins = ins->inputs().front(); + + auto* mm = mlir_ins->module_inputs().front(); + module_ref nm = mpm.create_module(mm->name() + ":" + ins->name()); + nm->set_bypass(); + + auto y = nm->fuse(*mm, mlir_ins->inputs()); + auto ret = nm->add_instruction(ins->get_operator(), y); + nm->add_return({ret}); + mpm.get_module().replace_instruction(ins, mlir_ins->get_operator(), mlir_ins->inputs(), {nm}); + } +}; + struct find_convolution_reshape { auto matcher() const @@ -1069,6 +1095,7 @@ struct find_convolution_reshape if(out_dims[1] < 4) return; auto reshape = mpm.get_module().insert_instruction(ins,ins->get_operator(), ins->inputs()); + // auto t2 = mpm.get_module().insert_instruction(ins, make_op("layout", {{"permutation", {0, 1, 3, 4, 2}}}), reshape); auto t1 = mpm.get_module().insert_instruction(ins, make_op("transpose", {{"permutation", {0, 1, 3, 4, 2}}}), reshape); auto c = mpm.get_module().insert_instruction(ins, make_op("contiguous"), t1); auto t2 = mpm.get_module().insert_instruction(ins, make_op("transpose", {{"permutation", {0, 1, 4, 2, 3}}}), c); @@ -1097,9 +1124,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const return std::max(m1, m2); }; - mpm.get_module().debug_print(); match::find_matches(mpm, find_convolution_reshape{}); - mpm.get_module().debug_print(); // Attention offloads; default disabled if(mlir_attention_enabled(ctx) or enable_extra) { @@ -1131,6 +1156,9 @@ void fuse_mlir::apply(module_pass_manager& mpm) const match::find_matches(mpm, find_pointwise_mlir{}); match::find_matches(mpm, find_unpack_int4_mlir_op{}); + + for(int i=0;i<4;i++) + match::find_matches(mpm, find_mlir_reshape_ops{}); #else (void)mpm;