From 3112e554df4e94b9b64042f79b6528c3c64e446f Mon Sep 17 00:00:00 2001 From: amberhassaan Date: Wed, 7 Feb 2024 20:20:13 -0500 Subject: [PATCH] Update CK-based 2d/3d convolution solvers to support nchw/ncdhw layout (#2429) --- src/buffer_info.cpp | 20 +- src/conv/invokers/impl_gemm_dynamic.cpp | 10 +- src/include/miopen/batched_transpose_sol.hpp | 61 ++ src/include/miopen/buffer_info.hpp | 8 +- src/include/miopen/conv/tensors.hpp | 2 + src/include/miopen/errors.hpp | 12 + src/include/miopen/handle.hpp | 1 + src/include/miopen/solver.hpp | 24 + .../miopen/solver/implicitgemm_ck_util.hpp | 589 +++++++++++++++++- src/include/miopen/util_sol.hpp | 65 -- .../conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp | 6 +- .../conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp | 2 +- .../conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp | 6 +- .../conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp | 10 +- ...ip_implicit_gemm_3d_grouped_bwd_xdlops.cpp | 128 ++-- ...ip_implicit_gemm_3d_grouped_fwd_xdlops.cpp | 90 ++- ...ip_implicit_gemm_3d_grouped_wrw_xdlops.cpp | 119 ++-- ...conv_hip_implicit_gemm_bwd_data_xdlops.cpp | 13 +- ..._hip_implicit_gemm_f16f8f16_bwd_xdlops.cpp | 6 +- ..._hip_implicit_gemm_f16f8f16_fwd_xdlops.cpp | 6 +- ..._hip_implicit_gemm_f16f8f16_wrw_xdlops.cpp | 6 +- .../conv_hip_implicit_gemm_fwd_xdlops.cpp | 14 +- ...v_hip_implicit_gemm_grouped_bwd_xdlops.cpp | 102 +-- ...v_hip_implicit_gemm_grouped_fwd_xdlops.cpp | 56 +- ...v_hip_implicit_gemm_grouped_wrw_xdlops.cpp | 98 +-- test/gpu_nchw_nhwc_transpose.cpp | 1 - test/gtest/group_conv.hpp | 24 +- 27 files changed, 1084 insertions(+), 395 deletions(-) delete mode 100644 src/include/miopen/util_sol.hpp diff --git a/src/buffer_info.cpp b/src/buffer_info.cpp index fb432e0708..ef567da1a2 100644 --- a/src/buffer_info.cpp +++ b/src/buffer_info.cpp @@ -148,29 +148,31 @@ BuffInfo::BuffInfo(MemLayout_t layout, int nk, int c, int h, int w, int g, int _ } } -MultiBufferWorkspaceTraits::MultiBufferWorkspaceTraits(std::initializer_list v_size_, - size_t alignment_) - : v_size(v_size_), alignment(alignment_) +MultiBufferWorkspaceTraits::MultiBufferWorkspaceTraits(std::initializer_list v_size_) { + + assert(v_size_.size() > 0); size_t each_offset = 0; v_offset.push_back(each_offset); - for(auto each_size : v_size) + for(auto each_size : v_size_) { - size_t padding = (alignment - (each_size % alignment)) % alignment; - each_offset += each_size + padding; + auto padded_size = (each_size + max_padding) & (~max_padding); + each_offset += padded_size; v_offset.push_back(each_offset); } } size_t MultiBufferWorkspaceTraits::GetSize() const { - return (&v_offset.back())[-1] + v_size.back(); + assert(v_offset.size() > 1); + // last location contains the sum of all padded sizes + return v_offset.back(); } size_t MultiBufferWorkspaceTraits::GetOffset(size_t index) const { - if(index >= v_offset.size()) - MIOPEN_THROW("index given overflows"); + // last location contains the sum of all padded sizes + MIOPEN_THROW_IF(index >= (v_offset.size() - 1), "index given overflows"); return v_offset[index]; } diff --git a/src/conv/invokers/impl_gemm_dynamic.cpp b/src/conv/invokers/impl_gemm_dynamic.cpp index 1abbf0ce18..80a8cb8bc3 100644 --- a/src/conv/invokers/impl_gemm_dynamic.cpp +++ b/src/conv/invokers/impl_gemm_dynamic.cpp @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include namespace miopen { @@ -561,8 +561,6 @@ InvokerFactory MakeImplGemmDynamicForwardXdlopsNHWCInvokerFactory( int trans_weight_idx = -1; int trans_output_idx = -1; - constexpr size_t buf_alignment = 256; - if(is_nchw) { TransposeSolutionDefault2Nhwc trans_input(ctx, problem.GetInDataType(), n, c, hi, wi); @@ -601,7 +599,7 @@ InvokerFactory MakeImplGemmDynamicForwardXdlopsNHWCInvokerFactory( const size_t cast_size = need_cast ? miopen::GetTypeSize(miopenFloat) * n * k * ho * wo : 0; MultiBufferWorkspaceTraits wt( - {trans_input_size, trans_weight_size, trans_output_size, cast_size}, buf_alignment); + {trans_input_size, trans_weight_size, trans_output_size, cast_size}); trans_input_offset = wt.GetOffset(0); trans_weight_offset = wt.GetOffset(1); @@ -879,8 +877,6 @@ InvokerFactory MakeImplGemmDynamicBackwardDataXdlopsNHWCInvokerFactory( int trans_weight_idx = -1; int trans_output_idx = -1; - constexpr size_t buf_alignment = 256; - if(is_nchw) { TransposeSolutionNhwc2Default trans_input(ctx, problem.GetOutDataType(), n, c, hi, wi); @@ -919,7 +915,7 @@ InvokerFactory MakeImplGemmDynamicBackwardDataXdlopsNHWCInvokerFactory( const size_t cast_size = need_cast ? miopen::GetTypeSize(miopenFloat) * n * c * hi * wi : 0; MultiBufferWorkspaceTraits wt( - {trans_input_size, trans_weight_size, trans_output_size, cast_size}, buf_alignment); + {trans_input_size, trans_weight_size, trans_output_size, cast_size}); trans_input_offset = wt.GetOffset(0); trans_weight_offset = wt.GetOffset(1); diff --git a/src/include/miopen/batched_transpose_sol.hpp b/src/include/miopen/batched_transpose_sol.hpp index dedbf4f73e..e117afb808 100644 --- a/src/include/miopen/batched_transpose_sol.hpp +++ b/src/include/miopen/batched_transpose_sol.hpp @@ -27,6 +27,7 @@ #define GUARD_MIOPEN_BATCHED_TRANSPOSE_SOL_HPP #include +#include #include #include #include @@ -66,6 +67,66 @@ struct BatchedTransposeSolution BatchedTransposeParam kernel_param_heuristic; }; +struct TransposeSolutionDefault2Nhwc : public BatchedTransposeSolution +{ + TransposeSolutionDefault2Nhwc(const ExecutionContext& ctx_, + miopenDataType_t data_type_, + uint32_t n_, + uint32_t c_, + uint32_t h_, + uint32_t w_) + : BatchedTransposeSolution(ctx_, data_type_, n_, c_, h_ * w_) + { + MIOPEN_THROW_IF(size_t(h_ * w_) != (size_t(h_) * size_t(w_)), "integer overflow"); + } +}; + +struct TransposeSolutionNhwc2Default : public BatchedTransposeSolution +{ + TransposeSolutionNhwc2Default(const ExecutionContext& ctx_, + miopenDataType_t data_type_, + uint32_t n_, + uint32_t c_, + uint32_t h_, + uint32_t w_) + : BatchedTransposeSolution(ctx_, data_type_, n_, h_ * w_, c_) + { + MIOPEN_THROW_IF(size_t(h_ * w_) != (size_t(h_) * size_t(w_)), "integer overflow"); + } +}; + +struct TransposeSolutionDefault2Ndhwc : public BatchedTransposeSolution +{ + TransposeSolutionDefault2Ndhwc(const ExecutionContext& ctx_, + miopenDataType_t data_type_, + uint32_t n_, + uint32_t c_, + uint32_t d_, + uint32_t h_, + uint32_t w_) + : BatchedTransposeSolution(ctx_, data_type_, n_, c_, d_ * h_ * w_) + { + MIOPEN_THROW_IF(size_t(d_ * h_ * w_) != (size_t(d_) * size_t(h_) * size_t(w_)), + "integer overflow"); + } +}; + +struct TransposeSolutionNdhwc2Default : public BatchedTransposeSolution +{ + TransposeSolutionNdhwc2Default(const ExecutionContext& ctx_, + miopenDataType_t data_type_, + uint32_t n_, + uint32_t c_, + uint32_t d_, + uint32_t h_, + uint32_t w_) + : BatchedTransposeSolution(ctx_, data_type_, n_, d_ * h_ * w_, c_) + { + MIOPEN_THROW_IF(size_t(d_ * h_ * w_) != (size_t(d_) * size_t(h_) * size_t(w_)), + "integer overflow"); + } +}; + } // namespace miopen #endif diff --git a/src/include/miopen/buffer_info.hpp b/src/include/miopen/buffer_info.hpp index 95b98c49c6..33e33339a5 100644 --- a/src/include/miopen/buffer_info.hpp +++ b/src/include/miopen/buffer_info.hpp @@ -310,13 +310,15 @@ struct WinogradBufferInfo struct MultiBufferWorkspaceTraits { - MultiBufferWorkspaceTraits(std::initializer_list v_size_, size_t alignment_); + MultiBufferWorkspaceTraits(std::initializer_list v_size_); size_t GetSize() const; size_t GetOffset(size_t index) const; - std::vector v_size; std::vector v_offset; - size_t alignment; + // aligning and padding to 256 byte boundary + constexpr static size_t max_padding = 255ull; + static_assert((max_padding & (max_padding + 1)) == 0, + "max_padding should be 1 less than a power of 2"); }; } // namespace miopen diff --git a/src/include/miopen/conv/tensors.hpp b/src/include/miopen/conv/tensors.hpp index cc96d81f48..ea01c3b0e1 100644 --- a/src/include/miopen/conv/tensors.hpp +++ b/src/include/miopen/conv/tensors.hpp @@ -114,6 +114,8 @@ struct ConvDataTensors out(tensors.dx) { } + + operator ConvTensors() const { return {inDesc, in, wDesc, w, outDesc, out}; } }; struct ConvWrwTensors diff --git a/src/include/miopen/errors.hpp b/src/include/miopen/errors.hpp index 253ad6ec85..a82db5a0f4 100644 --- a/src/include/miopen/errors.hpp +++ b/src/include/miopen/errors.hpp @@ -69,6 +69,18 @@ template miopen::MIOpenThrow(__FILE__, __LINE__, __VA_ARGS__); \ } while(false) +#define MIOPEN_THROW_IF(condition, msg) \ + do \ + { \ + if((condition)) \ + { \ + miopen::MIOpenThrow(__FILE__, \ + __LINE__, \ + miopenStatusInternalError, \ + std::string(msg) + ", failed condition: " #condition); \ + } \ + } while(false) + #define MIOPEN_THROW_CL_STATUS(...) \ MIOPEN_THROW(miopenStatusUnknownError, miopen::OpenCLErrorMessage(__VA_ARGS__)) #define MIOPEN_THROW_HIP_STATUS(...) \ diff --git a/src/include/miopen/handle.hpp b/src/include/miopen/handle.hpp index 39f499b351..5b82e88d3d 100644 --- a/src/include/miopen/handle.hpp +++ b/src/include/miopen/handle.hpp @@ -185,6 +185,7 @@ struct MIOPEN_EXPORT Handle : miopenHandle template Allocator::ManageDataPtr Write(const Container& c) { + assert(!c.empty()); using type = typename Container::value_type; auto buf = this->Create(c.size()); return std::move( diff --git a/src/include/miopen/solver.hpp b/src/include/miopen/solver.hpp index 12e2265e08..fd2fcfb0f9 100644 --- a/src/include/miopen/solver.hpp +++ b/src/include/miopen/solver.hpp @@ -4545,6 +4545,10 @@ struct ConvHipImplicitGemmGroupFwdXdlops final return 0.02f; }; + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; + bool MayNeedWorkspace() const override { return true; } + private: template bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const; @@ -4617,6 +4621,10 @@ struct ConvHipImplicitGemm3DGroupFwdXdlops final return 0.02f; }; + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; + bool MayNeedWorkspace() const override { return true; } + private: template bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const; @@ -4694,6 +4702,10 @@ struct ConvHipImplicitGemm3DGroupWrwXdlops final return 0.02f; }; + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; + bool MayNeedWorkspace() const override { return true; } + private: template bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const; @@ -4771,6 +4783,10 @@ struct ConvHipImplicitGemm3DGroupBwdXdlops final return 0.02f; }; + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; + bool MayNeedWorkspace() const override { return true; } + private: template bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const; @@ -4847,6 +4863,10 @@ struct ConvHipImplicitGemmGroupBwdXdlops final return 0.02f; }; + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; + bool MayNeedWorkspace() const override { return true; } + private: template bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const; @@ -4923,6 +4943,10 @@ struct ConvHipImplicitGemmGroupWrwXdlops final return 0.02f; }; + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; + bool MayNeedWorkspace() const override { return true; } + private: template bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const; diff --git a/src/include/miopen/solver/implicitgemm_ck_util.hpp b/src/include/miopen/solver/implicitgemm_ck_util.hpp index 2bf5848dab..2199d88b01 100644 --- a/src/include/miopen/solver/implicitgemm_ck_util.hpp +++ b/src/include/miopen/solver/implicitgemm_ck_util.hpp @@ -28,6 +28,11 @@ #include #include +#include + +#if MIOPEN_USE_COMPOSABLEKERNEL +#include +#endif // MIOPEN_USE_COMPOSABLEKERNEL namespace miopen { @@ -95,7 +100,9 @@ template -ConvSolution MakeInvokerFactory(const ProblemDescriptionType& problem, const std::string& kernel_id) +ConvSolution InitInvokerFactoryNHWC(const ExecutionContext&, + const ProblemDescriptionType& problem, + const std::string& kernel_id) { auto conv_ptrs = DeviceOpType::GetInstances(); auto ptr_iter = FindConvPtrByID(conv_ptrs, kernel_id); @@ -165,5 +172,585 @@ ConvSolution InitAnyInvokerFactory(const ProblemDescriptionType& problem, return result; } +namespace internal { + +enum class ConvOperandTag : int +{ + Input = 0, + Weights, + Output +}; + +enum class TranposeKind : int +{ + NHWC_TO_NCHW = 0, + NCHW_TO_NHWC +}; + +template +struct TransposeOperand +{ + static_assert(ND == 2 || ND == 3, "Num Dimensions must be 2 or 3"); + constexpr static int NDIM = ND; + constexpr static ConvOperandTag CONV_OP_TAG = CONV_OP; + constexpr static TranposeKind TRANSPOSE_KIND = TPOSE_KIND; + + using SolverType = + std::conditional_t, + // NCHW_TO_NHWC + std::conditional_t>; + + template + SolverType MakeTransposeSolver(const miopen::ExecutionContext& ctx, + const miopen::conv::ProblemDescription& problem, + const CKArgsType& ck_args) const + { + + if constexpr(CONV_OP_TAG == ConvOperandTag::Input) + { + if constexpr(ND == 3) + { + + return SolverType{ctx, + problem.GetInDataType(), + static_cast(ck_args.N), + static_cast(ck_args.C1), + static_cast(ck_args.Di), + static_cast(ck_args.Hi), + static_cast(ck_args.Wi)}; + } + else + { + return SolverType{ctx, + problem.GetInDataType(), + static_cast(ck_args.N), + static_cast(ck_args.C1), + static_cast(ck_args.Hi), + static_cast(ck_args.Wi)}; + } + } + else if constexpr(CONV_OP_TAG == ConvOperandTag::Weights) + { + if constexpr(ND == 3) + { + return SolverType{ctx, + problem.GetWeightsDataType(), + static_cast(ck_args.K1), + static_cast(ck_args.C), + static_cast(ck_args.Z), + static_cast(ck_args.Y), + static_cast(ck_args.X)}; + } + else + { + return SolverType{ctx, + problem.GetWeightsDataType(), + static_cast(ck_args.K1), + static_cast(ck_args.C), + static_cast(ck_args.Y), + static_cast(ck_args.X)}; + } + } + else + { + static_assert(CONV_OP_TAG == ConvOperandTag::Output); + if constexpr(ND == 3) + { + return SolverType{ctx, + problem.GetOutDataType(), + static_cast(ck_args.N), + static_cast(ck_args.K1), + static_cast(ck_args.Do), + static_cast(ck_args.Ho), + static_cast(ck_args.Wo)}; + } + else + { + return SolverType{ctx, + problem.GetOutDataType(), + static_cast(ck_args.N), + static_cast(ck_args.K1), + static_cast(ck_args.Ho), + static_cast(ck_args.Wo)}; + } + } + } +}; + +// Shorthand aliases for CK assuming CK always expects and generates NHWC/NDHWC layouts +template +using CKTransposeInputOp = TransposeOperand; + +template +using CKTransposeOutputOp = TransposeOperand; + +class TransposeInstance +{ + size_t tensor_sz = 0; + std::vector kern_args{}; + size_t kern_idx = std::numeric_limits::max(); + size_t buf_offset = 0; + shared buf_handle{}; + +public: + template + TransposeInstance(const TransSolnType& trans_sol, + size_t k_idx, + const MultiBufferWorkspaceTraits& wt, + size_t wspace_index) + : tensor_sz(trans_sol.GetOutputTensorSize()), + kern_args(trans_sol.GetKernelArg()), + kern_idx(k_idx), + buf_offset(wt.GetOffset(wspace_index)) + { + } + + void AssignBuffer(const Handle& handle, Data_t workSpace) + { + buf_handle = handle.CreateSubBuffer(workSpace, buf_offset, tensor_sz); + assert(buf_handle.get()); + } + + Data_t GetBufferPtr() const { return buf_handle.get(); } + + void ConvertFrom(const Handle& handle, const std::vector& kernels, ConstData_t in_ptr) + { + Run(handle, kernels, buf_handle.get(), in_ptr); + } + + void ConvertTo(const Handle& handle, const std::vector& kernels, Data_t out_ptr) + { + Run(handle, kernels, out_ptr, buf_handle.get()); + } + + void ZeroOutBuffer() + { + [[maybe_unused]] auto status = hipMemset(buf_handle.get(), 0, tensor_sz); + assert(status == hipSuccess); + } + + TransposeInstance() = delete; + TransposeInstance(const TransposeInstance&) = default; + TransposeInstance(TransposeInstance&&) = default; + ~TransposeInstance() = default; + +private: + void Run(const Handle& handle, + const std::vector& kernels, + Data_t out_ptr, + ConstData_t in_ptr) + { + assert(out_ptr); + assert(in_ptr); + assert(kernels.size() > kern_idx); + + kern_args[0] = out_ptr; + kern_args[1] = in_ptr; + + auto save = handle.IsProfilingEnabled() ? 0.0f : handle.GetKernelTime(); + handle.Run(kernels[kern_idx])(kern_args); + if(handle.IsProfilingEnabled()) + { + handle.AccumKernelTime(save); + } + } +}; + +class TransposeInstanceTagged : public TransposeInstance +{ + + ConvOperandTag conv_op_tag_; + +public: + template + TransposeInstanceTagged(const TransSolnType& sol, + size_t k_idx, + const MultiBufferWorkspaceTraits& wt, + size_t wspace_index, + ConvOperandTag conv_op_tag) + : TransposeInstance(sol, k_idx, wt, wspace_index), conv_op_tag_(conv_op_tag) + { + } + + ConvOperandTag GetConvOperandTag() const noexcept { return conv_op_tag_; } + + std::underlying_type_t GetConvOperandTagAsInt() const noexcept + { + using IntType = std::underlying_type_t; + return static_cast(GetConvOperandTag()); + } + + void ConvertFrom(const Handle& handle, + const std::vector& kernels, + const ConvTensors& tensors) + { + TransposeInstance::ConvertFrom(handle, kernels, pickTensorPtr(tensors)); + } + + void + ConvertTo(const Handle& handle, const std::vector& kernels, const ConvTensors& tensors) + { + TransposeInstance::ConvertTo(handle, kernels, pickTensorPtr(tensors)); + } + + TransposeInstanceTagged() = delete; + TransposeInstanceTagged(const TransposeInstanceTagged&) = default; + TransposeInstanceTagged(TransposeInstanceTagged&&) = default; + ~TransposeInstanceTagged() = default; + +private: + Data_t pickTensorPtr(const ConvTensors& tensors) const + { + std::array data_ptrs = { + const_cast(tensors.x), // NOLINT (cppcoreguidelines-pro-type-const-cast) + const_cast(tensors.w), // NOLINT (cppcoreguidelines-pro-type-const-cast) + const_cast(tensors.y) // NOLINT (cppcoreguidelines-pro-type-const-cast) + }; + + return data_ptrs[GetConvOperandTagAsInt()]; + } +}; + +template +auto MakeTaggedTransposeInstances(ConvSolution& result, + const ExecutionContext& ctx, + const miopen::conv::ProblemDescription& problem, + const CKArgsType& ck_args, + const Input1TposeOp& input1_op, + const Input2TposeOp& input2_op, + const OutputTposeOp& output_op) +{ + + auto input1_solver = input1_op.MakeTransposeSolver(ctx, problem, ck_args); + auto input2_solver = input2_op.MakeTransposeSolver(ctx, problem, ck_args); + auto output_solver = output_op.MakeTransposeSolver(ctx, problem, ck_args); + + // NOTE: In cases where the convolution updates only a subset of output + // indices, we need to first initialize the workspace buffer for + // output with the real tensor for the output and then apply the convolution. + // This is achieved by creating an input transpose op for the output workspace + // bufffer. + + using OutputInitOp = CKTransposeInputOp; + + auto output_init_solver = OutputInitOp{}.MakeTransposeSolver(ctx, problem, ck_args); + + result.construction_params.insert(result.construction_params.end(), + {input1_solver.GetKernelInfo(), + input2_solver.GetKernelInfo(), + output_solver.GetKernelInfo(), + output_init_solver.GetKernelInfo()}); + + MultiBufferWorkspaceTraits wt({input1_solver.GetOutputTensorSize(), + input2_solver.GetOutputTensorSize(), + output_solver.GetOutputTensorSize()}); + + return std::make_tuple( + TransposeInstanceTagged{input1_solver, 0, wt, 0, Input1TposeOp::CONV_OP_TAG}, + TransposeInstanceTagged{input2_solver, 1, wt, 1, Input2TposeOp::CONV_OP_TAG}, + TransposeInstanceTagged{output_solver, 2, wt, 2, OutputTposeOp::CONV_OP_TAG}, + TransposeInstanceTagged{output_init_solver, 3, wt, 2, OutputTposeOp::CONV_OP_TAG}); +} + +#ifndef NDEBUG // disable for release builds, enable for debug builds + +template +void DebugPrintVec(const char* name, const V& vec) +{ + std::ostringstream oss; + oss << name << " = [ "; + for(const auto& v : vec) + { + oss << v << ", "; + } + oss << "]"; + MIOPEN_LOG_I(oss.str()); +} + +#define DEBUG_PRINT_VEC(x) DebugPrintVec(#x, x); + +template +void DebugPrintCKArgPtrs( + const CKArgsType& ck_args, const ConvPtr& conv_ptr, ConstData_t x, ConstData_t w, ConstData_t y) +{ + + MIOPEN_LOG_I("CK Instance: " << conv_ptr->GetTypeString()); + MIOPEN_LOG_I("in ptr = " << x); + MIOPEN_LOG_I("w ptr = " << w); + MIOPEN_LOG_I("out ptr = " << y); + + DEBUG_PRINT_VEC(ck_args.input); + DEBUG_PRINT_VEC(ck_args.in_strides); + DEBUG_PRINT_VEC(ck_args.weight); + DEBUG_PRINT_VEC(ck_args.wei_strides); + DEBUG_PRINT_VEC(ck_args.output); + DEBUG_PRINT_VEC(ck_args.out_strides); +} + +inline void DebugPrintConvTensors(const ConvTensors& conv_tensors) +{ + MIOPEN_LOG_I("in ptr = " << conv_tensors.x); + MIOPEN_LOG_I("w ptr = " << conv_tensors.w); + MIOPEN_LOG_I("out ptr = " << conv_tensors.y); + + DEBUG_PRINT_VEC(conv_tensors.xDesc.GetLengths()); + DEBUG_PRINT_VEC(conv_tensors.wDesc.GetLengths()); + DEBUG_PRINT_VEC(conv_tensors.yDesc.GetLengths()); +} + +#undef DEBUG_PRINT_VEC + +#endif // NDEBUG +} // end namespace internal + +/// \todo move to a cpp file +inline size_t GetWorkspaceSizeLayoutTransformConv(const miopen::conv::ProblemDescription& problem) +{ + if(problem.IsLayoutNHWC()) + { + return 0; + } + + assert(problem.IsLayoutDefault()); + // packed size in bytes + auto GetPackedSize = [](const TensorDescriptor& td) { + return td.GetElementSize() * GetTypeSize(td.GetType()); + }; + + MultiBufferWorkspaceTraits wt({GetPackedSize(problem.GetIn()), + GetPackedSize(problem.GetWeights()), + GetPackedSize(problem.GetOut())}); + + return wt.GetSize(); +} + +template +ConvSolution InitInvokerFactoryNCHW(const ExecutionContext& ctx, + const miopen::conv::ProblemDescription& problem, + const std::string& kernel_id, + const Input1TposeOp& input1_op, + const Input2TposeOp& input2_op, + const OutputTposeOp& output_op) +{ + + assert(problem.IsLayoutDefault()); + + ConvSolution result; + auto ck_args = CKArgsType{problem}; + + auto [_input1_tr_inst, _input2_tr_inst, _output_tr_inst, _output_init_tr_inst] = + internal::MakeTaggedTransposeInstances( + result, ctx, problem, ck_args, input1_op, input2_op, output_op); + + auto conv_ptrs = DeviceOpType::GetInstances(); + auto ptr_iter = FindConvPtrByID(conv_ptrs, kernel_id); + + if(ptr_iter == conv_ptrs.end()) + { + MIOPEN_LOG_E("PerformanceConfig kernel '" + kernel_id + "' does not exist."); + return {miopenStatusInvalidValue}; + } + + result.invoker_factory = [ck_args = std::move(ck_args), + sh_conv_ptr = std::shared_ptr{std::move(*ptr_iter)}, + input1_tr_inst = std::move(_input1_tr_inst), + input2_tr_inst = std::move(_input2_tr_inst), + output_tr_inst = std::move(_output_tr_inst), + output_init_tr_inst = std::move(_output_init_tr_inst)]( + const std::vector& kernels) mutable { + return [kernels, + ck_args = std::move(ck_args), + sh_conv_ptr = std::move(sh_conv_ptr), + input1_tr_inst = std::move(input1_tr_inst), + input2_tr_inst = std::move(input2_tr_inst), + output_tr_inst = std::move(output_tr_inst), + output_init_tr_inst = std::move(output_init_tr_inst)]( + const Handle& handle, const AnyInvokeParams& primitive_parameters) mutable { + handle.ResetKernelTime(); + + const auto& data_ctx = primitive_parameters.CastTo(); + + if(!data_ctx.workSpace) + { + MIOPEN_THROW(miopenStatusInvalidValue, "workspace pointer is null"); + } + + input1_tr_inst.AssignBuffer(handle, data_ctx.workSpace); + input2_tr_inst.AssignBuffer(handle, data_ctx.workSpace); + output_tr_inst.AssignBuffer(handle, data_ctx.workSpace); + output_init_tr_inst.AssignBuffer(handle, data_ctx.workSpace); + + // conversion operator applied here to convert to ConvTensors + auto conv_tensors = ConvTensors(data_ctx.tensors); + + /// \todo remove this when DataInvokeParams stops swapping + // "in" and "out" tensors for backward pass + if(output_tr_inst.GetConvOperandTag() == internal::ConvOperandTag::Input) + { + // this is backward pass, swap back input and output + std::swap(conv_tensors.x, conv_tensors.y); + std::swap(conv_tensors.xDesc, conv_tensors.yDesc); + } + + input1_tr_inst.ConvertFrom(handle, kernels, conv_tensors); + + input2_tr_inst.ConvertFrom(handle, kernels, conv_tensors); + + output_init_tr_inst.ConvertFrom(handle, kernels, conv_tensors); + + /// \todo: Fix NHWC Wrw invokers to also issue a zero-out kernel. Will + /// need SetTensor() to properly zero out non-packed tensors + if(output_tr_inst.GetConvOperandTag() == internal::ConvOperandTag::Weights) + { + output_tr_inst.ZeroOutBuffer(); + } + + std::array tr_ptrs = { + &input1_tr_inst, &input2_tr_inst, &output_tr_inst}; + + // sort by tag in order: Input, Weights, Output + std::sort(tr_ptrs.begin(), tr_ptrs.end(), [](const auto& left, const auto& right) { + return left->GetConvOperandTagAsInt() < right->GetConvOperandTagAsInt(); + }); + + auto invoker_ptr = sh_conv_ptr->MakeInvokerPointer(); + auto argument_ptr = ck_args.MakeArgPtr(sh_conv_ptr, + tr_ptrs[0]->GetBufferPtr(), + tr_ptrs[1]->GetBufferPtr(), + tr_ptrs[2]->GetBufferPtr()); + float conv_time = 0; + conv_time += invoker_ptr->Run(argument_ptr.get(), + {handle.GetStream(), handle.IsProfilingEnabled()}); + + if(handle.IsProfilingEnabled()) + { + handle.AccumKernelTime(conv_time); + } + + output_tr_inst.ConvertTo(handle, kernels, conv_tensors); + }; + }; + + result.workspace_sz = GetWorkspaceSizeLayoutTransformConv(problem); + + return result; +} + +template +ConvSolution InitInvokerFactoryFwdNCHW(const ExecutionContext& ctx, + const miopen::conv::ProblemDescription& problem, + const std::string& kernel_id) +{ + + static_assert(ND == 2 || ND == 3, "Num Dimensions must be 2 or 3"); + + using Input1 = internal::CKTransposeInputOp; + using Input2 = internal::CKTransposeInputOp; + using Output = internal::CKTransposeOutputOp; + + return InitInvokerFactoryNCHW( + ctx, problem, kernel_id, Input1{}, Input2{}, Output{}); +} + +template +ConvSolution InitInvokerFactoryBwdNCHW(const ExecutionContext& ctx, + const miopen::conv::ProblemDescription& problem, + const std::string& kernel_id) +{ + + static_assert(ND == 2 || ND == 3, "Num Dimensions must be 2 or 3"); + + using Input1 = internal::CKTransposeInputOp; + using Input2 = internal::CKTransposeInputOp; + using Output = internal::CKTransposeOutputOp; + + return InitInvokerFactoryNCHW( + ctx, problem, kernel_id, Input1{}, Input2{}, Output{}); +} + +template +ConvSolution InitInvokerFactoryWrwNCHW(const ExecutionContext& ctx, + const miopen::conv::ProblemDescription& problem, + const std::string& kernel_id) +{ + static_assert(ND == 2 || ND == 3, "Num Dimensions must be 2 or 3"); + + using Input1 = internal::CKTransposeInputOp; + using Input2 = internal::CKTransposeInputOp; + using Output = internal::CKTransposeOutputOp; + + return InitInvokerFactoryNCHW( + ctx, problem, kernel_id, Input1{}, Input2{}, Output{}); +} + +template +ConvSolution +MakeSolutionGroupConvImplicitGemmXdlops(const miopen::conv::ProblemDescription& problem, + InvokerFactoryMakerNCHW&& invoker_factory_maker_ncdhw, + InvokerFactoryMakerNHWC&& invoker_factory_maker_ndhwc) +{ + +#if MIOPEN_USE_COMPOSABLEKERNEL + if(problem.IsLayoutDefault()) + { + switch(problem.GetInDataType()) + { + case miopenInt8: return invoker_factory_maker_ncdhw(int8_t{}); + case miopenHalf: return invoker_factory_maker_ncdhw(ck::half_t{}); + case miopenFloat: return invoker_factory_maker_ncdhw(float{}); + case miopenInt32: + case miopenBFloat16: + case miopenDouble: + case miopenFloat8: + case miopenBFloat8: + default: + MIOPEN_THROW(miopenStatusInternalError, + "3DGroupConvolutionImplicitGemmXdlops operation not implemented for this " + "data type"); + } + } + else if(problem.IsLayoutNHWC()) + { + switch(problem.GetInDataType()) + { + case miopenInt8: return invoker_factory_maker_ndhwc(int8_t{}); + case miopenHalf: return invoker_factory_maker_ndhwc(ck::half_t{}); + case miopenFloat: return invoker_factory_maker_ndhwc(float{}); + case miopenInt32: + case miopenBFloat16: + case miopenDouble: + case miopenFloat8: + case miopenBFloat8: + default: + MIOPEN_THROW(miopenStatusInternalError, + "3DGroupConvolutionImplicitGemmXdlops operation not implemented for this " + "data type"); + } + } + else + { + MIOPEN_THROW( + miopenStatusInternalError, + "3DGroupConvolutionImplicitGemmXdlops operation not implemented for this data type"); + } +#else + return {}; +#endif +} + } // namespace solver } // namespace miopen diff --git a/src/include/miopen/util_sol.hpp b/src/include/miopen/util_sol.hpp deleted file mode 100644 index c6b5f1e6e7..0000000000 --- a/src/include/miopen/util_sol.hpp +++ /dev/null @@ -1,65 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c_) 202 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef MIOPEN_UTIL_SOL_HPP_ -#define MIOPEN_UTIL_SOL_HPP_ - -#include -#include -#include -#include -#include -#include - -namespace miopen { - -struct TransposeSolutionDefault2Nhwc : public BatchedTransposeSolution -{ - TransposeSolutionDefault2Nhwc(const ExecutionContext& ctx_, - miopenDataType_t data_type_, - uint32_t n_, - uint32_t c_, - uint32_t h_, - uint32_t w_) - : BatchedTransposeSolution(ctx_, data_type_, n_, c_, h_ * w_) - { - } -}; - -struct TransposeSolutionNhwc2Default : public BatchedTransposeSolution -{ - TransposeSolutionNhwc2Default(const ExecutionContext& ctx_, - miopenDataType_t data_type_, - uint32_t n_, - uint32_t c_, - uint32_t h_, - uint32_t w_) - : BatchedTransposeSolution(ctx_, data_type_, n_, h_ * w_, c_) - { - } -}; -} // namespace miopen - -#endif // MIOPEN_UTIL_SOL_HPP_ diff --git a/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp index 1588ac0ef8..ddeaf4a5bf 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp @@ -30,7 +30,7 @@ #include #include #include -#include +#include MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_BWD_GTC_XDLOPS_NHWC) MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_PK_ATOMIC_ADD_FP16) @@ -1017,8 +1017,6 @@ size_t ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::GetWorkspaceSize( size_t size_trans_output = 0; size_t size_tensor_cast = 0; - constexpr size_t buf_alignment = 256; - size_t workspace_size = 0; if(is_nchw) { @@ -1047,7 +1045,7 @@ size_t ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::GetWorkspaceSize( } MultiBufferWorkspaceTraits wt( - {size_trans_input, size_trans_weight, size_trans_output, size_tensor_cast}, buf_alignment); + {size_trans_input, size_trans_weight, size_trans_output, size_tensor_cast}); workspace_size = wt.GetSize(); return workspace_size; diff --git a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp index e60f5f4911..bb94610fa9 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp @@ -30,7 +30,7 @@ #include #include #include -#include +#include MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_GTC_DLOPS_NCHWC) MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_PK_ATOMIC_ADD_FP16) diff --git a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp index d1cdcecd85..a850667212 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp @@ -30,7 +30,7 @@ #include #include #include -#include +#include MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_GTC_XDLOPS_NHWC) MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_PK_ATOMIC_ADD_FP16) @@ -824,8 +824,6 @@ size_t ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::GetWorkspaceSize( size_t size_trans_output = 0; size_t size_tensor_cast = 0; - constexpr size_t buf_alignment = 256; - if(is_nchw) { @@ -855,7 +853,7 @@ size_t ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::GetWorkspaceSize( } MultiBufferWorkspaceTraits wt( - {size_trans_input, size_trans_weight, size_trans_output, size_tensor_cast}, buf_alignment); + {size_trans_input, size_trans_weight, size_trans_output, size_tensor_cast}); workspace_size = wt.GetSize(); return workspace_size; diff --git a/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp index 806095efb3..ca9981e922 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp @@ -33,7 +33,7 @@ #include #include #include -#include +#include MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_WRW_GTC_XDLOPS_NHWC) MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_PK_ATOMIC_ADD_FP16) @@ -993,8 +993,6 @@ size_t ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::GetWorkspaceSize( size_t size_trans_output = 0; size_t size_tensor_cast = 0; - constexpr size_t buf_alignment = 256; - size_t workspace_size = 0; if(is_nchw) { @@ -1023,7 +1021,7 @@ size_t ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::GetWorkspaceSize( } MultiBufferWorkspaceTraits wt( - {size_trans_input, size_trans_weight, size_trans_output, size_tensor_cast}, buf_alignment); + {size_trans_input, size_trans_weight, size_trans_output, size_tensor_cast}); workspace_size = wt.GetSize(); return workspace_size; @@ -1164,8 +1162,6 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::GetSolution( int trans_weight_idx = -1; int trans_output_idx = -1; - constexpr size_t buf_alignment = 256; - if(is_nchw) { TransposeSolutionDefault2Nhwc trans_input(ctx, problem.GetOutDataType(), n, c, hi, wi); @@ -1222,7 +1218,7 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::GetSolution( need_cast ? miopen::GetTypeSize(miopenFloat) * k * (c / group) * y * x : 0; MultiBufferWorkspaceTraits wt( - {trans_input_size, trans_weight_size, trans_output_size, cast_size}, buf_alignment); + {trans_input_size, trans_weight_size, trans_output_size, cast_size}); trans_input_offset = wt.GetOffset(0); trans_weight_offset = wt.GetOffset(1); diff --git a/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp index 0908d9b850..910e978587 100644 --- a/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp @@ -90,18 +90,37 @@ struct CKArgs output = {G, N, K, Do, Ho, Wo}; weight = {G, K, C, Z, Y, X}; - // miopen strides to CK strides - // On a backward pass, problem.GetIn() means y(or out), - // and problem.GetOut means x(or in) - auto miopen_in_strides = problem.GetOut().GetStrides(); - auto miopen_out_strides = problem.GetIn().GetStrides(); - auto miopen_wei_strides = problem.GetWeights().GetStrides(); - miopen_in_strides.insert(miopen_in_strides.begin(), C); - miopen_out_strides.insert(miopen_out_strides.begin(), K); - miopen_wei_strides.insert(miopen_wei_strides.begin(), K * miopen_wei_strides[0]); - std::copy(miopen_in_strides.begin(), miopen_in_strides.end(), in_strides.begin()); - std::copy(miopen_out_strides.begin(), miopen_out_strides.end(), out_strides.begin()); - std::copy(miopen_wei_strides.begin(), miopen_wei_strides.end(), wei_strides.begin()); + // CK strides are in GNCDHW order + if(problem.IsLayoutNHWC()) + { + // first entry reserved for G's stride + auto copy_strides = [](const auto& src, auto& dst) { + assert(dst.size() == (src.size() + 1)); + std::copy(src.begin(), src.end(), dst.begin() + 1); + }; + copy_strides(problem.GetIn().GetStrides(), in_strides); + copy_strides(problem.GetOut().GetStrides(), out_strides); + copy_strides(problem.GetWeights().GetStrides(), wei_strides); + + // On a backward pass, problem.GetIn() means y(or out), + // and problem.GetOut means x(or in) + /// \todo remove this when we stop swapping in and out tensors/descriptors + std::swap(in_strides, out_strides); + + // Now compute G's stride + in_strides[0] = C; + out_strides[0] = K; + wei_strides[0] = K * wei_strides[1]; + } + else + { + assert(problem.IsLayoutDefault()); // already checked in IsApplicable + // for default layout, we produce packed strides for NHWC layout + // because we transpose to NHWC layout before calling CK kernel + in_strides = {C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + out_strides = {K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + wei_strides = {K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + } strides = {ProblemInterpreter::GetAdjustedConvolutionStrideD(problem), ProblemInterpreter::GetAdjustedConvolutionStrideH(problem), @@ -122,35 +141,12 @@ struct CKArgs CKArgs& operator=(const CKArgs&) = default; template - auto MakeArgPtr(const ConvPtr& conv_ptr, ConstData_t in, ConstData_t w, Data_t out) const + auto MakeArgPtr(const ConvPtr& conv_ptr, Data_t in, ConstData_t w, ConstData_t out) const { - -#if 0 // Leaving for debugging needs - std::cout << "y ptr = " << in << std::endl; - std::cout << "w ptr = " << w << std::endl; - std::cout << "x ptr = " << out << std::endl; - - auto print_vec = [](const char* name, const auto& vec) { - std::cout << name << " = [ "; - for(const auto& v : vec) - { - std::cout << v << ", "; - } - std::cout << "]\n"; - }; -#define PRINT_VEC(x) print_vec(#x, x); - - PRINT_VEC(output); - PRINT_VEC(out_strides); - PRINT_VEC(input); - PRINT_VEC(in_strides); - PRINT_VEC(weight); - PRINT_VEC(wei_strides); -#endif - return conv_ptr->MakeArgumentPointer(in, + return conv_ptr->MakeArgumentPointer(out, w, {}, - out, + in, output, out_strides, weight, @@ -171,7 +167,7 @@ struct CKArgs template auto MakeArgPtr(const ConvPtr& conv_ptr, const ConvDataTensors& tensors) const { - return MakeArgPtr(conv_ptr, tensors.in, tensors.w, tensors.out); + return MakeArgPtr(conv_ptr, tensors.out, tensors.w, tensors.in); } template @@ -319,6 +315,13 @@ bool ConvHipImplicitGemm3DGroupBwdXdlops::IsValidPerformanceConfig( return config.IsValid(problem); } +size_t +ConvHipImplicitGemm3DGroupBwdXdlops::GetWorkspaceSize(const ExecutionContext&, + const ProblemDescription& problem) const +{ + return GetWorkspaceSizeLayoutTransformConv(problem); +} + PerformanceConfigHipImplicitGemm3DGroupBwdXdlops ConvHipImplicitGemm3DGroupBwdXdlops::Search(const ExecutionContext& ctx, const ProblemDescription& problem, @@ -346,7 +349,10 @@ bool ConvHipImplicitGemm3DGroupBwdXdlops::IsApplicable( return false; if(!problem.Is3d()) return false; - if(!problem.IsLayoutNHWC()) + if(!(problem.IsLayoutNHWC() || problem.IsLayoutDefault())) + return false; + // needed because layout transpose kernel does not support non-packed tensors + if(problem.IsLayoutDefault() && problem.HasNonPackedTensors()) return false; if(!ck_utility::is_ck_whitelist(ctx.GetStream().GetDeviceName())) return false; @@ -371,29 +377,27 @@ ConvSolution ConvHipImplicitGemm3DGroupBwdXdlops::GetSolution( [[maybe_unused]] const PerformanceConfigHipImplicitGemm3DGroupBwdXdlops& config) const { #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL - switch(problem.GetInDataType()) - { - case miopenInt8: - return MakeInvokerFactory, CKArgs, miopen::conv::DataInvokeParams>( - problem, config.kernel_id); - case miopenHalf: - return MakeInvokerFactory, - CKArgs, - miopen::conv::DataInvokeParams>(problem, config.kernel_id); - case miopenFloat: - return MakeInvokerFactory, CKArgs, miopen::conv::DataInvokeParams>( - problem, config.kernel_id); - case miopenInt32: - case miopenBFloat16: - case miopenDouble: - case miopenFloat8: - case miopenBFloat8: - default: - MIOPEN_THROW(miopenStatusInternalError, - "ConvHipImplicitGemmBwdXdlops operation not implemented for this data type"); - } -#endif + return MakeSolutionGroupConvImplicitGemmXdlops( + problem, + [&](auto data_type_val) { + using T = decltype(data_type_val); + return InitInvokerFactoryBwdNCHW<3, + DeviceOpGBwdPtrs, + CKArgs, + miopen::conv::DataInvokeParams>( + ctx, problem, config.kernel_id); + }, + [&](auto data_type_val) { + using T = decltype(data_type_val); + return InitInvokerFactoryNHWC, + CKArgs, + miopen::conv::DataInvokeParams>( + ctx, problem, config.kernel_id); + }); + +#else return {}; +#endif } } // namespace conv diff --git a/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp index a4fdfd7e52..637a7b2008 100644 --- a/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp @@ -90,16 +90,32 @@ struct CKArgs output = {G, N, K, Do, Ho, Wo}; weight = {G, K, C, Z, Y, X}; - // miopen strides to CK strides - auto miopen_in_strides = problem.GetIn().GetStrides(); - auto miopen_out_strides = problem.GetOut().GetStrides(); - auto miopen_wei_strides = problem.GetWeights().GetStrides(); - miopen_in_strides.insert(miopen_in_strides.begin(), C); - miopen_out_strides.insert(miopen_out_strides.begin(), K); - miopen_wei_strides.insert(miopen_wei_strides.begin(), K * miopen_wei_strides[0]); - std::copy(miopen_in_strides.begin(), miopen_in_strides.end(), in_strides.begin()); - std::copy(miopen_out_strides.begin(), miopen_out_strides.end(), out_strides.begin()); - std::copy(miopen_wei_strides.begin(), miopen_wei_strides.end(), wei_strides.begin()); + // CK strides are in GNCDHW order + if(problem.IsLayoutNHWC()) + { + // first entry reserved for G's stride + auto copy_strides = [](const auto& src, auto& dst) { + assert(dst.size() == (src.size() + 1)); + std::copy(src.begin(), src.end(), dst.begin() + 1); + }; + copy_strides(problem.GetIn().GetStrides(), in_strides); + copy_strides(problem.GetOut().GetStrides(), out_strides); + copy_strides(problem.GetWeights().GetStrides(), wei_strides); + + // Now compute G's stride + in_strides[0] = C; + out_strides[0] = K; + wei_strides[0] = K * wei_strides[1]; + } + else + { + assert(problem.IsLayoutDefault()); // already checked in IsApplicable + // for default layout, we produce packed strides for NHWC layout + // because we transpose to NHWC layout before calling CK kernel + in_strides = {C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + out_strides = {K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + wei_strides = {K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + } strides = {ProblemInterpreter::GetAdjustedConvolutionStrideD(problem), ProblemInterpreter::GetAdjustedConvolutionStrideH(problem), @@ -295,6 +311,13 @@ bool ConvHipImplicitGemm3DGroupFwdXdlops::IsValidPerformanceConfig( return config.IsValid(problem); } +size_t +ConvHipImplicitGemm3DGroupFwdXdlops::GetWorkspaceSize(const ExecutionContext&, + const ProblemDescription& problem) const +{ + return GetWorkspaceSizeLayoutTransformConv(problem); +} + PerformanceConfigHipImplicitGemm3DGroupFwdXdlops ConvHipImplicitGemm3DGroupFwdXdlops::Search(const ExecutionContext& ctx, const ProblemDescription& problem, @@ -320,7 +343,10 @@ bool ConvHipImplicitGemm3DGroupFwdXdlops::IsApplicable( return false; if(!problem.Is3d()) return false; - if(!problem.IsLayoutNHWC()) + if(!(problem.IsLayoutNHWC() || problem.IsLayoutDefault())) + return false; + // needed because layout transpose kernel does not support non-packed tensors + if(problem.IsLayoutDefault() && problem.HasNonPackedTensors()) return false; if(!ck_utility::is_ck_whitelist(ctx.GetStream().GetDeviceName())) return false; @@ -345,29 +371,27 @@ ConvSolution ConvHipImplicitGemm3DGroupFwdXdlops::GetSolution( [[maybe_unused]] const PerformanceConfigHipImplicitGemm3DGroupFwdXdlops& config) const { #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL - switch(problem.GetInDataType()) - { - case miopenInt8: - return MakeInvokerFactory, CKArgs, miopen::conv::DataInvokeParams>( - problem, config.kernel_id); - case miopenHalf: - return MakeInvokerFactory, - CKArgs, - miopen::conv::DataInvokeParams>(problem, config.kernel_id); - case miopenFloat: - return MakeInvokerFactory, CKArgs, miopen::conv::DataInvokeParams>( - problem, config.kernel_id); - case miopenInt32: - case miopenBFloat16: - case miopenDouble: - case miopenFloat8: - case miopenBFloat8: - default: - MIOPEN_THROW(miopenStatusInternalError, - "ConvHipImplicitGemmFwdXdlops operation not implemented for this data type"); - } -#endif + return MakeSolutionGroupConvImplicitGemmXdlops( + problem, + [&](auto data_type_val) { + using T = decltype(data_type_val); + return InitInvokerFactoryFwdNCHW<3, + DeviceOpGFwdPtrs, + CKArgs, + miopen::conv::DataInvokeParams>( + ctx, problem, config.kernel_id); + }, + [&](auto data_type_val) { + using T = decltype(data_type_val); + return InitInvokerFactoryNHWC, + CKArgs, + miopen::conv::DataInvokeParams>( + ctx, problem, config.kernel_id); + }); + +#else return {}; +#endif } } // namespace conv diff --git a/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp index aae133335e..1865ec82c3 100644 --- a/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp @@ -88,18 +88,37 @@ struct CKArgs output = {G, N, K, Do, Ho, Wo}; weight = {G, K, C, Z, Y, X}; - // miopen strides to CK strides - // On a backward pass, problem.GetIn() means y(or out), - // and problem.GetOut means x(or in) - auto miopen_in_strides = problem.GetOut().GetStrides(); - auto miopen_out_strides = problem.GetIn().GetStrides(); - auto miopen_wei_strides = problem.GetWeights().GetStrides(); - miopen_in_strides.insert(miopen_in_strides.begin(), C); - miopen_out_strides.insert(miopen_out_strides.begin(), K); - miopen_wei_strides.insert(miopen_wei_strides.begin(), K * miopen_wei_strides[0]); - std::copy(miopen_in_strides.begin(), miopen_in_strides.end(), in_strides.begin()); - std::copy(miopen_out_strides.begin(), miopen_out_strides.end(), out_strides.begin()); - std::copy(miopen_wei_strides.begin(), miopen_wei_strides.end(), wei_strides.begin()); + // CK strides are in GNCDHW order + if(problem.IsLayoutNHWC()) + { + // first entry reserved for G's stride + auto copy_strides = [](const auto& src, auto& dst) { + assert(dst.size() == (src.size() + 1)); + std::copy(src.begin(), src.end(), dst.begin() + 1); + }; + copy_strides(problem.GetIn().GetStrides(), in_strides); + copy_strides(problem.GetOut().GetStrides(), out_strides); + copy_strides(problem.GetWeights().GetStrides(), wei_strides); + + // On a backward pass, problem.GetIn() means y(or out), + // and problem.GetOut means x(or in) + /// \todo remove this when we stop swapping in and out tensors/descriptors + std::swap(in_strides, out_strides); + + // Now compute G's stride + in_strides[0] = C; + out_strides[0] = K; + wei_strides[0] = K * wei_strides[1]; + } + else + { + assert(problem.IsLayoutDefault()); // already checked in IsApplicable + // for default layout, we produce packed strides for NHWC layout + // because we transpose to NHWC layout before calling CK kernel + in_strides = {C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + out_strides = {K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + wei_strides = {K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + } strides = {ProblemInterpreter::GetAdjustedConvolutionStrideD(problem), ProblemInterpreter::GetAdjustedConvolutionStrideH(problem), @@ -121,28 +140,6 @@ struct CKArgs template auto MakeArgPtr(const ConvPtr& conv_ptr, ConstData_t x, Data_t dw, ConstData_t dy) const { -#if 0 // Leaving for debugging needs - std::cout << "y ptr = " << dy << std::endl; - std::cout << "w ptr = " << dw << std::endl; - std::cout << "x ptr = " << x << std::endl; - - auto print_vec = [](const char* name, const auto& vec) { - std::cout << name << " = [ "; - for(const auto& v : vec) - { - std::cout << v << ", "; - } - std::cout << "]\n"; - }; -#define PRINT_VEC(x) print_vec(#x, x); - - PRINT_VEC(output); - PRINT_VEC(out_strides); - PRINT_VEC(input); - PRINT_VEC(in_strides); - PRINT_VEC(weight); - PRINT_VEC(wei_strides); -#endif return conv_ptr->MakeArgumentPointer(x, dw, dy, @@ -314,6 +311,13 @@ bool ConvHipImplicitGemm3DGroupWrwXdlops::IsValidPerformanceConfig( return config.IsValid(problem); } +size_t +ConvHipImplicitGemm3DGroupWrwXdlops::GetWorkspaceSize(const ExecutionContext&, + const ProblemDescription& problem) const +{ + return GetWorkspaceSizeLayoutTransformConv(problem); +} + PerformanceConfigHipImplicitGemm3DGroupWrwXdlops ConvHipImplicitGemm3DGroupWrwXdlops::Search(const ExecutionContext& ctx, const ProblemDescription& problem, @@ -339,7 +343,10 @@ bool ConvHipImplicitGemm3DGroupWrwXdlops::IsApplicable( return false; if(!problem.Is3d()) return false; - if(!problem.IsLayoutNHWC()) + if(!(problem.IsLayoutNHWC() || problem.IsLayoutDefault())) + return false; + // needed because layout transpose kernel does not support non-packed tensors + if(problem.IsLayoutDefault() && problem.HasNonPackedTensors()) return false; if(!ck_utility::is_ck_whitelist(ctx.GetStream().GetDeviceName())) return false; @@ -364,29 +371,27 @@ ConvSolution ConvHipImplicitGemm3DGroupWrwXdlops::GetSolution( [[maybe_unused]] const PerformanceConfigHipImplicitGemm3DGroupWrwXdlops& config) const { #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL - switch(problem.GetInDataType()) - { - case miopenInt8: - return MakeInvokerFactory, CKArgs, miopen::conv::WrWInvokeParams>( - problem, config.kernel_id); - case miopenHalf: - return MakeInvokerFactory, - CKArgs, - miopen::conv::WrWInvokeParams>(problem, config.kernel_id); - case miopenFloat: - return MakeInvokerFactory, CKArgs, miopen::conv::WrWInvokeParams>( - problem, config.kernel_id); - case miopenInt32: - case miopenBFloat16: - case miopenFloat8: - case miopenBFloat8: - case miopenDouble: - default: - MIOPEN_THROW(miopenStatusInternalError, - "ConvHipImplicitGemmWrwXdlops operation not implemented for this data type"); - } -#endif + return MakeSolutionGroupConvImplicitGemmXdlops( + problem, + [&](auto data_type_val) { + using T = decltype(data_type_val); + return InitInvokerFactoryWrwNCHW<3, + DeviceOpGWrwPtrs, + CKArgs, + miopen::conv::WrWInvokeParams>( + ctx, problem, config.kernel_id); + }, + [&](auto data_type_val) { + using T = decltype(data_type_val); + return InitInvokerFactoryNHWC, + CKArgs, + miopen::conv::WrWInvokeParams>( + ctx, problem, config.kernel_id); + }); + +#else return {}; +#endif } } // namespace conv diff --git a/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp index 5b8f534ed8..f0b0ff266f 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp @@ -312,12 +312,15 @@ ConvSolution ConvHipImplicitGemmBwdXdlops::GetSolution( switch(problem.GetInDataType()) { case miopenHalf: - return MakeInvokerFactory, - CKArgs, - miopen::conv::DataInvokeParams>(problem, config.kernel_id); + return InitInvokerFactoryNHWC, + CKArgs, + miopen::conv::DataInvokeParams>( + ctx, problem, config.kernel_id); case miopenFloat: - return MakeInvokerFactory, CKArgs, miopen::conv::DataInvokeParams>( - problem, config.kernel_id); + return InitInvokerFactoryNHWC, + CKArgs, + miopen::conv::DataInvokeParams>( + ctx, problem, config.kernel_id); case miopenInt8: case miopenInt32: case miopenBFloat16: diff --git a/src/solver/conv_hip_implicit_gemm_f16f8f16_bwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_f16f8f16_bwd_xdlops.cpp index 976931f793..1098acc37e 100644 --- a/src/solver/conv_hip_implicit_gemm_f16f8f16_bwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_f16f8f16_bwd_xdlops.cpp @@ -330,9 +330,9 @@ ConvSolution ConvHipImplicitGemmF16F8F16BwdXdlops::GetSolution( [[maybe_unused]] const PerformanceConfigHipImplicitGemmF16F8F16BwdXdlops& config) const { #if MIOPEN_USE_COMPOSABLEKERNEL - return MakeInvokerFactory, - CKArgs, - miopen::conv::DataInvokeParams>(problem, config.kernel_id); + return InitInvokerFactoryNHWC, + CKArgs, + miopen::conv::DataInvokeParams>(ctx, problem, config.kernel_id); #else return {}; #endif diff --git a/src/solver/conv_hip_implicit_gemm_f16f8f16_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_f16f8f16_fwd_xdlops.cpp index 8b1491a24c..34380cfd72 100644 --- a/src/solver/conv_hip_implicit_gemm_f16f8f16_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_f16f8f16_fwd_xdlops.cpp @@ -327,9 +327,9 @@ ConvSolution ConvHipImplicitGemmF16F8F16FwdXdlops::GetSolution( [[maybe_unused]] const PerformanceConfigHipImplicitGemmF16F8F16FwdXdlops& config) const { #if MIOPEN_USE_COMPOSABLEKERNEL - return MakeInvokerFactory, - CKArgs, - miopen::conv::DataInvokeParams>(problem, config.kernel_id); + return InitInvokerFactoryNHWC, + CKArgs, + miopen::conv::DataInvokeParams>(ctx, problem, config.kernel_id); #else return {}; #endif diff --git a/src/solver/conv_hip_implicit_gemm_f16f8f16_wrw_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_f16f8f16_wrw_xdlops.cpp index 065267dbcc..721bdf9d2e 100644 --- a/src/solver/conv_hip_implicit_gemm_f16f8f16_wrw_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_f16f8f16_wrw_xdlops.cpp @@ -327,9 +327,9 @@ ConvSolution ConvHipImplicitGemmF16F8F16WrwXdlops::GetSolution( [[maybe_unused]] const PerformanceConfigHipImplicitGemmF16F8F16WrwXdlops& config) const { #if MIOPEN_USE_COMPOSABLEKERNEL - return MakeInvokerFactory, - CKArgs, - miopen::conv::WrWInvokeParams>(problem, config.kernel_id); + return InitInvokerFactoryNHWC, + CKArgs, + miopen::conv::WrWInvokeParams>(ctx, problem, config.kernel_id); #else return {}; #endif diff --git a/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp index 21aecc3234..45cecc99da 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp @@ -313,14 +313,16 @@ ConvSolution ConvHipImplicitGemmFwdXdlops::GetSolution( switch(problem.GetInDataType()) { case miopenInt8: - return MakeInvokerFactory, CKArgs, miopen::conv::DataInvokeParams>( - problem, config.kernel_id); + return InitInvokerFactoryNHWC, CKArgs, miopen::conv::DataInvokeParams>( + ctx, problem, config.kernel_id); case miopenHalf: - return MakeInvokerFactory, CKArgs, miopen::conv::DataInvokeParams>( - problem, config.kernel_id); + return InitInvokerFactoryNHWC, + CKArgs, + miopen::conv::DataInvokeParams>( + ctx, problem, config.kernel_id); case miopenFloat: - return MakeInvokerFactory, CKArgs, miopen::conv::DataInvokeParams>( - problem, config.kernel_id); + return InitInvokerFactoryNHWC, CKArgs, miopen::conv::DataInvokeParams>( + ctx, problem, config.kernel_id); case miopenInt32: case miopenBFloat16: case miopenDouble: diff --git a/src/solver/conv_hip_implicit_gemm_grouped_bwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_grouped_bwd_xdlops.cpp index ca6c86a537..34e19d75f7 100644 --- a/src/solver/conv_hip_implicit_gemm_grouped_bwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_grouped_bwd_xdlops.cpp @@ -87,18 +87,37 @@ struct CKArgs output = {G, N, K, Ho, Wo}; weight = {G, K, C, Y, X}; - // miopen strides to CK strides - // On a backward pass, problem.GetIn() means y(or out), - // and problem.GetOut means x(or in) - auto miopen_in_strides = problem.GetOut().GetStrides(); - auto miopen_out_strides = problem.GetIn().GetStrides(); - auto miopen_wei_strides = problem.GetWeights().GetStrides(); - miopen_in_strides.insert(miopen_in_strides.begin(), C); - miopen_out_strides.insert(miopen_out_strides.begin(), K); - miopen_wei_strides.insert(miopen_wei_strides.begin(), K * miopen_wei_strides[0]); - std::copy(miopen_in_strides.begin(), miopen_in_strides.end(), in_strides.begin()); - std::copy(miopen_out_strides.begin(), miopen_out_strides.end(), out_strides.begin()); - std::copy(miopen_wei_strides.begin(), miopen_wei_strides.end(), wei_strides.begin()); + // CK strides are in GNCDHW order + if(problem.IsLayoutNHWC()) + { + // first entry reserved for G's stride + auto copy_strides = [](const auto& src, auto& dst) { + assert(dst.size() == (src.size() + 1)); + std::copy(src.begin(), src.end(), dst.begin() + 1); + }; + copy_strides(problem.GetIn().GetStrides(), in_strides); + copy_strides(problem.GetOut().GetStrides(), out_strides); + copy_strides(problem.GetWeights().GetStrides(), wei_strides); + + // On a backward pass, problem.GetIn() means y(or out), + // and problem.GetOut means x(or in) + /// \todo remove this when we stop swapping in and out tensors/descriptors + std::swap(in_strides, out_strides); + + // Now compute G's stride + in_strides[0] = C; + out_strides[0] = K; + wei_strides[0] = K * wei_strides[1]; + } + else + { + assert(problem.IsLayoutDefault()); // already checked in IsApplicable + // for default layout, we produce packed strides for NHWC layout + // because we transpose to NHWC layout before calling CK kernel + in_strides = {C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; + out_strides = {K, Ho * Wo * G * K, 1, Wo * G * K, G * K}; + wei_strides = {K * Y * X * C, Y * X * C, 1, X * C, C}; + } strides = {ProblemInterpreter::GetAdjustedConvolutionStrideH(problem), ProblemInterpreter::GetAdjustedConvolutionStrideW(problem)}; @@ -115,7 +134,7 @@ struct CKArgs CKArgs& operator=(const CKArgs&) = default; template - auto MakeArgPtr(const ConvPtr& conv_ptr, ConstData_t out, ConstData_t w, Data_t in) const + auto MakeArgPtr(const ConvPtr& conv_ptr, Data_t in, ConstData_t w, ConstData_t out) const { return conv_ptr->MakeArgumentPointer(out, w, @@ -141,7 +160,7 @@ struct CKArgs template auto MakeArgPtr(const ConvPtr& conv_ptr, const ConvDataTensors& tensors) const { - return MakeArgPtr(conv_ptr, tensors.in, tensors.w, tensors.out); + return MakeArgPtr(conv_ptr, tensors.out, tensors.w, tensors.in); } template @@ -285,6 +304,12 @@ bool ConvHipImplicitGemmGroupBwdXdlops::IsValidPerformanceConfig( return config.IsValid(problem); } +size_t ConvHipImplicitGemmGroupBwdXdlops::GetWorkspaceSize(const ExecutionContext&, + const ProblemDescription& problem) const +{ + return GetWorkspaceSizeLayoutTransformConv(problem); +} + PerformanceConfigHipImplicitGemmGroupBwdXdlops ConvHipImplicitGemmGroupBwdXdlops::Search(const ExecutionContext& ctx, const ProblemDescription& problem, @@ -304,8 +329,6 @@ bool ConvHipImplicitGemmGroupBwdXdlops::IsApplicable( return false; if(problem.HasMixedDataTypes()) return false; - if(problem.HasNonPackedTensors()) - return false; if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsTensorsCasted()) @@ -314,7 +337,10 @@ bool ConvHipImplicitGemmGroupBwdXdlops::IsApplicable( return false; if(!problem.Is2d()) return false; - if(!problem.IsLayoutNHWC()) + if(!(problem.IsLayoutNHWC() || problem.IsLayoutDefault())) + return false; + // needed because layout transpose kernel does not support non-packed tensors + if(problem.IsLayoutDefault() && problem.HasNonPackedTensors()) return false; if(!ck_utility::is_ck_whitelist(ctx.GetStream().GetDeviceName())) return false; @@ -339,29 +365,27 @@ ConvSolution ConvHipImplicitGemmGroupBwdXdlops::GetSolution( [[maybe_unused]] const PerformanceConfigHipImplicitGemmGroupBwdXdlops& config) const { #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL - switch(problem.GetInDataType()) - { - case miopenInt8: - return MakeInvokerFactory, CKArgs, miopen::conv::DataInvokeParams>( - problem, config.kernel_id); - case miopenHalf: - return MakeInvokerFactory, - CKArgs, - miopen::conv::DataInvokeParams>(problem, config.kernel_id); - case miopenFloat: - return MakeInvokerFactory, CKArgs, miopen::conv::DataInvokeParams>( - problem, config.kernel_id); - case miopenInt32: - case miopenBFloat16: - case miopenDouble: - case miopenFloat8: - case miopenBFloat8: - default: - MIOPEN_THROW(miopenStatusInternalError, - "ConvHipImplicitGemmBwdXdlops operation not implemented for this data type"); - } -#endif + return MakeSolutionGroupConvImplicitGemmXdlops( + problem, + [&](auto data_type_val) { + using T = decltype(data_type_val); + return InitInvokerFactoryBwdNCHW<2, + DeviceOpGBwdPtrs, + CKArgs, + miopen::conv::DataInvokeParams>( + ctx, problem, config.kernel_id); + }, + [&](auto data_type_val) { + using T = decltype(data_type_val); + return InitInvokerFactoryNHWC, + CKArgs, + miopen::conv::DataInvokeParams>( + ctx, problem, config.kernel_id); + }); + +#else return {}; +#endif } } // namespace conv diff --git a/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp index fa1f2a8e21..3c1ea8fbeb 100644 --- a/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp @@ -144,10 +144,10 @@ struct CKArgs int G; int N; + int K1; + int C1; int K; int C; - int C1; - int K1; int Hi; int Wi; int Ho; @@ -414,6 +414,12 @@ bool ConvHipImplicitGemmGroupFwdXdlops::IsValidPerformanceConfig( return config.IsValid(problem); } +size_t ConvHipImplicitGemmGroupFwdXdlops::GetWorkspaceSize(const ExecutionContext&, + const ProblemDescription& problem) const +{ + return GetWorkspaceSizeLayoutTransformConv(problem); +} + PerformanceConfigHipImplicitGemmGroupFwdXdlops ConvHipImplicitGemmGroupFwdXdlops::Search(const ExecutionContext& ctx, const ProblemDescription& problem, @@ -443,7 +449,10 @@ bool ConvHipImplicitGemmGroupFwdXdlops::IsApplicable( return false; if(!problem.Is2d()) return false; - if(!problem.IsLayoutNHWC()) + if(!(problem.IsLayoutNHWC() || problem.IsLayoutDefault())) + return false; + // needed because layout transpose kernel does not support non-packed tensors + if(problem.IsLayoutDefault() && problem.HasNonPackedTensors()) return false; const std::string& arch = ctx.GetStream().GetDeviceName(); if(!(arch == "gfx908" || arch == "gfx90a")) @@ -469,29 +478,26 @@ ConvSolution ConvHipImplicitGemmGroupFwdXdlops::GetSolution( [[maybe_unused]] const PerformanceConfigHipImplicitGemmGroupFwdXdlops& config) const { #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL - switch(problem.GetInDataType()) - { - case miopenHalf: - return MakeInvokerFactory, - CKArgs, - miopen::conv::DataInvokeParams>(problem, config.kernel_id); - case miopenFloat: - return MakeInvokerFactory, CKArgs, miopen::conv::DataInvokeParams>( - problem, config.kernel_id); - case miopenInt8: - return MakeInvokerFactory, CKArgs, miopen::conv::DataInvokeParams>( - problem, config.kernel_id); - case miopenInt32: - case miopenBFloat16: - case miopenDouble: - case miopenFloat8: - case miopenBFloat8: - default: - MIOPEN_THROW(miopenStatusInternalError, - "ConvHipImplicitGemmFwdXdlops operation not implemented for this data type"); - } -#endif + return MakeSolutionGroupConvImplicitGemmXdlops( + problem, + [&](auto data_type_val) { + using T = decltype(data_type_val); + return InitInvokerFactoryFwdNCHW<2, + DeviceOpGFwdPtrs, + CKArgs, + miopen::conv::DataInvokeParams>( + ctx, problem, config.kernel_id); + }, + [&](auto data_type_val) { + using T = decltype(data_type_val); + return InitInvokerFactoryNHWC, + CKArgs, + miopen::conv::DataInvokeParams>( + ctx, problem, config.kernel_id); + }); +#else return {}; +#endif } } // namespace conv diff --git a/src/solver/conv_hip_implicit_gemm_grouped_wrw_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_grouped_wrw_xdlops.cpp index 741744ad67..5f8a805027 100644 --- a/src/solver/conv_hip_implicit_gemm_grouped_wrw_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_grouped_wrw_xdlops.cpp @@ -85,18 +85,37 @@ struct CKArgs output = {G, N, K, Ho, Wo}; weight = {G, K, C, Y, X}; - // miopen strides to CK strides - // On a backward pass, problem.GetIn() means y(or out), - // and problem.GetOut means x(or in) - auto miopen_in_strides = problem.GetOut().GetStrides(); - auto miopen_out_strides = problem.GetIn().GetStrides(); - auto miopen_wei_strides = problem.GetWeights().GetStrides(); - miopen_in_strides.insert(miopen_in_strides.begin(), C); - miopen_out_strides.insert(miopen_out_strides.begin(), K); - miopen_wei_strides.insert(miopen_wei_strides.begin(), K * miopen_wei_strides[0]); - std::copy(miopen_in_strides.begin(), miopen_in_strides.end(), in_strides.begin()); - std::copy(miopen_out_strides.begin(), miopen_out_strides.end(), out_strides.begin()); - std::copy(miopen_wei_strides.begin(), miopen_wei_strides.end(), wei_strides.begin()); + // CK strides are in GNCDHW order + if(problem.IsLayoutNHWC()) + { + // first entry reserved for G's stride + auto copy_strides = [](const auto& src, auto& dst) { + assert(dst.size() == (src.size() + 1)); + std::copy(src.begin(), src.end(), dst.begin() + 1); + }; + copy_strides(problem.GetIn().GetStrides(), in_strides); + copy_strides(problem.GetOut().GetStrides(), out_strides); + copy_strides(problem.GetWeights().GetStrides(), wei_strides); + + // On a backward pass, problem.GetIn() means y(or out), + // and problem.GetOut means x(or in) + /// \todo remove this when we stop swapping in and out tensors/descriptors + std::swap(in_strides, out_strides); + + // Now compute G's stride + in_strides[0] = C; + out_strides[0] = K; + wei_strides[0] = K * wei_strides[1]; + } + else + { + assert(problem.IsLayoutDefault()); // already checked in IsApplicable + // for default layout, we produce packed strides for NHWC layout + // because we transpose to NHWC layout before calling CK kernel + in_strides = {C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; + out_strides = {K, Ho * Wo * G * K, 1, Wo * G * K, G * K}; + wei_strides = {K * Y * X * C, Y * X * C, 1, X * C, C}; + } strides = {ProblemInterpreter::GetAdjustedConvolutionStrideH(problem), ProblemInterpreter::GetAdjustedConvolutionStrideW(problem)}; @@ -281,6 +300,12 @@ bool ConvHipImplicitGemmGroupWrwXdlops::IsValidPerformanceConfig( return config.IsValid(problem); } +size_t ConvHipImplicitGemmGroupWrwXdlops::GetWorkspaceSize(const ExecutionContext&, + const ProblemDescription& problem) const +{ + return GetWorkspaceSizeLayoutTransformConv(problem); +} + PerformanceConfigHipImplicitGemmGroupWrwXdlops ConvHipImplicitGemmGroupWrwXdlops::Search(const ExecutionContext& ctx, const ProblemDescription& problem, @@ -300,15 +325,16 @@ bool ConvHipImplicitGemmGroupWrwXdlops::IsApplicable( return false; if(problem.HasMixedDataTypes()) return false; - if(problem.HasNonPackedTensors()) - return false; if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!problem.IsDirectionBackwardWrW()) return false; if(!problem.Is2d()) return false; - if(!problem.IsLayoutNHWC()) + if(!(problem.IsLayoutNHWC() || problem.IsLayoutDefault())) + return false; + // needed because layout transpose kernel does not support non-packed tensors + if(problem.IsLayoutDefault() && problem.HasNonPackedTensors()) return false; if(!ck_utility::is_ck_whitelist(ctx.GetStream().GetDeviceName())) return false; @@ -333,29 +359,27 @@ ConvSolution ConvHipImplicitGemmGroupWrwXdlops::GetSolution( [[maybe_unused]] const PerformanceConfigHipImplicitGemmGroupWrwXdlops& config) const { #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL - switch(problem.GetInDataType()) - { - case miopenInt8: - return MakeInvokerFactory, CKArgs, miopen::conv::WrWInvokeParams>( - problem, config.kernel_id); - case miopenHalf: - return MakeInvokerFactory, - CKArgs, - miopen::conv::WrWInvokeParams>(problem, config.kernel_id); - case miopenFloat: - return MakeInvokerFactory, CKArgs, miopen::conv::WrWInvokeParams>( - problem, config.kernel_id); - case miopenInt32: - case miopenBFloat16: - case miopenFloat8: - case miopenBFloat8: - case miopenDouble: - default: - MIOPEN_THROW(miopenStatusInternalError, - "ConvHipImplicitGemmWrwXdlops operation not implemented for this data type"); - } -#endif + return MakeSolutionGroupConvImplicitGemmXdlops( + problem, + [&](auto data_type_val) { + using T = decltype(data_type_val); + return InitInvokerFactoryWrwNCHW<2, + DeviceOpGWrwPtrs, + CKArgs, + miopen::conv::WrWInvokeParams>( + ctx, problem, config.kernel_id); + }, + [&](auto data_type_val) { + using T = decltype(data_type_val); + return InitInvokerFactoryNHWC, + CKArgs, + miopen::conv::WrWInvokeParams>( + ctx, problem, config.kernel_id); + }); + +#else return {}; +#endif } } // namespace conv diff --git a/test/gpu_nchw_nhwc_transpose.cpp b/test/gpu_nchw_nhwc_transpose.cpp index 6c0dbf5fb7..a412ab7992 100644 --- a/test/gpu_nchw_nhwc_transpose.cpp +++ b/test/gpu_nchw_nhwc_transpose.cpp @@ -26,7 +26,6 @@ #include #include -#include #include #include #include diff --git a/test/gtest/group_conv.hpp b/test/gtest/group_conv.hpp index 789fc1cab4..74b3da6d6b 100644 --- a/test/gtest/group_conv.hpp +++ b/test/gtest/group_conv.hpp @@ -38,19 +38,6 @@ namespace group_conv { using Direction = miopen::conv::Direction; -// Works by detecting if Solver has a method named GetWorkSpaceSize -template -struct NeedsWorkspace -{ - constexpr static bool value = false; -}; - -template -struct NeedsWorkspace -{ - constexpr static bool value = true; -}; - template struct GroupConvTestConfig { @@ -338,12 +325,9 @@ struct GroupConvTestFix GTEST_SKIP() << solv.SolverDbId() << "Not Applicable for this problem" << conv_config; } - if constexpr(NeedsWorkspace::value) + if(solv.MayNeedWorkspace()) { - if(solv.MayNeedWorkspace()) - { - wspace.resize(solv.GetWorkSpaceSize(ctx, problem)); - } + wspace.resize(solv.GetWorkspaceSize(ctx, problem)); } const auto invoke_params = InvokeParamType{tensors, wspace.ptr(), wspace.size(), false}; @@ -503,11 +487,11 @@ std::vector GetLayoutValues() static_assert(NDIM == 2u || NDIM == 3u); if constexpr(NDIM == 2u) { - return {miopenTensorNHWC}; + return {miopenTensorNHWC, miopenTensorNCHW}; } else { - return {miopenTensorNDHWC}; + return {miopenTensorNDHWC, miopenTensorNCDHW}; } }