From 70336db4f00b6e997139accc90dd23695f0331ea Mon Sep 17 00:00:00 2001 From: charlie Date: Thu, 3 Oct 2024 11:01:49 -0500 Subject: [PATCH 01/23] Initial --- src/fp8_ocp_to_fnuz.cpp | 30 +++++++++++++++ src/include/migraphx/fp8_ocp_to_fnuz.hpp | 49 ++++++++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 src/fp8_ocp_to_fnuz.cpp create mode 100644 src/include/migraphx/fp8_ocp_to_fnuz.hpp diff --git a/src/fp8_ocp_to_fnuz.cpp b/src/fp8_ocp_to_fnuz.cpp new file mode 100644 index 00000000000..c2ff12f9625 --- /dev/null +++ b/src/fp8_ocp_to_fnuz.cpp @@ -0,0 +1,30 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/include/migraphx/fp8_ocp_to_fnuz.hpp b/src/include/migraphx/fp8_ocp_to_fnuz.hpp new file mode 100644 index 00000000000..8c9e4a58390 --- /dev/null +++ b/src/include/migraphx/fp8_ocp_to_fnuz.hpp @@ -0,0 +1,49 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_FP8_OCP_TO_FNUZ_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_FP8_OCP_TO_FNUZ_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module; + +/** + * Convert fp8 types from OCP format to FNUZ format if on + * hardware that only supports the FNUZ format. + */ +struct MIGRAPHX_EXPORT fp8_ocp_to_fnuz +{ + std::string name() const { return "fp8_ocp_to_fnuz"; } + void apply(module& m) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif From bdebeb5776d5839812ffc4df8b8897b12a4ec129 Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 15 Nov 2024 15:57:53 -0600 Subject: [PATCH 02/23] progress --- src/include/migraphx/fp8_ocp_to_fnuz.hpp | 2 ++ src/targets/gpu/compile_gen.cpp | 1 + .../include/migraphx/kernels/bit_cast.hpp | 2 ++ test/gpu/jit.cpp | 35 +++++++++++++++++++ 4 files changed, 40 insertions(+) diff --git a/src/include/migraphx/fp8_ocp_to_fnuz.hpp b/src/include/migraphx/fp8_ocp_to_fnuz.hpp index 8c9e4a58390..3b1537547d9 100644 --- a/src/include/migraphx/fp8_ocp_to_fnuz.hpp +++ b/src/include/migraphx/fp8_ocp_to_fnuz.hpp @@ -36,6 +36,8 @@ struct module; /** * Convert fp8 types from OCP format to FNUZ format if on * hardware that only supports the FNUZ format. + * + * Handles Uses the same bit represenation */ struct MIGRAPHX_EXPORT fp8_ocp_to_fnuz { diff --git a/src/targets/gpu/compile_gen.cpp b/src/targets/gpu/compile_gen.cpp index 82ff3c4a2d0..6c58497314e 100644 --- a/src/targets/gpu/compile_gen.cpp +++ b/src/targets/gpu/compile_gen.cpp @@ -326,6 +326,7 @@ static void generate_pointwise(cpp_generator& gg, g.add_point_op("less", "migraphx::abs(${0} < ${1})"); g.add_point_op("greater", "migraphx::abs(${0} > ${1})"); g.add_point_op("not", "migraphx::abs(not ${0})"); + g.add_point_op("bit_cast", "$migraphx::bit_cast<${0}>(${1})"); // Add explict conversions g.fresult( [](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; }); diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp index c98395bbe10..c2d74d2401c 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp @@ -25,6 +25,7 @@ #include namespace migraphx { + template {} and is_trivially_copyable{})> @@ -33,5 +34,6 @@ inline constexpr To bit_cast(From fr) noexcept static_assert(sizeof(To) == sizeof(From)); return __builtin_bit_cast(To, fr); } + } // namespace migraphx #endif // MIGRAPHX_GUARD_KERNELS_BITCAST_HPP diff --git a/test/gpu/jit.cpp b/test/gpu/jit.cpp index 8b25df0cac8..eeb3f111fa6 100644 --- a/test/gpu/jit.cpp +++ b/test/gpu/jit.cpp @@ -155,6 +155,24 @@ int main() {} )__migraphx__"; +const std::string bit_cast_kernel = R"__migraphx__( +#include +#include +#include + +namespace migraphx { +extern "C" { +__global__ void kernel(${From} fr) +{ + migraphx::bit_cast<${To}>(fr); +} +} +} + +int main() {} + +)__migraphx__"; + migraphx::src_file make_src_file(const std::string& name, const std::string& content) { return {name, content}; @@ -446,4 +464,21 @@ TEST_CASE(assert_type_min_max) } } +// test bit_cast +TEST_CASE(gpu_bit_cast) +{ + migraphx::shape input{migraphx::shape::fp8e4m3fn_type, {5, 2}}; + migraphx::gpu::hip_compile_options options; + options.global = 1024; + options.local = 1024; + options.inputs = {input}; + options.output = input; + auto src = migraphx::interpolate_string( + bit_cast_kernel, + {{"To", migraphx::shape::cpp_type(migraphx::shape::fp8e4m3fn_type)}, + {"From", migraphx::shape::cpp_type(migraphx::shape::fp8e4m3fnuz_type)}}); + auto co = migraphx::gpu::compile_hip_code_object(src, options); + (void)co; +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From a1fb21eb96d1e1e3fe097ea74c280a4ecd09fc76 Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 22 Nov 2024 12:16:29 -0600 Subject: [PATCH 03/23] cleanup --- src/CMakeLists.txt | 1 + src/include/migraphx/op/bit_cast.hpp | 99 ++++++++++++++++++++++++++++ src/targets/gpu/compile_gen.cpp | 1 - test/gpu/jit.cpp | 35 ---------- test/ref/bit_cast.cpp | 75 +++++++++++++++++++++ test/verify/test_bit_cast.cpp | 55 ++++++++++++++++ 6 files changed, 230 insertions(+), 36 deletions(-) create mode 100644 src/include/migraphx/op/bit_cast.hpp create mode 100644 test/ref/bit_cast.cpp create mode 100644 test/verify/test_bit_cast.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d7269547c9a..1bfab6bf790 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -145,6 +145,7 @@ register_migraphx_ops( as_shape atanh atan + bit_cast bitwise_and broadcast broadcast_for_dot diff --git a/src/include/migraphx/op/bit_cast.hpp b/src/include/migraphx/op/bit_cast.hpp new file mode 100644 index 00000000000..6f81a4164be --- /dev/null +++ b/src/include/migraphx/op/bit_cast.hpp @@ -0,0 +1,99 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_OPERATORS_BIT_CAST_HPP +#define MIGRAPHX_GUARD_OPERATORS_BIT_CAST_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +/** + * Obtain a value of type `target_type` by reinterpreting + * the object represnetaion of the input. Originally used + * for casting from fp8e4m3fn to fp8e4m3fnuz. + */ +struct bit_cast : unary +{ + shape::type_t target_type; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.target_type, "target_type")); + } + + shape compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this, true}.has(1); + auto input = inputs.at(0); + std::size_t target_type_size; + shape::visit(target_type, [&](auto as) { target_type_size = as.size(); }); + if(input.type_size() != target_type_size) + { + MIGRAPHX_THROW("BIT_CAST: target_type has different type_size from input's"); + } + if(input.dynamic()) + { + return {target_type, input.dyn_dims()}; + } + else + { + return {target_type, input.lens(), input.strides()}; + } + } + + std::string point_op() const + { + return "${function:bit_cast}<" + shape::cpp_type(target_type) + ">(${0})"; + } + + argument compute(const dyn_output& dyn_out, std::vector args) const + { + argument result{dyn_out.computed_shape}; + result.visit([&](auto output) { + using otype = typename decltype(output)::value_type; + args[0].visit([&](auto input) { + using itype = typename decltype(input)::value_type; + if constexpr(sizeof(otype) == sizeof(itype)) + { + par_transform(input.begin(), input.end(), output.begin(), [&](auto x) { + return __builtin_bit_cast(otype, x); + }); + } + else + MIGRAPHX_THROW("BIT_CAST: type size mismatch"); + }); + }); + return result; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/compile_gen.cpp b/src/targets/gpu/compile_gen.cpp index 6c58497314e..82ff3c4a2d0 100644 --- a/src/targets/gpu/compile_gen.cpp +++ b/src/targets/gpu/compile_gen.cpp @@ -326,7 +326,6 @@ static void generate_pointwise(cpp_generator& gg, g.add_point_op("less", "migraphx::abs(${0} < ${1})"); g.add_point_op("greater", "migraphx::abs(${0} > ${1})"); g.add_point_op("not", "migraphx::abs(not ${0})"); - g.add_point_op("bit_cast", "$migraphx::bit_cast<${0}>(${1})"); // Add explict conversions g.fresult( [](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; }); diff --git a/test/gpu/jit.cpp b/test/gpu/jit.cpp index eeb3f111fa6..8b25df0cac8 100644 --- a/test/gpu/jit.cpp +++ b/test/gpu/jit.cpp @@ -155,24 +155,6 @@ int main() {} )__migraphx__"; -const std::string bit_cast_kernel = R"__migraphx__( -#include -#include -#include - -namespace migraphx { -extern "C" { -__global__ void kernel(${From} fr) -{ - migraphx::bit_cast<${To}>(fr); -} -} -} - -int main() {} - -)__migraphx__"; - migraphx::src_file make_src_file(const std::string& name, const std::string& content) { return {name, content}; @@ -464,21 +446,4 @@ TEST_CASE(assert_type_min_max) } } -// test bit_cast -TEST_CASE(gpu_bit_cast) -{ - migraphx::shape input{migraphx::shape::fp8e4m3fn_type, {5, 2}}; - migraphx::gpu::hip_compile_options options; - options.global = 1024; - options.local = 1024; - options.inputs = {input}; - options.output = input; - auto src = migraphx::interpolate_string( - bit_cast_kernel, - {{"To", migraphx::shape::cpp_type(migraphx::shape::fp8e4m3fn_type)}, - {"From", migraphx::shape::cpp_type(migraphx::shape::fp8e4m3fnuz_type)}}); - auto co = migraphx::gpu::compile_hip_code_object(src, options); - (void)co; -} - int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/ref/bit_cast.cpp b/test/ref/bit_cast.cpp new file mode 100644 index 00000000000..4f9438ef4fd --- /dev/null +++ b/test/ref/bit_cast.cpp @@ -0,0 +1,75 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include + +#include + +TEST_CASE(bit_cast_fp8) +{ + using migraphx::fp8::fp8e4m3fn; + using migraphx::fp8::fp8e4m3fnuz; + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::fp8e4m3fn_type, {2, 2}}; + std::vector data; + data.push_back(fp8e4m3fn{26.0f}); + data.push_back(fp8e4m3fn{3.0f}); + data.push_back(fp8e4m3fn{96.0f}); + data.push_back(fp8e4m3fn{-1.25f}); + auto lit = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction( + migraphx::make_op("bit_cast", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), lit); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector results_vector(4); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold; + gold.push_back(fp8e4m3fnuz{13.0f}); + gold.push_back(fp8e4m3fnuz{1.5f}); + gold.push_back(fp8e4m3fnuz{48.0f}); + gold.push_back(fp8e4m3fnuz{-0.625f}); + EXPECT(results_vector == gold); +} + +TEST_CASE(bit_cast_uint8) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::int8_type, {2, 2}}; + std::vector data = {23, -3, 0, -1}; + auto lit = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction( + migraphx::make_op("bit_cast", {{"target_type", migraphx::shape::uint8_type}}), lit); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector results_vector(4); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {23, 253, 0, 255}; + EXPECT(results_vector == gold); +} diff --git a/test/verify/test_bit_cast.cpp b/test/verify/test_bit_cast.cpp new file mode 100644 index 00000000000..e83de86241e --- /dev/null +++ b/test/verify/test_bit_cast.cpp @@ -0,0 +1,55 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +template +struct test_bit_cast : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sa{From, {8, 24}}; + migraphx::shape sb{From, {24, 6}}; + auto pa = mm->add_parameter("a", sa); + auto pb = mm->add_parameter("b", sb); + auto ia = mm->add_instruction( + migraphx::make_op("bit_cast", {{"target_type", migraphx::to_value(To)}}), pa); + auto ib = mm->add_instruction( + migraphx::make_op("bit_cast", {{"target_type", migraphx::to_value(To)}}), pb); + mm->add_instruction(migraphx::make_op("dot"), ia, ib); + + return p; + }; + std::string section() const { return "gemm"; } +}; + +template struct test_bit_cast; +template struct test_bit_cast; From b8e2041eb0f28140874a16b4bcbdc8fa054ab64e Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 22 Nov 2024 12:18:26 -0600 Subject: [PATCH 04/23] remove unneeded files --- src/fp8_ocp_to_fnuz.cpp | 30 -------------- src/include/migraphx/fp8_ocp_to_fnuz.hpp | 51 ------------------------ 2 files changed, 81 deletions(-) delete mode 100644 src/fp8_ocp_to_fnuz.cpp delete mode 100644 src/include/migraphx/fp8_ocp_to_fnuz.hpp diff --git a/src/fp8_ocp_to_fnuz.cpp b/src/fp8_ocp_to_fnuz.cpp deleted file mode 100644 index c2ff12f9625..00000000000 --- a/src/fp8_ocp_to_fnuz.cpp +++ /dev/null @@ -1,30 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ -#include - -namespace migraphx { -inline namespace MIGRAPHX_INLINE_NS { - -} // namespace MIGRAPHX_INLINE_NS -} // namespace migraphx diff --git a/src/include/migraphx/fp8_ocp_to_fnuz.hpp b/src/include/migraphx/fp8_ocp_to_fnuz.hpp deleted file mode 100644 index 3b1537547d9..00000000000 --- a/src/include/migraphx/fp8_ocp_to_fnuz.hpp +++ /dev/null @@ -1,51 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ -#ifndef MIGRAPHX_GUARD_MIGRAPHX_FP8_OCP_TO_FNUZ_HPP -#define MIGRAPHX_GUARD_MIGRAPHX_FP8_OCP_TO_FNUZ_HPP - -#include -#include -#include - -namespace migraphx { -inline namespace MIGRAPHX_INLINE_NS { - -struct module; - -/** - * Convert fp8 types from OCP format to FNUZ format if on - * hardware that only supports the FNUZ format. - * - * Handles Uses the same bit represenation - */ -struct MIGRAPHX_EXPORT fp8_ocp_to_fnuz -{ - std::string name() const { return "fp8_ocp_to_fnuz"; } - void apply(module& m) const; -}; - -} // namespace MIGRAPHX_INLINE_NS -} // namespace migraphx - -#endif From 83664347c031d240117b89b0c23a5e4870158067 Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 22 Nov 2024 15:03:21 -0600 Subject: [PATCH 05/23] Fix bit_cast kernel --- .../include/migraphx/kernels/bit_cast.hpp | 16 +++++++++++++++- test/verify/test_bit_cast.cpp | 9 ++++----- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp index c2d74d2401c..4e8fcad3156 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp @@ -23,13 +23,27 @@ #define MIGRAPHX_GUARD_KERNELS_BITCAST_HPP #include +#include namespace migraphx { template ()), MIGRAPHX_REQUIRES(is_trivially_copyable{} and is_trivially_copyable{})> -inline constexpr To bit_cast(From fr) noexcept +inline constexpr auto bit_cast(From fr) noexcept +{ + return vec_transform(fr)([](auto x) -> To { + static_assert(sizeof(To) == sizeof(decltype(x))); + return __builtin_bit_cast(To, x); + }); +} + +template ()), + MIGRAPHX_REQUIRES(is_trivially_copyable{} and is_trivially_copyable{})> +inline constexpr auto bit_cast(From fr) noexcept { static_assert(sizeof(To) == sizeof(From)); return __builtin_bit_cast(To, fr); diff --git a/test/verify/test_bit_cast.cpp b/test/verify/test_bit_cast.cpp index e83de86241e..e99af258d91 100644 --- a/test/verify/test_bit_cast.cpp +++ b/test/verify/test_bit_cast.cpp @@ -40,16 +40,15 @@ struct test_bit_cast : verify_program> migraphx::shape sb{From, {24, 6}}; auto pa = mm->add_parameter("a", sa); auto pb = mm->add_parameter("b", sb); - auto ia = mm->add_instruction( + mm->add_instruction( migraphx::make_op("bit_cast", {{"target_type", migraphx::to_value(To)}}), pa); - auto ib = mm->add_instruction( + mm->add_instruction( migraphx::make_op("bit_cast", {{"target_type", migraphx::to_value(To)}}), pb); - mm->add_instruction(migraphx::make_op("dot"), ia, ib); - return p; }; - std::string section() const { return "gemm"; } }; +template struct test_bit_cast; template struct test_bit_cast; template struct test_bit_cast; +template struct test_bit_cast; From 697d459fd44490e179d4e2c88d69fa02835ba079 Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 27 Nov 2024 13:00:29 -0600 Subject: [PATCH 06/23] progress --- src/fp8_ocp_to_nanoo.cpp | 61 +++++++++++++++++++++++ src/include/migraphx/fp8_ocp_to_nanoo.hpp | 48 ++++++++++++++++++ src/simplify_qdq.cpp | 1 + src/targets/gpu/target.cpp | 1 + 4 files changed, 111 insertions(+) create mode 100644 src/fp8_ocp_to_nanoo.cpp create mode 100644 src/include/migraphx/fp8_ocp_to_nanoo.hpp diff --git a/src/fp8_ocp_to_nanoo.cpp b/src/fp8_ocp_to_nanoo.cpp new file mode 100644 index 00000000000..db0d4fcf6b5 --- /dev/null +++ b/src/fp8_ocp_to_nanoo.cpp @@ -0,0 +1,61 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct match_fp8ocp_dq_convert_to_fp8nanoo +{ + /** + * Match dequantizelinear instructions. + * Bind the scale and zero_point inputs. + */ + static auto dequantizelinear_op(const std::string& scale, const std::string& zp) + { + return match::name("dequantizelinear")( + match::arg(1)(match::skip_broadcasts(match::is_constant().bind(scale))), + match::arg(2)(match::skip_broadcasts(match::is_constant().bind(zp)))); + } + + auto matcher() const + { + return dequantizelinear_op("scale", "zp"); + } + + void fp8_ocp_to_nanoo::apply(module_pass_manager& mpm) const + { + // Check if input is a quantizelinear instruction. + // Change how the quantizelinear works if it is by changing the last convert + // to where instructions into a bit_cast instruction. + // + // if input is a parameter just add the where instructions and bit_cast. + // + // Multiply the scale of the dequantizelinear by 2. + } + +}; +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/include/migraphx/fp8_ocp_to_nanoo.hpp b/src/include/migraphx/fp8_ocp_to_nanoo.hpp new file mode 100644 index 00000000000..97f49878fbc --- /dev/null +++ b/src/include/migraphx/fp8_ocp_to_nanoo.hpp @@ -0,0 +1,48 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_RTGLIB_FP8_OCP_TO_NANOO_HPP +#define MIGRAPHX_GUARD_RTGLIB_FP8_OCP_TO_NANOO_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +/** + * Convert fp8e4m3fn to fp8e4m3fnuz for hardware that only supports fp8e4m3fnuz data types intrinsically. + * Conversion uses the same bit representation and adjusts scaling factors at the dequantization. + * Using the same bit representation from fp8e4m3fn to fp8e4m3fnuz halves the floating point representation. + * This pass should run before simplify_qdq so that the scales and zero points calculated by simplify_qdq have the correct adjusted scaling factors + */ +struct MIGRAPHX_EXPORT fp8_ocp_to_nanoo +{ + std::string name() const { return "fp8_ocp_to_nanoo"; } + void apply(module_pass_manager& mpm) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/simplify_qdq.cpp b/src/simplify_qdq.cpp index 5eab6ab392b..36f3754a311 100644 --- a/src/simplify_qdq.cpp +++ b/src/simplify_qdq.cpp @@ -118,6 +118,7 @@ struct match_find_quantizable_ops static auto dequantizelinear_op(const std::string& scale, const std::string& zp) { + // TODO: do we need the condition on arg(0)? return match::name("dequantizelinear")( match::arg(0)(match::skip(match::name("quantizelinear"))(match::any())), match::arg(1)(match::skip_broadcasts(match::is_constant().bind(scale))), diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index a0edac5eb17..3b177ed6ddf 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include From 4b6c8c16c38a073de3474260920c1c70d4c7d53f Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 27 Nov 2024 13:20:12 -0600 Subject: [PATCH 07/23] fix template for gpu bit_cast --- .../kernels/include/migraphx/kernels/bit_cast.hpp | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp index e5ee022f167..e559658a004 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp @@ -29,9 +29,8 @@ namespace migraphx { template ()), MIGRAPHX_REQUIRES(is_trivially_copyable{} and is_trivially_copyable{})> -inline constexpr To bit_cast(From fr) noexcept +inline constexpr auto bit_cast(From fr) noexcept { return vec_transform(fr)([](auto x) -> To { static_assert(sizeof(To) == sizeof(decltype(x))); @@ -39,15 +38,5 @@ inline constexpr To bit_cast(From fr) noexcept }); } -template ()), - MIGRAPHX_REQUIRES(is_trivially_copyable{} and is_trivially_copyable{})> -inline constexpr auto bit_cast(From fr) noexcept -{ - static_assert(sizeof(To) == sizeof(From)); - return __builtin_bit_cast(To, fr); -} - } // namespace migraphx #endif // MIGRAPHX_GUARD_KERNELS_BITCAST_HPP From 95a3cd73828ef2bef3207ed2366a25274a85b7f2 Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 27 Nov 2024 14:50:53 -0600 Subject: [PATCH 08/23] first implementation --- src/CMakeLists.txt | 1 + src/fp8_ocp_to_nanoo.cpp | 62 ++++++++++++++++++++++++++++++++------ src/targets/gpu/target.cpp | 5 +-- 3 files changed, 57 insertions(+), 11 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index eda6ea626e4..10d7193b67c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -57,6 +57,7 @@ add_library(migraphx file_buffer.cpp fileutils.cpp fp_to_double.cpp + fp8_ocp_to_nanoo.cpp fuse_concat.cpp fuse_pointwise.cpp fuse_pointwise_reduce.cpp diff --git a/src/fp8_ocp_to_nanoo.cpp b/src/fp8_ocp_to_nanoo.cpp index db0d4fcf6b5..6839356aebe 100644 --- a/src/fp8_ocp_to_nanoo.cpp +++ b/src/fp8_ocp_to_nanoo.cpp @@ -23,6 +23,9 @@ */ #include #include +#include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -45,17 +48,58 @@ struct match_fp8ocp_dq_convert_to_fp8nanoo return dequantizelinear_op("scale", "zp"); } - void fp8_ocp_to_nanoo::apply(module_pass_manager& mpm) const + auto apply(module& m, const match::matcher_result& r) const { - // Check if input is a quantizelinear instruction. - // Change how the quantizelinear works if it is by changing the last convert - // to where instructions into a bit_cast instruction. - // - // if input is a parameter just add the where instructions and bit_cast. - // - // Multiply the scale of the dequantizelinear by 2. - } + auto dq = r.result; + auto x = dq->inputs().front(); + shape::type_t x_type = x->get_shape().type(); + if(x_type != shape::fp8e4m3fn_type) + { + return; + } + auto dq_scale = r.instructions["scale"]; + auto dq_zp = r.instructions["zp"]; + + x = m.insert_instruction(dq, make_op("bit_cast", {{"target_type", shape::fp8e4m3fnuz_type}}), x); + auto x_lens = x->get_shape().lens(); + + // negative zero in fp8e4m3fn to zero in fp8e4m3fnuz + // a == 0x80 ? 0x0 : a + std::vector bits_0x80 = {fp8::fp8e4m3fnuz(0x80, fp8::fp8e4m3fnuz::from_bits())}; + auto bits_0x80_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}}, bits_0x80); + bits_0x80_lit = m.insert_instruction(dq, make_op("multibroadcast", {{"output_lens", x_lens}}), bits_0x80_lit); + auto is_neg_zero = m.insert_instruction(dq, make_op("equal"), x, bits_0x80_lit); + std::vector bits_0x00 = {fp8::fp8e4m3fnuz(0x00, fp8::fp8e4m3fnuz::from_bits())}; + auto zero_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}}, bits_0x00); + zero_lit = m.insert_instruction(dq, make_op("multibroadcast", {{"output_lens", x_lens}}), zero_lit); + x = m.insert_instruction(dq, make_op("where"), is_neg_zero, zero_lit, x); + + // positive and negative NaN in fp8e4m3fn to NaN in fp8e4m3fnuz + //(a & 0x7f) == 0x7f ? 0x80 : a + std::vector positive_nan_fp8ocp = {fp8::fp8e4m3fnuz(0x7f, fp8::fp8e4m3fnuz::from_bits())}; + auto nan_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}}, positive_nan_fp8ocp); + nan_lit = m.insert_instruction(dq, make_op("multibroadcast", {{"output_lens", x_lens}}), nan_lit); + auto cond = m.insert_instruction(dq, make_op("bitwise_and"), x, nan_lit); + cond = m.insert_instruction(dq, make_op("equal"), cond, nan_lit); + x = m.insert_instruction(dq, make_op("where"), cond, bits_0x80_lit, x); + + // adj_scale = 2 * scale + auto two_lit = m.add_literal(literal{shape{dq_scale->get_shape().type()}, {2}}); + two_lit = m.insert_instruction( + dq, make_op("multibroadcast", {{"out_lens", dq_scale->get_shape().lens()}}), two_lit); + auto adj_dq_scale = m.insert_instruction(dq, make_op("mul"), dq_scale, two_lit); + m.replace_instruction(dq, make_op("dequantizelinear"), x, adj_dq_scale, dq_zp); + return; + } }; + +void fp8_ocp_to_nanoo::apply(module_pass_manager& mpm) const +{ + module_ref mm = &mpm.get_module(); + match::find_matches(*mm, match_fp8ocp_dq_convert_to_fp8nanoo{}); + mpm.run_pass(migraphx::dead_code_elimination{}); +} + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 3b177ed6ddf..12bc5293e8c 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -103,8 +103,8 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti unsupported_types.erase(shape::type_t::tuple_type); // whiltelist supported Ops for the FP8 types - // different between fp8e4m3fnuz and OCP types because rocBLAS only has - // support for fp8e4m3fnuz + // different between NANOO and OCP types because rocBLAS only has + // support for NANOO std::set unsupported_fp8e4m3fnuz_ops = {}; if(not gpu::rocblas_fp8_available()) { @@ -171,6 +171,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, eliminate_identity{}, dead_code_elimination{}, + enable_pass(not gpu::gfx_has_fp8ocp_intrinsics(), fp8_ocp_to_nanoo{}), simplify_qdq{}, enable_pass(not mlir_enabled(), rewrite_quantization{}), dead_code_elimination{}, From 98d87608374e21f7c51777494a24b113ad150628 Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 4 Dec 2024 15:42:28 -0600 Subject: [PATCH 09/23] progress --- src/fp8_ocp_to_nanoo.cpp | 143 ++++++++++++++++++++++++++----------- src/simplify_qdq.cpp | 2 - src/targets/gpu/target.cpp | 1 + test/simplify_qdq_test.cpp | 66 +---------------- 4 files changed, 104 insertions(+), 108 deletions(-) diff --git a/src/fp8_ocp_to_nanoo.cpp b/src/fp8_ocp_to_nanoo.cpp index 6839356aebe..b4029595948 100644 --- a/src/fp8_ocp_to_nanoo.cpp +++ b/src/fp8_ocp_to_nanoo.cpp @@ -29,59 +29,99 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { - -struct match_fp8ocp_dq_convert_to_fp8nanoo +namespace { + +using fp8::fp8e4m3fnuz; + +template +auto skip_post_dq_ops(Ms... ms) +{ + return match::skip(match::name( + "broadcast", "multibroadcast", "contiguous", "transpose", "reshape", "convert"))(ms...); +} + +std::unordered_set get_quantizable_op_names() +{ + static std::unordered_set s = {"convolution", "dot"}; + return s; +} + +struct match_fp8ocp_convert_to_fp8nanoo { - /** - * Match dequantizelinear instructions. - * Bind the scale and zero_point inputs. - */ + // almost the same as matcher from simplify_qdq + // difference is that broadcasts are not skipped static auto dequantizelinear_op(const std::string& scale, const std::string& zp) { - return match::name("dequantizelinear")( - match::arg(1)(match::skip_broadcasts(match::is_constant().bind(scale))), - match::arg(2)(match::skip_broadcasts(match::is_constant().bind(zp)))); + return match::name("dequantizelinear")(match::arg(1)(match::is_constant().bind(scale)), + match::arg(2)(match::is_constant().bind(zp))); } - auto matcher() const { - return dequantizelinear_op("scale", "zp"); + auto dq1 = + match::arg(0)(skip_post_dq_ops(dequantizelinear_op("scale1", "zp1").bind("dq1"))); + auto dq2 = + match::arg(1)(skip_post_dq_ops(dequantizelinear_op("scale2", "zp2").bind("dq2"))); + return match::name(get_quantizable_op_names())(dq1, dq2); } - auto apply(module& m, const match::matcher_result& r) const + auto bit_cast_and_handle_specials(module& m, + const instruction_ref dq, + const instruction_ref x, + const instruction_ref bits_0x80_lit, + const instruction_ref bits_0x7f_lit, + const instruction_ref bits_0xff_lit, + const instruction_ref bits_0x00_lit) const { - auto dq = r.result; - auto x = dq->inputs().front(); - shape::type_t x_type = x->get_shape().type(); - if(x_type != shape::fp8e4m3fn_type) - { - return; - } - auto dq_scale = r.instructions["scale"]; - auto dq_zp = r.instructions["zp"]; - - x = m.insert_instruction(dq, make_op("bit_cast", {{"target_type", shape::fp8e4m3fnuz_type}}), x); auto x_lens = x->get_shape().lens(); - + auto cast_input = m.insert_instruction( + dq, make_op("bit_cast", {{"target_type", shape::fp8e4m3fnuz_type}}), x); + auto mb_bits_0x80_lit = m.insert_instruction( + dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0x80_lit); + auto mb_bits_0x7f_lit = m.insert_instruction( + dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0x7f_lit); + auto mb_bits_0xff_lit = m.insert_instruction( + dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0xff_lit); + auto mb_zero_lit = m.insert_instruction( + dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0x00_lit); // negative zero in fp8e4m3fn to zero in fp8e4m3fnuz // a == 0x80 ? 0x0 : a - std::vector bits_0x80 = {fp8::fp8e4m3fnuz(0x80, fp8::fp8e4m3fnuz::from_bits())}; - auto bits_0x80_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}}, bits_0x80); - bits_0x80_lit = m.insert_instruction(dq, make_op("multibroadcast", {{"output_lens", x_lens}}), bits_0x80_lit); - auto is_neg_zero = m.insert_instruction(dq, make_op("equal"), x, bits_0x80_lit); - std::vector bits_0x00 = {fp8::fp8e4m3fnuz(0x00, fp8::fp8e4m3fnuz::from_bits())}; - auto zero_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}}, bits_0x00); - zero_lit = m.insert_instruction(dq, make_op("multibroadcast", {{"output_lens", x_lens}}), zero_lit); - x = m.insert_instruction(dq, make_op("where"), is_neg_zero, zero_lit, x); + auto is_neg_zero = m.insert_instruction(dq, make_op("equal"), cast_input, mb_bits_0x80_lit); + auto ret = m.insert_instruction(dq, make_op("where"), is_neg_zero, mb_zero_lit, cast_input); // positive and negative NaN in fp8e4m3fn to NaN in fp8e4m3fnuz - //(a & 0x7f) == 0x7f ? 0x80 : a - std::vector positive_nan_fp8ocp = {fp8::fp8e4m3fnuz(0x7f, fp8::fp8e4m3fnuz::from_bits())}; - auto nan_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}}, positive_nan_fp8ocp); - nan_lit = m.insert_instruction(dq, make_op("multibroadcast", {{"output_lens", x_lens}}), nan_lit); - auto cond = m.insert_instruction(dq, make_op("bitwise_and"), x, nan_lit); - cond = m.insert_instruction(dq, make_op("equal"), cond, nan_lit); - x = m.insert_instruction(dq, make_op("where"), cond, bits_0x80_lit, x); + // (a == 0x7f or a == 0xff) ? 0x80 : a + auto eq_0x7f = m.insert_instruction(dq, make_op("equal"), ret, mb_bits_0x7f_lit); + + auto eq_0xff = m.insert_instruction(dq, make_op("equal"), ret, mb_bits_0xff_lit); + + auto cond = m.insert_instruction(dq, make_op("logical_or"), eq_0x7f, eq_0xff); + ret = m.insert_instruction(dq, make_op("where"), cond, mb_bits_0x80_lit, ret); + return ret; + } + + auto cast_to_nanoo(module& m, + const instruction_ref dq, + const instruction_ref input, + const instruction_ref dq_scale, + const instruction_ref dq_zp) const + { + auto x = input; + std::vector bits_0x80 = {fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits())}; + auto bits_0x80_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}}, bits_0x80); + + std::vector bits_0x7f = {fp8e4m3fnuz(0x7f, fp8e4m3fnuz::from_bits())}; + auto bits_0x7f_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}}, bits_0x7f); + + std::vector bits_0xff = {fp8e4m3fnuz(0xff, fp8e4m3fnuz::from_bits())}; + auto bits_0xff_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}}, bits_0xff); + + std::vector bits_0x00 = {fp8e4m3fnuz(0x00, fp8e4m3fnuz::from_bits())}; + auto bits_0x00_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}}, bits_0x00); + + x = bit_cast_and_handle_specials( + m, dq, x, bits_0x80_lit, bits_0x7f_lit, bits_0xff_lit, bits_0x00_lit); + auto adj_zp = bit_cast_and_handle_specials( + m, dq, dq_zp, bits_0x80_lit, bits_0x7f_lit, bits_0xff_lit, bits_0x00_lit); // adj_scale = 2 * scale auto two_lit = m.add_literal(literal{shape{dq_scale->get_shape().type()}, {2}}); @@ -89,16 +129,35 @@ struct match_fp8ocp_dq_convert_to_fp8nanoo dq, make_op("multibroadcast", {{"out_lens", dq_scale->get_shape().lens()}}), two_lit); auto adj_dq_scale = m.insert_instruction(dq, make_op("mul"), dq_scale, two_lit); - m.replace_instruction(dq, make_op("dequantizelinear"), x, adj_dq_scale, dq_zp); + m.replace_instruction(dq, make_op("dequantizelinear"), x, adj_dq_scale, adj_zp); + } + + auto apply(module& m, const match::matcher_result& r) const + { + auto dq1 = r.instructions["dq1"]; + auto dq2 = r.instructions["dq2"]; + auto scale1 = r.instructions["scale1"]; + auto scale2 = r.instructions["scale2"]; + auto zp1 = r.instructions["zp1"]; + auto zp2 = r.instructions["zp2"]; + + std::set supported_types = {migraphx::shape::fp8e4m3fn_type}; + if(not contains(supported_types, dq1->inputs().front()->get_shape().type()) or + not contains(supported_types, dq2->inputs().front()->get_shape().type())) + return; + + cast_to_nanoo(m, dq1, dq1->inputs().front(), scale1, zp1); + cast_to_nanoo(m, dq2, dq2->inputs().front(), scale2, zp2); return; } }; +} // namespace + void fp8_ocp_to_nanoo::apply(module_pass_manager& mpm) const { module_ref mm = &mpm.get_module(); - match::find_matches(*mm, match_fp8ocp_dq_convert_to_fp8nanoo{}); - mpm.run_pass(migraphx::dead_code_elimination{}); + match::find_matches(*mm, match_fp8ocp_convert_to_fp8nanoo{}); } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/simplify_qdq.cpp b/src/simplify_qdq.cpp index 36f3754a311..190293feb92 100644 --- a/src/simplify_qdq.cpp +++ b/src/simplify_qdq.cpp @@ -118,9 +118,7 @@ struct match_find_quantizable_ops static auto dequantizelinear_op(const std::string& scale, const std::string& zp) { - // TODO: do we need the condition on arg(0)? return match::name("dequantizelinear")( - match::arg(0)(match::skip(match::name("quantizelinear"))(match::any())), match::arg(1)(match::skip_broadcasts(match::is_constant().bind(scale))), match::arg(2)(match::skip_broadcasts(match::is_constant().bind(zp)))); } diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 12bc5293e8c..3663b2c2199 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -172,6 +172,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti eliminate_identity{}, dead_code_elimination{}, enable_pass(not gpu::gfx_has_fp8ocp_intrinsics(), fp8_ocp_to_nanoo{}), + enable_pass(not gpu::gfx_has_fp8ocp_intrinsics(), dead_code_elimination{}), simplify_qdq{}, enable_pass(not mlir_enabled(), rewrite_quantization{}), dead_code_elimination{}, diff --git a/test/simplify_qdq_test.cpp b/test/simplify_qdq_test.cpp index c3c50cb4172..cef500fbfd3 100644 --- a/test/simplify_qdq_test.cpp +++ b/test/simplify_qdq_test.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -45,75 +46,12 @@ void run_pass(migraphx::module& m) { run_passes(m, {migraphx::simplify_qdq{}, migraphx::dead_code_elimination{}}); } + void run_cse(migraphx::module& m) { run_passes(m, {migraphx::eliminate_common_subexpression{}, migraphx::dead_code_elimination{}}); } -migraphx::instruction_ref broadcast_scale(migraphx::module& m, - migraphx::instruction_ref scale, - const std::vector& out_lens, - std::size_t axis) -{ - if(scale->get_shape().lens() == out_lens) - return scale; - - migraphx::instruction_ref scale_mb; - auto scale_lens = scale->get_shape().lens(); - if(scale_lens.front() == 1 and scale_lens.size() == 1) - scale_mb = - m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), scale); - else - scale_mb = m.add_instruction( - migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", out_lens}}), scale); - return scale_mb; -} - -migraphx::instruction_ref broadcast_shift(migraphx::module& m, - migraphx::instruction_ref shift, - const std::vector& out_lens) -{ - if(shift->get_shape().lens() == out_lens) - return shift; - return m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), shift); -} - -migraphx::instruction_ref add_quantize_op(migraphx::module& m, - const std::string& name, - migraphx::instruction_ref x, - migraphx::instruction_ref scale, - migraphx::instruction_ref shift, - std::size_t q_axis = 1) -{ - auto lens = x->get_shape().lens(); - auto scale_mb = broadcast_scale(m, scale, lens, q_axis); - auto shift_mb = broadcast_shift(m, shift, lens); - return m.add_instruction(migraphx::make_op(name), x, scale_mb, shift_mb); -} - -migraphx::instruction_ref add_quantize_op(migraphx::module& m, - const std::string& name, - migraphx::instruction_ref x, - migraphx::instruction_ref scale, - std::size_t q_axis = 1) -{ - auto lens = x->get_shape().lens(); - auto scale_mb = broadcast_scale(m, scale, lens, q_axis); - return m.add_instruction(migraphx::make_op(name), x, scale_mb); -} - -migraphx::instruction_ref add_scale_mul(migraphx::module& m, - migraphx::instruction_ref scale1, - migraphx::instruction_ref scale2, - std::size_t axis1, - std::size_t axis2, - const std::vector& out_lens) -{ - auto scale1_mb = broadcast_scale(m, scale1, out_lens, axis1); - auto scale2_mb = broadcast_scale(m, scale2, out_lens, axis2); - return m.add_instruction(migraphx::make_op("mul"), scale1_mb, scale2_mb); -} - migraphx::instruction_ref init_zero_point(migraphx::module& m, migraphx::instruction_ref q_ins) { auto zp = m.add_literal(migraphx::literal{migraphx::shape{q_ins->get_shape().type()}, {0}}); From e3d84fccd84ef34a23ac23bcdfc85157ad4c4507 Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 4 Dec 2024 17:57:48 -0600 Subject: [PATCH 10/23] Fixes and first test works --- src/fp8_ocp_to_nanoo.cpp | 72 ++++++++++++++++------------ src/include/migraphx/qdq_helpers.hpp | 60 +++++++++++++++++++++++ src/simplify_qdq.cpp | 25 ++-------- test/fp8_ocp_to_nanoo_test.cpp | 46 +++++++++++------- 4 files changed, 135 insertions(+), 68 deletions(-) create mode 100644 src/include/migraphx/qdq_helpers.hpp diff --git a/src/fp8_ocp_to_nanoo.cpp b/src/fp8_ocp_to_nanoo.cpp index b4029595948..318163570d7 100644 --- a/src/fp8_ocp_to_nanoo.cpp +++ b/src/fp8_ocp_to_nanoo.cpp @@ -26,6 +26,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -33,28 +34,8 @@ namespace { using fp8::fp8e4m3fnuz; -template -auto skip_post_dq_ops(Ms... ms) -{ - return match::skip(match::name( - "broadcast", "multibroadcast", "contiguous", "transpose", "reshape", "convert"))(ms...); -} - -std::unordered_set get_quantizable_op_names() -{ - static std::unordered_set s = {"convolution", "dot"}; - return s; -} - struct match_fp8ocp_convert_to_fp8nanoo { - // almost the same as matcher from simplify_qdq - // difference is that broadcasts are not skipped - static auto dequantizelinear_op(const std::string& scale, const std::string& zp) - { - return match::name("dequantizelinear")(match::arg(1)(match::is_constant().bind(scale)), - match::arg(2)(match::is_constant().bind(zp))); - } auto matcher() const { auto dq1 = @@ -63,14 +44,14 @@ struct match_fp8ocp_convert_to_fp8nanoo match::arg(1)(skip_post_dq_ops(dequantizelinear_op("scale2", "zp2").bind("dq2"))); return match::name(get_quantizable_op_names())(dq1, dq2); } - - auto bit_cast_and_handle_specials(module& m, + + static auto bit_cast_and_handle_specials(module& m, const instruction_ref dq, const instruction_ref x, const instruction_ref bits_0x80_lit, const instruction_ref bits_0x7f_lit, const instruction_ref bits_0xff_lit, - const instruction_ref bits_0x00_lit) const + const instruction_ref bits_0x00_lit) { auto x_lens = x->get_shape().lens(); auto cast_input = m.insert_instruction( @@ -99,28 +80,55 @@ struct match_fp8ocp_convert_to_fp8nanoo return ret; } - auto cast_to_nanoo(module& m, + // Add the same broadcast instructions after adjusted scales or + // adjusted zero points from after the originals. Similar to + // propagate_quantized_ins in simplify_qdq. + static auto propagate_broadcasts(module& m, + const instruction_ref adj, + const instruction_ref ori, + const instruction_ref start, + const instruction_ref insert_pt + ) + { + auto prev_ins = start; + std::vector ins_inbetween; + // matcher skips continguous, multi/broadcasts and transposes, collect all those + // instructions + while(prev_ins != ori) + { + ins_inbetween.push_back(prev_ins); + prev_ins = prev_ins->inputs().front(); + } + auto ret = adj; + for(auto ins : reverse_iterator_for(ins_inbetween)) + { + ret = m.insert_instruction(insert_pt, (*ins)->get_operator(), {ret}); + } + return ret; + } + + static auto cast_to_nanoo(module& m, const instruction_ref dq, const instruction_ref input, const instruction_ref dq_scale, - const instruction_ref dq_zp) const + const instruction_ref dq_zp) { auto x = input; std::vector bits_0x80 = {fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits())}; - auto bits_0x80_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}}, bits_0x80); + auto bits_0x80_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x80); std::vector bits_0x7f = {fp8e4m3fnuz(0x7f, fp8e4m3fnuz::from_bits())}; - auto bits_0x7f_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}}, bits_0x7f); + auto bits_0x7f_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x7f); std::vector bits_0xff = {fp8e4m3fnuz(0xff, fp8e4m3fnuz::from_bits())}; - auto bits_0xff_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}}, bits_0xff); + auto bits_0xff_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0xff); std::vector bits_0x00 = {fp8e4m3fnuz(0x00, fp8e4m3fnuz::from_bits())}; - auto bits_0x00_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}}, bits_0x00); + auto bits_0x00_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x00); x = bit_cast_and_handle_specials( m, dq, x, bits_0x80_lit, bits_0x7f_lit, bits_0xff_lit, bits_0x00_lit); - auto adj_zp = bit_cast_and_handle_specials( + auto adj_dq_zp = bit_cast_and_handle_specials( m, dq, dq_zp, bits_0x80_lit, bits_0x7f_lit, bits_0xff_lit, bits_0x00_lit); // adj_scale = 2 * scale @@ -129,7 +137,9 @@ struct match_fp8ocp_convert_to_fp8nanoo dq, make_op("multibroadcast", {{"out_lens", dq_scale->get_shape().lens()}}), two_lit); auto adj_dq_scale = m.insert_instruction(dq, make_op("mul"), dq_scale, two_lit); - m.replace_instruction(dq, make_op("dequantizelinear"), x, adj_dq_scale, adj_zp); + adj_dq_scale = propagate_broadcasts(m, adj_dq_scale, dq_scale, dq->inputs().at(1), dq); + adj_dq_zp = propagate_broadcasts(m, adj_dq_zp, dq_zp, dq->inputs().at(2), dq); + m.replace_instruction(dq, make_op("dequantizelinear"), x, adj_dq_scale, adj_dq_zp); } auto apply(module& m, const match::matcher_result& r) const diff --git a/src/include/migraphx/qdq_helpers.hpp b/src/include/migraphx/qdq_helpers.hpp new file mode 100644 index 00000000000..e28e9625930 --- /dev/null +++ b/src/include/migraphx/qdq_helpers.hpp @@ -0,0 +1,60 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#ifndef MIGRAPHX_GUARD_RTGLIB_QDQ_HELPERS_HPP +#define MIGRAPHX_GUARD_RTGLIB_QDQ_HELPERS_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +namespace { + +template +auto skip_post_dq_ops(Ms... ms) +{ + return match::skip(match::name( + "broadcast", "multibroadcast", "contiguous", "transpose", "reshape", "convert"))(ms...); +} + +static std::unordered_set get_quantizable_op_names() +{ + static std::unordered_set s = {"convolution", "dot"}; + return s; +} + +static auto dequantizelinear_op(const std::string& scale, const std::string& zp) +{ + return match::name("dequantizelinear")( + match::arg(1)(match::skip_broadcasts(match::is_constant().bind(scale))), + match::arg(2)(match::skip_broadcasts(match::is_constant().bind(zp)))); +} +} // namespace +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/simplify_qdq.cpp b/src/simplify_qdq.cpp index 7dbc3d236ce..0444ff0c64d 100644 --- a/src/simplify_qdq.cpp +++ b/src/simplify_qdq.cpp @@ -36,24 +36,12 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace { -template -auto skip_post_dq_ops(Ms... ms) -{ - return match::skip(match::name( - "broadcast", "multibroadcast", "contiguous", "transpose", "reshape", "convert"))(ms...); -} - -std::unordered_set get_quantizable_op_names() -{ - static std::unordered_set s = {"convolution", "dot"}; - return s; -} - struct match_find_quantizable_ops { static bool @@ -117,13 +105,6 @@ struct match_find_quantizable_ops return qinp; } - static auto dequantizelinear_op(const std::string& scale, const std::string& zp) - { - return match::name("dequantizelinear")( - match::arg(1)(match::skip_broadcasts(match::is_constant().bind(scale))), - match::arg(2)(match::skip_broadcasts(match::is_constant().bind(zp)))); - } - auto matcher() const { auto dq1 = @@ -230,7 +211,9 @@ struct match_find_quantizable_ops is_valid_qparam(zp1, out_lens, out_lens.size() - 2) and is_valid_qparam(scale2, out_lens, out_lens.size() - 1) and is_valid_qparam(zp2, out_lens, out_lens.size() - 1))) + { return; + } // This implementation supports both arguments being per-axis affine quantized // In practice, inputs are per-tensor affine and weights are per-axis symmetric @@ -246,7 +229,7 @@ struct match_find_quantizable_ops auto zero_lit = m.add_literal(literal{shape{dq->get_shape().type()}, {0}}); out_zp = m.insert_instruction( qop, make_op("multibroadcast", {{"out_lens", dq->get_shape().lens()}}), zero_lit); - + auto zp1_bc = m.insert_instruction( qop, qparam_broadcast_op(zp1, arg1_lens, arg1_lens.size() - 2), zp1); auto zp2_bc = m.insert_instruction( diff --git a/test/fp8_ocp_to_nanoo_test.cpp b/test/fp8_ocp_to_nanoo_test.cpp index 9615341c478..bf589b49ac9 100644 --- a/test/fp8_ocp_to_nanoo_test.cpp +++ b/test/fp8_ocp_to_nanoo_test.cpp @@ -107,10 +107,10 @@ auto cast_fp8_helper(migraphx::module& m, std::vector bits_0x7f = {fp8e4m3fnuz(0x7f, fp8e4m3fnuz::from_bits())}; std::vector bits_0xff = {fp8e4m3fnuz(0xff, fp8e4m3fnuz::from_bits())}; std::vector bits_0x00 = {fp8e4m3fnuz(0x00, fp8e4m3fnuz::from_bits())}; - auto bits_0x80_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}}, bits_0x80); - auto bits_0x7f_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}}, bits_0x7f); - auto bits_0xff_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}}, bits_0xff); - auto bits_0x00_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}}, bits_0x00); + auto bits_0x80_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x80); + auto bits_0x7f_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x7f); + auto bits_0xff_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0xff); + auto bits_0x00_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x00); auto cast_input = bit_cast_and_handle_specials( m, dq_input, bits_0x80_lit, bits_0x7f_lit, bits_0xff_lit, bits_0x00_lit); @@ -129,14 +129,15 @@ TEST_CASE(fp8_gemm_conversion) { using migraphx::fp8::fp8e4m3fn; using migraphx::fp8::fp8e4m3fnuz; + std::vector data_lens = {2, 3, 8, 8}; migraphx::module m1; { - auto a = m1.add_parameter("a", {migraphx::shape::float_type, {2, 3, 8, 8}}); - auto b = m1.add_parameter("b", {migraphx::shape::float_type, {2, 3, 8, 8}}); + auto a = m1.add_parameter("a", {migraphx::shape::float_type, data_lens}); + auto b = m1.add_parameter("b", {migraphx::shape::float_type, data_lens}); auto scale = m1.add_literal(0.5f); std::vector data; data.push_back(fp8e4m3fn{0.f}); - auto zero = m1.add_literal(migraphx::shape{migraphx::shape::fp8e4m3fn_type, {1}}, data); + auto zero = m1.add_literal(migraphx::shape{migraphx::shape::fp8e4m3fn_type, {1}, {0}}, data); auto qa = add_quantize_op(m1, "quantizelinear", a, scale, zero); auto qb = add_quantize_op(m1, "quantizelinear", b, scale, zero); @@ -152,21 +153,27 @@ TEST_CASE(fp8_gemm_conversion) // expected after fp8_ocp_to_nanoo migraphx::module m2; { - - auto a = m2.add_parameter("a", {migraphx::shape::float_type, {2, 3, 8, 8}}); - auto b = m2.add_parameter("b", {migraphx::shape::float_type, {2, 3, 8, 8}}); + auto a = m2.add_parameter("a", {migraphx::shape::float_type, data_lens}); + auto b = m2.add_parameter("b", {migraphx::shape::float_type, data_lens}); auto scale = m2.add_literal(0.5f); std::vector data; data.push_back(fp8e4m3fn{0.f}); - auto zero = m2.add_literal(migraphx::shape{migraphx::shape::fp8e4m3fn_type, {1}}, data); + auto zero = m2.add_literal(migraphx::shape{migraphx::shape::fp8e4m3fn_type, {1}, {0}}, data); auto qa = add_quantize_op(m2, "quantizelinear", a, scale, zero); auto qb = add_quantize_op(m2, "quantizelinear", b, scale, zero); - auto outs_a = cast_fp8_helper(m2, qa, qa->inputs().at(1), qa->inputs().at(2)); - auto da = add_quantize_op(m2, "dequantizelinear", outs_a.at(0), outs_a.at(1), outs_a.at(2)); - auto outs_b = cast_fp8_helper(m2, qb, qb->inputs().at(1), qb->inputs().at(2)); - auto db = add_quantize_op(m2, "dequantizelinear", outs_b.at(0), outs_b.at(1), outs_b.at(2)); + auto outs_a = cast_fp8_helper(m2, qa, scale, zero); + auto adj_a = outs_a.at(0); + auto mb_scales_a = m2.add_instruction(make_op("multibroadcast", {{"out_lens", data_lens}}), outs_a.at(1)); + auto mb_zp_a = m2.add_instruction(make_op("multibroadcast", {{"out_lens", data_lens}}), outs_a.at(2)); + auto da = m2.add_instruction(make_op("dequantizelinear"), adj_a, mb_scales_a, mb_zp_a); + + auto outs_b = cast_fp8_helper(m2, qb, scale, zero); + auto adj_b = outs_b.at(0); + auto mb_scales_b = m2.add_instruction(make_op("multibroadcast", {{"out_lens", data_lens}}), outs_b.at(1)); + auto mb_zp_b = m2.add_instruction(make_op("multibroadcast", {{"out_lens", data_lens}}), outs_b.at(2)); + auto db = m2.add_instruction(make_op("dequantizelinear"), adj_b, mb_scales_b, mb_zp_b); auto dot = m2.add_instruction(migraphx::make_op("dot"), da, db); m2.add_return({dot}); @@ -182,7 +189,7 @@ TEST_CASE(fp8_gemm_conversion) auto scale = m3.add_literal(0.5f); std::vector data; data.push_back(fp8e4m3fn{0.f}); - auto zero = m3.add_literal(migraphx::shape{migraphx::shape::fp8e4m3fn_type, {1}}, data); + auto zero = m3.add_literal(migraphx::shape{migraphx::shape::fp8e4m3fn_type, {1}, {0}}, data); auto qa = add_quantize_op(m3, "quantizelinear", a, scale, zero); auto qb = add_quantize_op(m3, "quantizelinear", b, scale, zero); @@ -202,7 +209,14 @@ TEST_CASE(fp8_gemm_conversion) } run_simplify_qdq(m1); + //running propagate constant to simplify adjustments to literals + //could pass the test without, but a tedious amount of instructions to rearrange + run_propagate_constant(m1); + run_propagate_constant(m3); + run_cse(m1); + run_cse(m3); EXPECT(m1 == m3); + m1.debug_print(); } int main(int argc, const char* argv[]) { test::run(argc, argv); } From dac07c2212767287c2965575d96cc6c99f395226 Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 4 Dec 2024 17:58:06 -0600 Subject: [PATCH 11/23] formatting --- src/fp8_ocp_to_nanoo.cpp | 37 +++++++++++++++++----------------- src/simplify_qdq.cpp | 2 +- test/fp8_ocp_to_nanoo_test.cpp | 29 ++++++++++++++++---------- 3 files changed, 37 insertions(+), 31 deletions(-) diff --git a/src/fp8_ocp_to_nanoo.cpp b/src/fp8_ocp_to_nanoo.cpp index 318163570d7..7ece60d333a 100644 --- a/src/fp8_ocp_to_nanoo.cpp +++ b/src/fp8_ocp_to_nanoo.cpp @@ -44,14 +44,14 @@ struct match_fp8ocp_convert_to_fp8nanoo match::arg(1)(skip_post_dq_ops(dequantizelinear_op("scale2", "zp2").bind("dq2"))); return match::name(get_quantizable_op_names())(dq1, dq2); } - + static auto bit_cast_and_handle_specials(module& m, - const instruction_ref dq, - const instruction_ref x, - const instruction_ref bits_0x80_lit, - const instruction_ref bits_0x7f_lit, - const instruction_ref bits_0xff_lit, - const instruction_ref bits_0x00_lit) + const instruction_ref dq, + const instruction_ref x, + const instruction_ref bits_0x80_lit, + const instruction_ref bits_0x7f_lit, + const instruction_ref bits_0xff_lit, + const instruction_ref bits_0x00_lit) { auto x_lens = x->get_shape().lens(); auto cast_input = m.insert_instruction( @@ -80,15 +80,14 @@ struct match_fp8ocp_convert_to_fp8nanoo return ret; } - // Add the same broadcast instructions after adjusted scales or - // adjusted zero points from after the originals. Similar to + // Add the same broadcast instructions after adjusted scales or + // adjusted zero points from after the originals. Similar to // propagate_quantized_ins in simplify_qdq. static auto propagate_broadcasts(module& m, - const instruction_ref adj, - const instruction_ref ori, - const instruction_ref start, - const instruction_ref insert_pt - ) + const instruction_ref adj, + const instruction_ref ori, + const instruction_ref start, + const instruction_ref insert_pt) { auto prev_ins = start; std::vector ins_inbetween; @@ -108,10 +107,10 @@ struct match_fp8ocp_convert_to_fp8nanoo } static auto cast_to_nanoo(module& m, - const instruction_ref dq, - const instruction_ref input, - const instruction_ref dq_scale, - const instruction_ref dq_zp) + const instruction_ref dq, + const instruction_ref input, + const instruction_ref dq_scale, + const instruction_ref dq_zp) { auto x = input; std::vector bits_0x80 = {fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits())}; @@ -138,7 +137,7 @@ struct match_fp8ocp_convert_to_fp8nanoo auto adj_dq_scale = m.insert_instruction(dq, make_op("mul"), dq_scale, two_lit); adj_dq_scale = propagate_broadcasts(m, adj_dq_scale, dq_scale, dq->inputs().at(1), dq); - adj_dq_zp = propagate_broadcasts(m, adj_dq_zp, dq_zp, dq->inputs().at(2), dq); + adj_dq_zp = propagate_broadcasts(m, adj_dq_zp, dq_zp, dq->inputs().at(2), dq); m.replace_instruction(dq, make_op("dequantizelinear"), x, adj_dq_scale, adj_dq_zp); } diff --git a/src/simplify_qdq.cpp b/src/simplify_qdq.cpp index 0444ff0c64d..36493684ef1 100644 --- a/src/simplify_qdq.cpp +++ b/src/simplify_qdq.cpp @@ -229,7 +229,7 @@ struct match_find_quantizable_ops auto zero_lit = m.add_literal(literal{shape{dq->get_shape().type()}, {0}}); out_zp = m.insert_instruction( qop, make_op("multibroadcast", {{"out_lens", dq->get_shape().lens()}}), zero_lit); - + auto zp1_bc = m.insert_instruction( qop, qparam_broadcast_op(zp1, arg1_lens, arg1_lens.size() - 2), zp1); auto zp2_bc = m.insert_instruction( diff --git a/test/fp8_ocp_to_nanoo_test.cpp b/test/fp8_ocp_to_nanoo_test.cpp index bf589b49ac9..8db7ff947f4 100644 --- a/test/fp8_ocp_to_nanoo_test.cpp +++ b/test/fp8_ocp_to_nanoo_test.cpp @@ -137,7 +137,8 @@ TEST_CASE(fp8_gemm_conversion) auto scale = m1.add_literal(0.5f); std::vector data; data.push_back(fp8e4m3fn{0.f}); - auto zero = m1.add_literal(migraphx::shape{migraphx::shape::fp8e4m3fn_type, {1}, {0}}, data); + auto zero = + m1.add_literal(migraphx::shape{migraphx::shape::fp8e4m3fn_type, {1}, {0}}, data); auto qa = add_quantize_op(m1, "quantizelinear", a, scale, zero); auto qb = add_quantize_op(m1, "quantizelinear", b, scale, zero); @@ -158,21 +159,26 @@ TEST_CASE(fp8_gemm_conversion) auto scale = m2.add_literal(0.5f); std::vector data; data.push_back(fp8e4m3fn{0.f}); - auto zero = m2.add_literal(migraphx::shape{migraphx::shape::fp8e4m3fn_type, {1}, {0}}, data); + auto zero = + m2.add_literal(migraphx::shape{migraphx::shape::fp8e4m3fn_type, {1}, {0}}, data); auto qa = add_quantize_op(m2, "quantizelinear", a, scale, zero); auto qb = add_quantize_op(m2, "quantizelinear", b, scale, zero); auto outs_a = cast_fp8_helper(m2, qa, scale, zero); - auto adj_a = outs_a.at(0); - auto mb_scales_a = m2.add_instruction(make_op("multibroadcast", {{"out_lens", data_lens}}), outs_a.at(1)); - auto mb_zp_a = m2.add_instruction(make_op("multibroadcast", {{"out_lens", data_lens}}), outs_a.at(2)); + auto adj_a = outs_a.at(0); + auto mb_scales_a = + m2.add_instruction(make_op("multibroadcast", {{"out_lens", data_lens}}), outs_a.at(1)); + auto mb_zp_a = + m2.add_instruction(make_op("multibroadcast", {{"out_lens", data_lens}}), outs_a.at(2)); auto da = m2.add_instruction(make_op("dequantizelinear"), adj_a, mb_scales_a, mb_zp_a); auto outs_b = cast_fp8_helper(m2, qb, scale, zero); - auto adj_b = outs_b.at(0); - auto mb_scales_b = m2.add_instruction(make_op("multibroadcast", {{"out_lens", data_lens}}), outs_b.at(1)); - auto mb_zp_b = m2.add_instruction(make_op("multibroadcast", {{"out_lens", data_lens}}), outs_b.at(2)); + auto adj_b = outs_b.at(0); + auto mb_scales_b = + m2.add_instruction(make_op("multibroadcast", {{"out_lens", data_lens}}), outs_b.at(1)); + auto mb_zp_b = + m2.add_instruction(make_op("multibroadcast", {{"out_lens", data_lens}}), outs_b.at(2)); auto db = m2.add_instruction(make_op("dequantizelinear"), adj_b, mb_scales_b, mb_zp_b); auto dot = m2.add_instruction(migraphx::make_op("dot"), da, db); @@ -189,7 +195,8 @@ TEST_CASE(fp8_gemm_conversion) auto scale = m3.add_literal(0.5f); std::vector data; data.push_back(fp8e4m3fn{0.f}); - auto zero = m3.add_literal(migraphx::shape{migraphx::shape::fp8e4m3fn_type, {1}, {0}}, data); + auto zero = + m3.add_literal(migraphx::shape{migraphx::shape::fp8e4m3fn_type, {1}, {0}}, data); auto qa = add_quantize_op(m3, "quantizelinear", a, scale, zero); auto qb = add_quantize_op(m3, "quantizelinear", b, scale, zero); @@ -209,8 +216,8 @@ TEST_CASE(fp8_gemm_conversion) } run_simplify_qdq(m1); - //running propagate constant to simplify adjustments to literals - //could pass the test without, but a tedious amount of instructions to rearrange + // running propagate constant to simplify adjustments to literals + // could pass the test without, but a tedious amount of instructions to rearrange run_propagate_constant(m1); run_propagate_constant(m3); run_cse(m1); From 06b94b86cb0fc8d9de711ef237499de51ad8e8c7 Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 4 Dec 2024 19:31:21 -0600 Subject: [PATCH 12/23] Added ref tests --- src/targets/gpu/target.cpp | 4 +- test/ref/fp8_ocp_to_nanoo.cpp | 195 +++++++++++++++++++++++++++++++++- 2 files changed, 196 insertions(+), 3 deletions(-) diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 0d28f861b15..61eeb4689d5 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -172,8 +172,8 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, eliminate_identity{}, dead_code_elimination{}, - enable_pass(not gpu::gfx_has_fp8ocp_intrinsics(), fp8_ocp_to_nanoo{}), - enable_pass(not gpu::gfx_has_fp8ocp_intrinsics(), dead_code_elimination{}), + enable_pass(not gpu::gfx_has_fp8ocp_intrinsics() and gpu::gfx_has_fp8fnuz_intrinsics(), fp8_ocp_to_nanoo{}), + enable_pass(not gpu::gfx_has_fp8ocp_intrinsics() and gpu::gfx_has_fp8fnuz_intrinsics(), dead_code_elimination{}), simplify_qdq{}, enable_pass(not mlir_enabled(), rewrite_quantization{}), dead_code_elimination{}, diff --git a/test/ref/fp8_ocp_to_nanoo.cpp b/test/ref/fp8_ocp_to_nanoo.cpp index 699c35fc79c..d8ca87b786a 100644 --- a/test/ref/fp8_ocp_to_nanoo.cpp +++ b/test/ref/fp8_ocp_to_nanoo.cpp @@ -27,7 +27,200 @@ #include #include #include +#include +#include +#include +#include #include +#include -TEST_CASE(fp8_ocp_to_nanno) {} +/** + * test that before and after the fp8_ocp_to_nanoo pass + * have equivalent results + */ + +void run_fp8_ocp_to_nanoo(migraphx::module& m) +{ + migraphx::run_passes(m, {migraphx::fp8_ocp_to_nanoo{}, migraphx::dead_code_elimination{}}); +} + +TEST_CASE(fp8_ocp_to_nanoo_gemm) +{ + using migraphx::fp8::fp8e4m3fn; + using migraphx::fp8::fp8e4m3fnuz; + std::vector data_lens = {2, 2}; + migraphx::shape data_shape{migraphx::shape::float_type, data_lens}; + + migraphx::program p1; + auto* m1 = p1.get_main_module(); + { + auto a = m1->add_parameter("a", data_shape); + auto b = m1->add_parameter("b", data_shape); + auto scale = m1->add_literal(0.5f); + std::vector data; + data.push_back(fp8e4m3fn{0.f}); + auto zero = + m1->add_literal(migraphx::shape{migraphx::shape::fp8e4m3fn_type, {1}, {0}}, data); + + auto qa = add_quantize_op(*m1, "quantizelinear", a, scale, zero); + auto qb = add_quantize_op(*m1, "quantizelinear", b, scale, zero); + auto da = + add_quantize_op(*m1, "dequantizelinear", qa, qa->inputs().at(1), qa->inputs().at(2)); + auto db = + add_quantize_op(*m1, "dequantizelinear", qb, qb->inputs().at(1), qb->inputs().at(2)); + auto dot = m1->add_instruction(migraphx::make_op("dot"), da, db); + m1->add_return({dot}); + } + + migraphx::program p2 = p1; + migraphx::module* m2 = p2.get_main_module(); + run_fp8_ocp_to_nanoo(*m2); + + p1.compile(migraphx::make_target("ref")); + p2.compile(migraphx::make_target("ref")); + + migraphx::parameter_map params; + std::vector a_data = {20, -100, 100, 0.25}; + std::vector b_data = {28, 0.125, 2.5, 0.25}; + params["a"] = migraphx::argument(data_shape, a_data.data()); + params["b"] = migraphx::argument(data_shape, b_data.data()); + + auto result_1 = p1.eval({params}).back(); + auto result_2 = p2.eval({params}).back(); + std::vector results_vector_1(4); + std::vector results_vector_2(4); + result_1.visit([&](auto output) { results_vector_1.assign(output.begin(), output.end()); }); + result_2.visit([&](auto output) { results_vector_2.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify::verify_rms_range(results_vector_1, results_vector_2)); +} + +TEST_CASE(fp8_ocp_to_nanoo_gemm_multi_scale) +{ + using migraphx::fp8::fp8e4m3fn; + using migraphx::fp8::fp8e4m3fnuz; + std::vector data_lens = {3, 3}; + migraphx::shape data_shape{migraphx::shape::float_type, data_lens}; + migraphx::shape scales_shape{migraphx::shape::float_type, {3}}; + + migraphx::program p1; + auto* m1 = p1.get_main_module(); + { + auto a = m1->add_parameter("a", data_shape); + auto b = m1->add_parameter("b", data_shape); + auto scale1 = m1->add_literal(migraphx::generate_literal(scales_shape, 0)); + auto scale2 = m1->add_literal(0.4f); + std::vector data; + data.push_back(fp8e4m3fn{0.f}); + auto zero = + m1->add_literal(migraphx::shape{migraphx::shape::fp8e4m3fn_type, {1}, {0}}, data); + + auto qa = add_quantize_op(*m1, "quantizelinear", a, scale1, zero); + auto qb = add_quantize_op(*m1, "quantizelinear", b, scale2, zero); + auto da = + add_quantize_op(*m1, "dequantizelinear", qa, qa->inputs().at(1), qa->inputs().at(2)); + auto db = + add_quantize_op(*m1, "dequantizelinear", qb, qb->inputs().at(1), qb->inputs().at(2)); + auto dot = m1->add_instruction(migraphx::make_op("dot"), da, db); + m1->add_return({dot}); + } + + migraphx::program p2 = p1; + migraphx::module* m2 = p2.get_main_module(); + run_fp8_ocp_to_nanoo(*m2); + + p1.compile(migraphx::make_target("ref")); + p2.compile(migraphx::make_target("ref")); + + migraphx::parameter_map params; + std::vector a_data = {20, -100, 100, 0.25, 0.3, 3.3, 5.0, -8.0, 63.0}; + std::vector b_data = {28, 0.125, 2.5, 0.25, 0.0582, -187, 0.716, 8.12, 1.87}; + params["a"] = migraphx::argument(data_shape, a_data.data()); + params["b"] = migraphx::argument(data_shape, b_data.data()); + + auto result_1 = p1.eval({params}).back(); + auto result_2 = p2.eval({params}).back(); + std::vector results_vector_1(9); + std::vector results_vector_2(9); + result_1.visit([&](auto output) { results_vector_1.assign(output.begin(), output.end()); }); + result_2.visit([&](auto output) { results_vector_2.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify::verify_rms_range(results_vector_1, results_vector_2)); +} + +TEST_CASE(fp8_ocp_to_nanoo_conv) +{ + using migraphx::fp8::fp8e4m3fn; + using migraphx::fp8::fp8e4m3fnuz; + std::vector data_lens = {2, 2}; + migraphx::shape data_shape{migraphx::shape::float_type, data_lens}; + + migraphx::program p1; + auto* m1 = p1.get_main_module(); + { + std::vector a_data = { + 2.71567607, -0.9960829, 0.91671127, 0.28140706, 0.63235772, 0.08077253, + 0.80927712, -0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, + 0.67726439, -0.65290606, 0.02345525, -0.33579525, 0.38901961, 1.05473483, + -1.31188095, 1.8963089, -0.07265259, 0.947339, 0.41949373, -0.70814759, + 0.25892952, 1.07311416, 1.2571274, -0.62318051, -0.19951548, -0.94232577, + -0.29393643, 0.42292568, -0.80230367, 1.40909171, 0.63617158, 0.13900366, + 1.09253144, -0.15265895, 1.54781747, 0.72780299, 1.09189606, -0.38068101, + 0.97057933, -0.58958799, 1.56188643, 0.21474874, 0.58725154, -1.27097559, + -0.03024297, 1.09437096, -0.4897908, 0.34838957, -1.31042492, -1.69069934, + 0.86956722, -0.40457946, 0.46691212, 1.29273605, 0.26464137, 0.22073045, + -1.02178168, 0.22163901, -1.84387338, 0.75522131, -0.45775682, -0.42241111, + -1.50944722, 1.07256448, -1.95876884, -0.28106022, 0.3341668, 2.13129425, + -1.14728117, -1.06555498, -0.298444, -0.88322699, -0.65866792, -2.06007552, + 0.01374334, 0.45612028, 0.52715492, 1.01914406, -1.72659791, 0.80650896, + 0.16860051, 2.24112225, -0.78620857, 0.36566174, -0.07020134, -0.47976932, + -0.68230027, -0.94711417, -0.54506505, 1.66504931, -0.71860826, 0.61132306}; + + std::vector b_data = { + 2.82721668e-02, 6.44195229e-02, 1.53499246e-02, 1.72468081e-01, -6.33238107e-02, + 9.49496776e-02, 1.40258059e-01, -7.92879611e-02, -1.29301161e-01, 3.11307609e-03, + -1.90624535e-01, 1.13238767e-01, -2.80647576e-02, 3.12882811e-02, -3.52091640e-02, + 3.33581865e-02, 6.43158704e-02, 7.40238279e-02, -1.00106120e-01, -9.56912562e-02, + 1.44342467e-01, 9.40258950e-02, 6.36333972e-02, 1.66158378e-03, -8.91554281e-02, + 2.58734226e-02, 1.70919895e-02, 1.78214177e-01, 8.84564668e-02, 8.98126513e-02, + -1.63809001e-01, 1.37802169e-01, 1.66439757e-01, -1.45631135e-02, 1.88469887e-04, + 4.76950556e-02, -1.91969007e-01, -1.76233292e-01, -7.70473927e-02, 1.14828631e-01, + 1.76608220e-01, -1.50728196e-01, 1.99946314e-02, -5.88052124e-02, 1.31612435e-01, + 1.61106288e-02, -1.35080189e-01, 1.49512306e-01, 3.86456847e-02, 1.29330024e-01, + -3.22975963e-02, -5.60784787e-02, -5.41997552e-02, 4.78562862e-02}; + + migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 4, 4}}; + auto a = m1->add_literal(migraphx::literal{a_shape, a_data}); + + migraphx::shape b_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; + auto b = m1->add_literal(migraphx::literal{b_shape, b_data}); + auto scale = m1->add_literal(0.5f); + std::vector data; + data.push_back(fp8e4m3fn{0.f}); + auto zero = + m1->add_literal(migraphx::shape{migraphx::shape::fp8e4m3fn_type, {1}, {0}}, data); + + auto qa = add_quantize_op(*m1, "quantizelinear", a, scale, zero); + auto qb = add_quantize_op(*m1, "quantizelinear", b, scale, zero); + auto da = + add_quantize_op(*m1, "dequantizelinear", qa, qa->inputs().at(1), qa->inputs().at(2)); + auto db = + add_quantize_op(*m1, "dequantizelinear", qb, qb->inputs().at(1), qb->inputs().at(2)); + auto conv_ins = m1->add_instruction(migraphx::make_op("convolution"), da, db); + m1->add_return({conv_ins}); + } + + migraphx::program p2 = p1; + migraphx::module* m2 = p2.get_main_module(); + run_fp8_ocp_to_nanoo(*m2); + + p1.compile(migraphx::make_target("ref")); + p2.compile(migraphx::make_target("ref")); + + auto result_1 = p1.eval({}).back(); + auto result_2 = p2.eval({}).back(); + std::vector results_vector_1(16); + std::vector results_vector_2(16); + result_1.visit([&](auto output) { results_vector_1.assign(output.begin(), output.end()); }); + result_2.visit([&](auto output) { results_vector_2.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify::verify_rms_range(results_vector_1, results_vector_2)); +} From df0202eae6185cd88f22167707f08de106950c46 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 10 Dec 2024 14:58:06 -0600 Subject: [PATCH 13/23] Cleanup --- src/CMakeLists.txt | 3 +- ...8_ocp_to_nanoo.cpp => fp8_ocp_to_fnuz.cpp} | 42 ++++++------- ...8_ocp_to_nanoo.hpp => fp8_ocp_to_fnuz.hpp} | 19 +++--- src/include/migraphx/match/dq_helpers.hpp | 62 +++++++++++++++++++ src/include/migraphx/qdq_helpers.hpp | 22 +------ src/qdq_helpers.cpp | 37 +++++++++++ src/simplify_qdq.cpp | 9 +-- src/targets/gpu/target.cpp | 4 +- ...anoo_test.cpp => fp8_ocp_to_fnuz_test.cpp} | 12 ++-- test/include/quantize_helpers.hpp | 54 ++++++++-------- ...8_ocp_to_nanoo.cpp => fp8_ocp_to_fnuz.cpp} | 20 +++--- 11 files changed, 182 insertions(+), 102 deletions(-) rename src/{fp8_ocp_to_nanoo.cpp => fp8_ocp_to_fnuz.cpp} (85%) rename src/include/migraphx/{fp8_ocp_to_nanoo.hpp => fp8_ocp_to_fnuz.hpp} (71%) create mode 100644 src/include/migraphx/match/dq_helpers.hpp create mode 100644 src/qdq_helpers.cpp rename test/{fp8_ocp_to_nanoo_test.cpp => fp8_ocp_to_fnuz_test.cpp} (97%) rename test/ref/{fp8_ocp_to_nanoo.cpp => fp8_ocp_to_fnuz.cpp} (95%) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 10d7193b67c..a05a1d7ef21 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -57,7 +57,7 @@ add_library(migraphx file_buffer.cpp fileutils.cpp fp_to_double.cpp - fp8_ocp_to_nanoo.cpp + fp8_ocp_to_fnuz.cpp fuse_concat.cpp fuse_pointwise.cpp fuse_pointwise_reduce.cpp @@ -89,6 +89,7 @@ add_library(migraphx program.cpp propagate_constant.cpp promote_literals.cpp + qdq_helpers.cpp quantization.cpp quantize_int4.cpp quantize_8bits.cpp diff --git a/src/fp8_ocp_to_nanoo.cpp b/src/fp8_ocp_to_fnuz.cpp similarity index 85% rename from src/fp8_ocp_to_nanoo.cpp rename to src/fp8_ocp_to_fnuz.cpp index 7ece60d333a..9f858c3bdb7 100644 --- a/src/fp8_ocp_to_nanoo.cpp +++ b/src/fp8_ocp_to_fnuz.cpp @@ -21,12 +21,13 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#include +#include #include #include #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -34,14 +35,14 @@ namespace { using fp8::fp8e4m3fnuz; -struct match_fp8ocp_convert_to_fp8nanoo +struct match_fp8ocp_convert_to_fp8fnuz { auto matcher() const { - auto dq1 = - match::arg(0)(skip_post_dq_ops(dequantizelinear_op("scale1", "zp1").bind("dq1"))); - auto dq2 = - match::arg(1)(skip_post_dq_ops(dequantizelinear_op("scale2", "zp2").bind("dq2"))); + auto dq1 = match::arg(0)( + skip_post_dq_ops(match::dequantizelinear_op("scale1", "zp1").bind("dq1"))); + auto dq2 = match::arg(1)( + skip_post_dq_ops(match::dequantizelinear_op("scale2", "zp2").bind("dq2"))); return match::name(get_quantizable_op_names())(dq1, dq2); } @@ -53,7 +54,7 @@ struct match_fp8ocp_convert_to_fp8nanoo const instruction_ref bits_0xff_lit, const instruction_ref bits_0x00_lit) { - auto x_lens = x->get_shape().lens(); + auto x_lens = x->get_shape().lens(); auto cast_input = m.insert_instruction( dq, make_op("bit_cast", {{"target_type", shape::fp8e4m3fnuz_type}}), x); auto mb_bits_0x80_lit = m.insert_instruction( @@ -90,27 +91,27 @@ struct match_fp8ocp_convert_to_fp8nanoo const instruction_ref insert_pt) { auto prev_ins = start; - std::vector ins_inbetween; + std::vector ins_between; // matcher skips continguous, multi/broadcasts and transposes, collect all those // instructions while(prev_ins != ori) { - ins_inbetween.push_back(prev_ins); + ins_between.push_back(prev_ins); prev_ins = prev_ins->inputs().front(); } auto ret = adj; - for(auto ins : reverse_iterator_for(ins_inbetween)) + for(auto ins : reverse_iterator_for(ins_between)) { ret = m.insert_instruction(insert_pt, (*ins)->get_operator(), {ret}); } return ret; } - static auto cast_to_nanoo(module& m, - const instruction_ref dq, - const instruction_ref input, - const instruction_ref dq_scale, - const instruction_ref dq_zp) + static auto cast_to_fnuz(module& m, + const instruction_ref dq, + const instruction_ref input, + const instruction_ref dq_scale, + const instruction_ref dq_zp) { auto x = input; std::vector bits_0x80 = {fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits())}; @@ -132,7 +133,7 @@ struct match_fp8ocp_convert_to_fp8nanoo // adj_scale = 2 * scale auto two_lit = m.add_literal(literal{shape{dq_scale->get_shape().type()}, {2}}); - two_lit = m.insert_instruction( + two_lit = m.insert_instruction( dq, make_op("multibroadcast", {{"out_lens", dq_scale->get_shape().lens()}}), two_lit); auto adj_dq_scale = m.insert_instruction(dq, make_op("mul"), dq_scale, two_lit); @@ -155,18 +156,17 @@ struct match_fp8ocp_convert_to_fp8nanoo not contains(supported_types, dq2->inputs().front()->get_shape().type())) return; - cast_to_nanoo(m, dq1, dq1->inputs().front(), scale1, zp1); - cast_to_nanoo(m, dq2, dq2->inputs().front(), scale2, zp2); - return; + cast_to_fnuz(m, dq1, dq1->inputs().front(), scale1, zp1); + cast_to_fnuz(m, dq2, dq2->inputs().front(), scale2, zp2); } }; } // namespace -void fp8_ocp_to_nanoo::apply(module_pass_manager& mpm) const +void fp8_ocp_to_fnuz::apply(module_pass_manager& mpm) const { module_ref mm = &mpm.get_module(); - match::find_matches(*mm, match_fp8ocp_convert_to_fp8nanoo{}); + match::find_matches(*mm, match_fp8ocp_convert_to_fp8fnuz{}); } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/fp8_ocp_to_nanoo.hpp b/src/include/migraphx/fp8_ocp_to_fnuz.hpp similarity index 71% rename from src/include/migraphx/fp8_ocp_to_nanoo.hpp rename to src/include/migraphx/fp8_ocp_to_fnuz.hpp index 97f49878fbc..19e4a1cda02 100644 --- a/src/include/migraphx/fp8_ocp_to_nanoo.hpp +++ b/src/include/migraphx/fp8_ocp_to_fnuz.hpp @@ -21,24 +21,25 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#ifndef MIGRAPHX_GUARD_RTGLIB_FP8_OCP_TO_NANOO_HPP -#define MIGRAPHX_GUARD_RTGLIB_FP8_OCP_TO_NANOO_HPP +#ifndef MIGRAPHX_GUARD_RTGLIB_FP8_OCP_TO_FNUZ_HPP +#define MIGRAPHX_GUARD_RTGLIB_FP8_OCP_TO_FNUZ_HPP #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { - + /** - * Convert fp8e4m3fn to fp8e4m3fnuz for hardware that only supports fp8e4m3fnuz data types intrinsically. - * Conversion uses the same bit representation and adjusts scaling factors at the dequantization. - * Using the same bit representation from fp8e4m3fn to fp8e4m3fnuz halves the floating point representation. - * This pass should run before simplify_qdq so that the scales and zero points calculated by simplify_qdq have the correct adjusted scaling factors + * Convert fp8e4m3fn to fp8e4m3fnuz for hardware that only supports fp8e4m3fnuz data types + * intrinsically. Conversion uses the same bit representation and adjusts scaling factors at the + * dequantization. Using the same bit representation from fp8e4m3fn to fp8e4m3fnuz halves the + * floating point representation. This pass should run before simplify_qdq so that the scales and + * zero points calculated by simplify_qdq have the correct adjusted scaling factors */ -struct MIGRAPHX_EXPORT fp8_ocp_to_nanoo +struct MIGRAPHX_EXPORT fp8_ocp_to_fnuz { - std::string name() const { return "fp8_ocp_to_nanoo"; } + std::string name() const { return "fp8_ocp_to_fnuz"; } void apply(module_pass_manager& mpm) const; }; diff --git a/src/include/migraphx/match/dq_helpers.hpp b/src/include/migraphx/match/dq_helpers.hpp new file mode 100644 index 00000000000..cdb40ae977e --- /dev/null +++ b/src/include/migraphx/match/dq_helpers.hpp @@ -0,0 +1,62 @@ + +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_MATCH_DQ_HELPERS_HPP +#define MIGRAPHX_GUARD_MATCH_DQ_HELPERS_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace match { + +/** + * Find dequantizelinear (DQ) instruction with constant scale and zero point input + * while skipping broadcast instructions between DQ and scale/zero point. Used + * in simplify_qdq and fp8_ocp_to_fnuz. + */ +inline auto dequantizelinear_op(const std::string& scale, const std::string& zp) +{ + return match::name("dequantizelinear")( + match::arg(1)(match::skip_broadcasts(match::is_constant().bind(scale))), + match::arg(2)(match::skip_broadcasts(match::is_constant().bind(zp)))); +} + +/** + * Skip certain operators after DQ instruction. + * Used in simplify_qdq and fp8_ocp_to_fnuz. + */ +template +auto skip_post_dq_ops(Ms... ms) +{ + return match::skip(match::name( + "broadcast", "multibroadcast", "contiguous", "transpose", "reshape", "convert"))(ms...); +} + +} // namespace match +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/qdq_helpers.hpp b/src/include/migraphx/qdq_helpers.hpp index e28e9625930..8dce790e42f 100644 --- a/src/include/migraphx/qdq_helpers.hpp +++ b/src/include/migraphx/qdq_helpers.hpp @@ -32,28 +32,8 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -namespace { +std::unordered_set get_quantizable_op_names(); -template -auto skip_post_dq_ops(Ms... ms) -{ - return match::skip(match::name( - "broadcast", "multibroadcast", "contiguous", "transpose", "reshape", "convert"))(ms...); -} - -static std::unordered_set get_quantizable_op_names() -{ - static std::unordered_set s = {"convolution", "dot"}; - return s; -} - -static auto dequantizelinear_op(const std::string& scale, const std::string& zp) -{ - return match::name("dequantizelinear")( - match::arg(1)(match::skip_broadcasts(match::is_constant().bind(scale))), - match::arg(2)(match::skip_broadcasts(match::is_constant().bind(zp)))); -} -} // namespace } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/qdq_helpers.cpp b/src/qdq_helpers.cpp new file mode 100644 index 00000000000..af397a4f3e5 --- /dev/null +++ b/src/qdq_helpers.cpp @@ -0,0 +1,37 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +std::unordered_set get_quantizable_op_names() +{ + static std::unordered_set s = {"convolution", "dot"}; + return s; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/simplify_qdq.cpp b/src/simplify_qdq.cpp index 36493684ef1..2877b1e4bfe 100644 --- a/src/simplify_qdq.cpp +++ b/src/simplify_qdq.cpp @@ -37,6 +37,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -107,10 +108,10 @@ struct match_find_quantizable_ops auto matcher() const { - auto dq1 = - match::arg(0)(skip_post_dq_ops(dequantizelinear_op("scale1", "zp1").bind("dq1"))); - auto dq2 = - match::arg(1)(skip_post_dq_ops(dequantizelinear_op("scale2", "zp2").bind("dq2"))); + auto dq1 = match::arg(0)( + skip_post_dq_ops(match::dequantizelinear_op("scale1", "zp1").bind("dq1"))); + auto dq2 = match::arg(1)( + skip_post_dq_ops(match::dequantizelinear_op("scale2", "zp2").bind("dq2"))); return match::name(get_quantizable_op_names())(dq1, dq2); } diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index c71a348aa30..85033f7557a 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -31,7 +31,7 @@ #include #include #include -#include +#include #include #include #include @@ -180,7 +180,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, eliminate_identity{}, dead_code_elimination{}, - enable_pass(not gpu::gfx_has_fp8ocp_intrinsics() and gpu::gfx_has_fp8fnuz_intrinsics(), fp8_ocp_to_nanoo{}), + enable_pass(not gpu::gfx_has_fp8ocp_intrinsics() and gpu::gfx_has_fp8fnuz_intrinsics(), fp8_ocp_to_fnuz{}), enable_pass(not gpu::gfx_has_fp8ocp_intrinsics() and gpu::gfx_has_fp8fnuz_intrinsics(), dead_code_elimination{}), simplify_qdq{}, enable_pass(not mlir_enabled(), rewrite_quantization{}), diff --git a/test/fp8_ocp_to_nanoo_test.cpp b/test/fp8_ocp_to_fnuz_test.cpp similarity index 97% rename from test/fp8_ocp_to_nanoo_test.cpp rename to test/fp8_ocp_to_fnuz_test.cpp index 8db7ff947f4..58abb18bddc 100644 --- a/test/fp8_ocp_to_nanoo_test.cpp +++ b/test/fp8_ocp_to_fnuz_test.cpp @@ -27,12 +27,12 @@ #include #include #include -#include +#include #include #include #include #include -#include +#include #include #include @@ -41,9 +41,9 @@ using migraphx::make_op; using migraphx::shape; using migraphx::fp8::fp8e4m3fnuz; -void run_fp8_ocp_to_nanoo(migraphx::module& m) +void run_fp8_ocp_to_fnuz(migraphx::module& m) { - migraphx::run_passes(m, {migraphx::fp8_ocp_to_nanoo{}, migraphx::dead_code_elimination{}}); + migraphx::run_passes(m, {migraphx::fp8_ocp_to_fnuz{}, migraphx::dead_code_elimination{}}); } void run_simplify_qdq(migraphx::module& m) @@ -149,9 +149,9 @@ TEST_CASE(fp8_gemm_conversion) auto dot = m1.add_instruction(migraphx::make_op("dot"), da, db); m1.add_return({dot}); } - run_fp8_ocp_to_nanoo(m1); + run_fp8_ocp_to_fnuz(m1); - // expected after fp8_ocp_to_nanoo + // expected after fp8_ocp_to_fnuz migraphx::module m2; { auto a = m2.add_parameter("a", {migraphx::shape::float_type, data_lens}); diff --git a/test/include/quantize_helpers.hpp b/test/include/quantize_helpers.hpp index 7a41484c0e7..43bde67199e 100644 --- a/test/include/quantize_helpers.hpp +++ b/test/include/quantize_helpers.hpp @@ -1,17 +1,15 @@ -#include #include -#include +#include +#include #include -#include - #ifndef MIGRAPHX_GUARD_TEST_INCLUDE_QUANTIZE_HELPERS_HPP #define MIGRAPHX_GUARD_TEST_INCLUDE_QUANTIZE_HELPERS_HPP -migraphx::instruction_ref broadcast_scale(migraphx::module& m, - migraphx::instruction_ref scale, - const std::vector& out_lens, - std::size_t axis) +inline migraphx::instruction_ref broadcast_scale(migraphx::module& m, + migraphx::instruction_ref scale, + const std::vector& out_lens, + std::size_t axis) { if(scale->get_shape().lens() == out_lens) return scale; @@ -27,33 +25,33 @@ migraphx::instruction_ref broadcast_scale(migraphx::module& m, return scale_mb; } -migraphx::instruction_ref broadcast_shift(migraphx::module& m, - migraphx::instruction_ref shift, - const std::vector& out_lens) +inline migraphx::instruction_ref broadcast_shift(migraphx::module& m, + migraphx::instruction_ref shift, + const std::vector& out_lens) { if(shift->get_shape().lens() == out_lens) return shift; return m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), shift); } -migraphx::instruction_ref add_scale_mul(migraphx::module& m, - migraphx::instruction_ref scale1, - migraphx::instruction_ref scale2, - std::size_t axis1, - std::size_t axis2, - const std::vector& out_lens) +inline migraphx::instruction_ref add_scale_mul(migraphx::module& m, + migraphx::instruction_ref scale1, + migraphx::instruction_ref scale2, + std::size_t axis1, + std::size_t axis2, + const std::vector& out_lens) { auto scale1_mb = broadcast_scale(m, scale1, out_lens, axis1); auto scale2_mb = broadcast_scale(m, scale2, out_lens, axis2); return m.add_instruction(migraphx::make_op("mul"), scale1_mb, scale2_mb); } -migraphx::instruction_ref add_quantize_op(migraphx::module& m, - const std::string& name, - migraphx::instruction_ref x, - migraphx::instruction_ref scale, - migraphx::instruction_ref shift, - std::size_t q_axis = 1) +inline migraphx::instruction_ref add_quantize_op(migraphx::module& m, + const std::string& name, + migraphx::instruction_ref x, + migraphx::instruction_ref scale, + migraphx::instruction_ref shift, + std::size_t q_axis = 1) { auto lens = x->get_shape().lens(); auto scale_mb = broadcast_scale(m, scale, lens, q_axis); @@ -61,11 +59,11 @@ migraphx::instruction_ref add_quantize_op(migraphx::module& m, return m.add_instruction(migraphx::make_op(name), x, scale_mb, shift_mb); } -migraphx::instruction_ref add_quantize_op(migraphx::module& m, - const std::string& name, - migraphx::instruction_ref x, - migraphx::instruction_ref scale, - std::size_t q_axis = 1) +inline migraphx::instruction_ref add_quantize_op(migraphx::module& m, + const std::string& name, + migraphx::instruction_ref x, + migraphx::instruction_ref scale, + std::size_t q_axis = 1) { auto lens = x->get_shape().lens(); auto scale_mb = broadcast_scale(m, scale, lens, q_axis); diff --git a/test/ref/fp8_ocp_to_nanoo.cpp b/test/ref/fp8_ocp_to_fnuz.cpp similarity index 95% rename from test/ref/fp8_ocp_to_nanoo.cpp rename to test/ref/fp8_ocp_to_fnuz.cpp index d8ca87b786a..d0a00df9565 100644 --- a/test/ref/fp8_ocp_to_nanoo.cpp +++ b/test/ref/fp8_ocp_to_fnuz.cpp @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #include @@ -36,16 +36,16 @@ #include /** - * test that before and after the fp8_ocp_to_nanoo pass + * test that before and after the fp8_ocp_to_fnuz pass * have equivalent results */ -void run_fp8_ocp_to_nanoo(migraphx::module& m) +void run_fp8_ocp_to_fnuz(migraphx::module& m) { - migraphx::run_passes(m, {migraphx::fp8_ocp_to_nanoo{}, migraphx::dead_code_elimination{}}); + migraphx::run_passes(m, {migraphx::fp8_ocp_to_fnuz{}, migraphx::dead_code_elimination{}}); } -TEST_CASE(fp8_ocp_to_nanoo_gemm) +TEST_CASE(fp8_ocp_to_fnuz_gemm) { using migraphx::fp8::fp8e4m3fn; using migraphx::fp8::fp8e4m3fnuz; @@ -75,7 +75,7 @@ TEST_CASE(fp8_ocp_to_nanoo_gemm) migraphx::program p2 = p1; migraphx::module* m2 = p2.get_main_module(); - run_fp8_ocp_to_nanoo(*m2); + run_fp8_ocp_to_fnuz(*m2); p1.compile(migraphx::make_target("ref")); p2.compile(migraphx::make_target("ref")); @@ -95,7 +95,7 @@ TEST_CASE(fp8_ocp_to_nanoo_gemm) EXPECT(migraphx::verify::verify_rms_range(results_vector_1, results_vector_2)); } -TEST_CASE(fp8_ocp_to_nanoo_gemm_multi_scale) +TEST_CASE(fp8_ocp_to_fnuz_gemm_multi_scale) { using migraphx::fp8::fp8e4m3fn; using migraphx::fp8::fp8e4m3fnuz; @@ -127,7 +127,7 @@ TEST_CASE(fp8_ocp_to_nanoo_gemm_multi_scale) migraphx::program p2 = p1; migraphx::module* m2 = p2.get_main_module(); - run_fp8_ocp_to_nanoo(*m2); + run_fp8_ocp_to_fnuz(*m2); p1.compile(migraphx::make_target("ref")); p2.compile(migraphx::make_target("ref")); @@ -147,7 +147,7 @@ TEST_CASE(fp8_ocp_to_nanoo_gemm_multi_scale) EXPECT(migraphx::verify::verify_rms_range(results_vector_1, results_vector_2)); } -TEST_CASE(fp8_ocp_to_nanoo_conv) +TEST_CASE(fp8_ocp_to_fnuz_conv) { using migraphx::fp8::fp8e4m3fn; using migraphx::fp8::fp8e4m3fnuz; @@ -211,7 +211,7 @@ TEST_CASE(fp8_ocp_to_nanoo_conv) migraphx::program p2 = p1; migraphx::module* m2 = p2.get_main_module(); - run_fp8_ocp_to_nanoo(*m2); + run_fp8_ocp_to_fnuz(*m2); p1.compile(migraphx::make_target("ref")); p2.compile(migraphx::make_target("ref")); From 0318f328d70a9c0f0c5b629205cda00de77221b4 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 10 Dec 2024 15:18:23 -0600 Subject: [PATCH 14/23] initial --- src/quantization.cpp | 18 +----------------- .../gpu/include/migraphx/gpu/context.hpp | 1 - test/gpu/context_serialize.cpp | 3 --- 3 files changed, 1 insertion(+), 21 deletions(-) diff --git a/src/quantization.cpp b/src/quantization.cpp index 7e02ae66685..73dc7022ad7 100644 --- a/src/quantization.cpp +++ b/src/quantization.cpp @@ -198,23 +198,7 @@ void quantize_fp8(program& prog, const target& t, const std::vectorname()); } } - auto gfx_has_fp8fnuz = [&]() { - if(t.name() == "gpu") - { - auto context_value = t.get_context().to_value(); - auto device_name = context_value["gfx_name"].to(); - return (starts_with(device_name, "gfx9") and device_name >= "gfx940"); - } - return false; - }; - if(gfx_has_fp8fnuz()) - { - quantize_8bits(prog, t, shape::fp8e4m3fnuz_type, calibration, supported_ins_names); - } - else - { - quantize_8bits(prog, t, shape::fp8e4m3fn_type, calibration, supported_ins_names); - } + quantize_8bits(prog, t, shape::fp8e4m3fn_type, calibration, supported_ins_names); } } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/include/migraphx/gpu/context.hpp b/src/targets/gpu/include/migraphx/gpu/context.hpp index 960e6ed55a0..8f9dedd6798 100644 --- a/src/targets/gpu/include/migraphx/gpu/context.hpp +++ b/src/targets/gpu/include/migraphx/gpu/context.hpp @@ -311,7 +311,6 @@ struct context value result; result["events"] = events.size(); result["streams"] = current_device->nstreams(); - result["gfx_name"] = get_current_device().get_gfx_name(); return result; } diff --git a/test/gpu/context_serialize.cpp b/test/gpu/context_serialize.cpp index d0c6072180f..57f7f974ba5 100644 --- a/test/gpu/context_serialize.cpp +++ b/test/gpu/context_serialize.cpp @@ -41,9 +41,6 @@ TEST_CASE(gpu_context_serialize) EXPECT(v.contains("streams")); EXPECT(v.at("streams").without_key().to() == 3); - EXPECT(v.contains("gfx_name")); - EXPECT(not v.at("gfx_name").without_key().to().empty()); - migraphx::gpu::context g_ctx; g_ctx.from_value(v); From 3b482424074c5e0a781e0b2d4c47b59070dbc62d Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 10 Dec 2024 16:17:31 -0600 Subject: [PATCH 15/23] temporary --- src/quantization.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/quantization.cpp b/src/quantization.cpp index 73dc7022ad7..8e19328f3bf 100644 --- a/src/quantization.cpp +++ b/src/quantization.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -159,6 +160,8 @@ void quantize_8bits(program& prog, run_passes(prog, {quantize_8bits_pass{precision, *quant_8bit_params}, + fp8_ocp_to_fnuz{}, + dead_code_elimination{}, simplify_qdq{}, optimize_module{}, dead_code_elimination{}}, From b373d10ec852dc1d656ff8c50222b3605f0e6d4a Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 10 Dec 2024 16:25:24 -0600 Subject: [PATCH 16/23] disable simpilify_qdq in quantization_8bits --- src/quantization.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/quantization.cpp b/src/quantization.cpp index 8e19328f3bf..e212b28d393 100644 --- a/src/quantization.cpp +++ b/src/quantization.cpp @@ -29,7 +29,6 @@ #include #include #include -#include #include #include #include @@ -160,9 +159,7 @@ void quantize_8bits(program& prog, run_passes(prog, {quantize_8bits_pass{precision, *quant_8bit_params}, - fp8_ocp_to_fnuz{}, - dead_code_elimination{}, - simplify_qdq{}, + //simplify_qdq{}, optimize_module{}, dead_code_elimination{}}, quant_tracer()); From 28aab5fe000a1796ebb14e7d2ee893def6d47bc3 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 10 Dec 2024 16:46:49 -0600 Subject: [PATCH 17/23] revert --- src/quantization.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/quantization.cpp b/src/quantization.cpp index e212b28d393..73dc7022ad7 100644 --- a/src/quantization.cpp +++ b/src/quantization.cpp @@ -159,7 +159,7 @@ void quantize_8bits(program& prog, run_passes(prog, {quantize_8bits_pass{precision, *quant_8bit_params}, - //simplify_qdq{}, + simplify_qdq{}, optimize_module{}, dead_code_elimination{}}, quant_tracer()); From 7e0142f0526fbd0bd1378cfb90297e2d3fe108df Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 10 Dec 2024 17:29:23 -0600 Subject: [PATCH 18/23] disable extra passes after quantize_8bits --- src/quantization.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/quantization.cpp b/src/quantization.cpp index 73dc7022ad7..bb0836d5273 100644 --- a/src/quantization.cpp +++ b/src/quantization.cpp @@ -158,10 +158,7 @@ void quantize_8bits(program& prog, } run_passes(prog, - {quantize_8bits_pass{precision, *quant_8bit_params}, - simplify_qdq{}, - optimize_module{}, - dead_code_elimination{}}, + {quantize_8bits_pass{precision, *quant_8bit_params}, dead_code_elimination{}}, quant_tracer()); } From 0a4d6bf7c3b61c1943b82eca267de43554120ef0 Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 11 Dec 2024 09:15:35 -0600 Subject: [PATCH 19/23] add verify test --- test/verify/test_fp8_ocp_to_fnuz_gemm.cpp | 60 +++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 test/verify/test_fp8_ocp_to_fnuz_gemm.cpp diff --git a/test/verify/test_fp8_ocp_to_fnuz_gemm.cpp b/test/verify/test_fp8_ocp_to_fnuz_gemm.cpp new file mode 100644 index 00000000000..88fc9828034 --- /dev/null +++ b/test/verify/test_fp8_ocp_to_fnuz_gemm.cpp @@ -0,0 +1,60 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include +#include + +struct test_fp8_ocp_to_fnuz_gemm : verify_program +{ + using fp8e4m3fn = migraphx::fp8::fp8e4m3fn; + using fp8e4m3fnuz = migraphx::fp8::fp8e4m3fnuz; + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data_lens = {2, 2}; + migraphx::shape data_shape{migraphx::shape::float_type, data_lens}; + auto a = mm->add_parameter("a", data_shape); + auto b = mm->add_parameter("b", data_shape); + auto scale = mm->add_literal(0.5f); + std::vector data; + data.push_back(fp8e4m3fn{0.f}); + auto zero = + mm->add_literal(migraphx::shape{migraphx::shape::fp8e4m3fn_type, {1}, {0}}, data); + + auto qa = add_quantize_op(*mm, "quantizelinear", a, scale, zero); + auto qb = add_quantize_op(*mm, "quantizelinear", b, scale, zero); + auto da = + add_quantize_op(*mm, "dequantizelinear", qa, qa->inputs().at(1), qa->inputs().at(2)); + auto db = + add_quantize_op(*mm, "dequantizelinear", qb, qb->inputs().at(1), qb->inputs().at(2)); + auto dot = mm->add_instruction(migraphx::make_op("dot"), da, db); + mm->add_return({dot}); + return p; + } + std::string section() const { return "gemm"; } +}; From c94c52009fca4f8465c25d1e67388aa747d974d8 Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 13 Dec 2024 14:06:55 -0600 Subject: [PATCH 20/23] Fix bug with __builtin_nan(string) it needs a string input --- src/cpp_generator.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cpp_generator.cpp b/src/cpp_generator.cpp index 433ccaadb5b..292d42e3e09 100644 --- a/src/cpp_generator.cpp +++ b/src/cpp_generator.cpp @@ -220,8 +220,8 @@ cpp_generator::function cpp_generator::generate_module(const module& m, if(x < 0) string_literal = "-__builtin_huge_val()"; } - else if(std::isnan(static_cast(x))) - string_literal = "__builtin_nan()"; + else if(std::isnan(x)) + string_literal = "__builtin_nan(\"0\")"; else string_literal = ins->get_literal().to_string(); }); From 0cddfbfaf37ed45ac872aa641bac237084cbdffa Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 13 Dec 2024 14:11:23 -0600 Subject: [PATCH 21/23] separate quantizable ops --- src/CMakeLists.txt | 1 - src/fp8_ocp_to_fnuz.cpp | 7 ++++- src/include/migraphx/qdq_helpers.hpp | 40 ---------------------------- src/qdq_helpers.cpp | 37 ------------------------- src/simplify_qdq.cpp | 7 ++++- 5 files changed, 12 insertions(+), 80 deletions(-) delete mode 100644 src/include/migraphx/qdq_helpers.hpp delete mode 100644 src/qdq_helpers.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a05a1d7ef21..613b21e662e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -89,7 +89,6 @@ add_library(migraphx program.cpp propagate_constant.cpp promote_literals.cpp - qdq_helpers.cpp quantization.cpp quantize_int4.cpp quantize_8bits.cpp diff --git a/src/fp8_ocp_to_fnuz.cpp b/src/fp8_ocp_to_fnuz.cpp index 9f858c3bdb7..305ca6058f1 100644 --- a/src/fp8_ocp_to_fnuz.cpp +++ b/src/fp8_ocp_to_fnuz.cpp @@ -26,7 +26,6 @@ #include #include #include -#include #include namespace migraphx { @@ -35,6 +34,12 @@ namespace { using fp8::fp8e4m3fnuz; +std::unordered_set get_quantizable_op_names() +{ + static std::unordered_set s = {"convolution", "dot"}; + return s; +} + struct match_fp8ocp_convert_to_fp8fnuz { auto matcher() const diff --git a/src/include/migraphx/qdq_helpers.hpp b/src/include/migraphx/qdq_helpers.hpp deleted file mode 100644 index 8dce790e42f..00000000000 --- a/src/include/migraphx/qdq_helpers.hpp +++ /dev/null @@ -1,40 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ - -#ifndef MIGRAPHX_GUARD_RTGLIB_QDQ_HELPERS_HPP -#define MIGRAPHX_GUARD_RTGLIB_QDQ_HELPERS_HPP - -#include -#include -#include - -namespace migraphx { -inline namespace MIGRAPHX_INLINE_NS { - -std::unordered_set get_quantizable_op_names(); - -} // namespace MIGRAPHX_INLINE_NS -} // namespace migraphx - -#endif diff --git a/src/qdq_helpers.cpp b/src/qdq_helpers.cpp deleted file mode 100644 index af397a4f3e5..00000000000 --- a/src/qdq_helpers.cpp +++ /dev/null @@ -1,37 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ - -#include - -namespace migraphx { -inline namespace MIGRAPHX_INLINE_NS { - -std::unordered_set get_quantizable_op_names() -{ - static std::unordered_set s = {"convolution", "dot"}; - return s; -} - -} // namespace MIGRAPHX_INLINE_NS -} // namespace migraphx diff --git a/src/simplify_qdq.cpp b/src/simplify_qdq.cpp index 2877b1e4bfe..bd21564b618 100644 --- a/src/simplify_qdq.cpp +++ b/src/simplify_qdq.cpp @@ -36,13 +36,18 @@ #include #include #include -#include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace { +std::unordered_set get_quantizable_op_names() +{ + static std::unordered_set s = {"convolution", "dot"}; + return s; +} + struct match_find_quantizable_ops { static bool From 083a9dabed5703cfe40daadc96c4e8dc9abef8ec Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 20 Jan 2025 14:09:10 -0600 Subject: [PATCH 22/23] Fix the gpu context test --- test/gpu/context_serialize.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/gpu/context_serialize.cpp b/test/gpu/context_serialize.cpp index 57f7f974ba5..f620d7add36 100644 --- a/test/gpu/context_serialize.cpp +++ b/test/gpu/context_serialize.cpp @@ -33,7 +33,7 @@ TEST_CASE(gpu_context_serialize) migraphx::context ctx = migraphx::gpu::context{0, 3}; auto v = ctx.to_value(); - EXPECT(v.size() == 3); + EXPECT(v.size() == 2); EXPECT(v.contains("events")); EXPECT(v.at("events").without_key().to() == 0); From 9fc2e97eae88599378d6e4cd704877695700bbb3 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 21 Jan 2025 10:25:33 -0600 Subject: [PATCH 23/23] licensing update --- src/quantization.cpp | 2 +- src/targets/gpu/include/migraphx/gpu/context.hpp | 2 +- test/gpu/context_serialize.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/quantization.cpp b/src/quantization.cpp index 130682a383e..4c96d233fec 100644 --- a/src/quantization.cpp +++ b/src/quantization.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/targets/gpu/include/migraphx/gpu/context.hpp b/src/targets/gpu/include/migraphx/gpu/context.hpp index 8f9dedd6798..7a1a7d34be3 100644 --- a/src/targets/gpu/include/migraphx/gpu/context.hpp +++ b/src/targets/gpu/include/migraphx/gpu/context.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/gpu/context_serialize.cpp b/test/gpu/context_serialize.cpp index f620d7add36..845f594a8f1 100644 --- a/test/gpu/context_serialize.cpp +++ b/test/gpu/context_serialize.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal