diff --git a/src/include/miopen/solver/implicitgemm_ck_util.hpp b/src/include/miopen/solver/implicitgemm_ck_util.hpp index 2199d88b01..efb0d16b96 100644 --- a/src/include/miopen/solver/implicitgemm_ck_util.hpp +++ b/src/include/miopen/solver/implicitgemm_ck_util.hpp @@ -29,6 +29,7 @@ #include #include #include +#include #if MIOPEN_USE_COMPOSABLEKERNEL #include @@ -96,6 +97,31 @@ bool IsCKApplicable(const ProblemDescriptionType& problem) ptrs.begin(), ptrs.end(), [&args](auto& ptr) { return args.IsSupportedBy(ptr); }); } +#define WORKAROUND_CK_ISSUE_1184 1 +#if WORKAROUND_CK_ISSUE_1184 +struct HipEventProfiler +{ + const Handle& handle; + float event_time; + HipEventPtr start; + HipEventPtr stop; + + HipEventProfiler(const Handle& handle_) + : handle(handle_), event_time(0.0f), start(make_hip_event()), stop(make_hip_event()) + { + hipEventRecord(start.get(), handle.GetStream()); + } + ~HipEventProfiler() + { + hipEventRecord(stop.get(), handle.GetStream()); + hipEventSynchronize(stop.get()); + hipEventElapsedTime(&event_time, start.get(), stop.get()); + handle.ResetKernelTime(); + handle.AccumKernelTime(event_time); + } +}; +#endif + template (); auto argument_ptr = ck_args.MakeArgPtr(sh_conv_ptr, data_ctx.tensors); auto invoker_ptr = sh_conv_ptr->MakeInvokerPointer(); - - const auto enable_profiling = handle.IsProfilingEnabled(); - float elapsed_time = - invoker_ptr->Run(argument_ptr.get(), {handle.GetStream(), enable_profiling}); - if(enable_profiling) { - handle.ResetKernelTime(); - handle.AccumKernelTime(elapsed_time); + HipEventProfiler pfr(handle); + if constexpr(std::is_same::value) + { + auto zero = 0.0f; + const auto& tensors = data_ctx.tensors; + SetTensor(handle, tensors.dwDesc, tensors.dw, &zero); + } + invoker_ptr->Run(argument_ptr.get(), {handle.GetStream(), false}); } }; }; @@ -605,15 +632,14 @@ ConvSolution InitInvokerFactoryNCHW(const ExecutionContext& ctx, std::swap(conv_tensors.x, conv_tensors.y); std::swap(conv_tensors.xDesc, conv_tensors.yDesc); } - + HipEventProfiler pfr(handle); input1_tr_inst.ConvertFrom(handle, kernels, conv_tensors); input2_tr_inst.ConvertFrom(handle, kernels, conv_tensors); output_init_tr_inst.ConvertFrom(handle, kernels, conv_tensors); - /// \todo: Fix NHWC Wrw invokers to also issue a zero-out kernel. Will - /// need SetTensor() to properly zero out non-packed tensors + /// \todo: Will need SetTensor() to properly zero out non-packed tensors if(output_tr_inst.GetConvOperandTag() == internal::ConvOperandTag::Weights) { output_tr_inst.ZeroOutBuffer(); @@ -632,15 +658,7 @@ ConvSolution InitInvokerFactoryNCHW(const ExecutionContext& ctx, tr_ptrs[0]->GetBufferPtr(), tr_ptrs[1]->GetBufferPtr(), tr_ptrs[2]->GetBufferPtr()); - float conv_time = 0; - conv_time += invoker_ptr->Run(argument_ptr.get(), - {handle.GetStream(), handle.IsProfilingEnabled()}); - - if(handle.IsProfilingEnabled()) - { - handle.AccumKernelTime(conv_time); - } - + invoker_ptr->Run(argument_ptr.get(), {handle.GetStream(), false}); output_tr_inst.ConvertTo(handle, kernels, conv_tensors); }; };