Skip to content

Commit

Permalink
[Improvements] Replace HasAtLeastOne64BitTensor() with AllTensorsDims…
Browse files Browse the repository at this point in the history
…FitIntoInt() (#2731)
  • Loading branch information
averinevg authored Feb 7, 2024
1 parent 365d183 commit 5c7617c
Show file tree
Hide file tree
Showing 74 changed files with 84 additions and 80 deletions.
4 changes: 2 additions & 2 deletions src/include/miopen/conv/problem_description.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,9 @@ struct ProblemDescription : ProblemDescriptionBase
GetWeightsDataType() == GetOutDataType());
}

bool HasAtLeastOne64BitTensor() const
bool AllTensorsDimsFitIntoInt() const
{
return in.Is64Bit() || weights.Is64Bit() || out.Is64Bit();
return in.AllDimsFitIntoInt() && weights.AllDimsFitIntoInt() && out.AllDimsFitIntoInt();
}

void HeuristicUpdateLayouts();
Expand Down
2 changes: 1 addition & 1 deletion src/include/miopen/fusion/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ inline bool WinoCommonIsApplicable(const FusionContext& context, const FusionDes
return false;
if(conv_problem.HasNonPackedTensors())
return false;
if(conv_problem.HasAtLeastOne64BitTensor())
if(!conv_problem.AllTensorsDimsFitIntoInt())
return false;
if(!conv_problem.IsLayoutDefault())
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/include/miopen/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ struct MIOPEN_EXPORT TensorDescriptor : miopenTensorDescriptor
}

bool IsPacked() const;
bool Is64Bit() const;
bool AllDimsFitIntoInt() const;

bool operator==(const TensorDescriptor& rhs) const;
bool operator!=(const TensorDescriptor& rhs) const;
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_MP_bidirectional_winograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ bool ConvMPBidirectWinograd<WinoDataH, WinoFilterH, WinoDataW, WinoFilterW>::IsA
if(problem.HasNonPackedTensors())
return false;

if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;

if(!problem.IsLayoutDefault())
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_asm_1x1u.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ bool ConvAsm1x1U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip
return false;
if(problem.HasNonPackedTensors())
return false;
if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;
if(!(problem.IsDirectionForward() || problem.IsDirectionBackwardData()))
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_asm_1x1u_stride2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ bool ConvAsm1x1UV2::IsApplicable(const ExecutionContext& ctx,
return false;
if(problem.HasNonPackedTensors())
return false;
if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;
if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW())
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_asm_3x3u.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ bool ConvAsm3x3U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip
return false;
if(problem.HasNonPackedTensors())
return false;
if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;
if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW())
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_asm_5x10u2v2b1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ bool ConvAsm5x10u2v2b1::IsApplicable(const ExecutionContext& ctx,
return false;
if(problem.HasNonPackedTensors())
return false;
if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;
if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW())
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_asm_5x10u2v2f1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ bool ConvAsm5x10u2v2f1::IsApplicable(const ExecutionContext& ctx,
return false;
if(problem.HasNonPackedTensors())
return false;
if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;
if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW())
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ bool ConvAsm7x7c3h224w224k64u2v2p3q3f1::IsApplicable(const ExecutionContext& ctx

if(problem.HasNonPackedTensors())
return false;
if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;

if(problem.IsTensorsCasted())
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_asm_dir_BwdWrW1x1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ bool ConvAsmBwdWrW1x1::IsApplicable(const ExecutionContext& ctx,
return false;
if(problem.HasNonPackedTensors())
return false;
if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;
if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW())
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_asm_dir_BwdWrW3x3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ bool ConvAsmBwdWrW3x3::IsApplicable(const ExecutionContext& ctx,
return false;
if(problem.HasNonPackedTensors())
return false;
if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;
if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW())
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ bool ConvAsmImplicitGemmV4R1DynamicBwd::IsApplicable(const ExecutionContext& ctx
if(problem.HasNonPackedTensors())
return false;

if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;

if(!problem.Is2d())
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -998,7 +998,7 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlops::IsApplicable(const ExecutionContext
if(problem.HasNonPackedTensors())
return false;

if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;

if(!problem.IsFp32() && !problem.IsFp16())
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,7 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::IsApplicable(
if(problem.HasNonPackedTensors())
return false;

if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;

if(!problem.IsFp32() && !problem.IsFp16() &&
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1523,7 +1523,7 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlops::IsApplicable(const ExecutionContext
if(problem.HasNonPackedTensors())
return false;

if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;

if(!problem.IsFp32() && !problem.IsFp16())
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ bool ConvAsmImplicitGemmGTCDynamicFwdDlopsNCHWC::IsApplicable(
if(problem.HasNonPackedTensors())
return false;

if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;

if(!problem.IsLayoutNCHWc())
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,7 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::IsApplicable(
if(problem.HasNonPackedTensors())
return false;

if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;

if(!problem.IsFp32() && !problem.IsFp16() &&
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::IsApplicable(
if(problem.HasNonPackedTensors())
return false;

if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;

if(!problem.IsFp32() && !problem.IsFp16() &&
Expand Down
4 changes: 2 additions & 2 deletions src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ bool ConvAsmImplicitGemmV4R1DynamicFwd::IsApplicable(const ExecutionContext& ctx
if(problem.HasNonPackedTensors())
return false;

if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;

if(!problem.IsFp32())
Expand Down Expand Up @@ -348,7 +348,7 @@ bool ConvAsmImplicitGemmV4R1DynamicFwd_1x1::IsApplicable(const ExecutionContext&
if(!problem.Is2d())
return false;

if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;

if(!problem.IsFp32())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlops::IsApplicable(const ExecutionContext
if(problem.HasNonPackedTensors())
return false;

if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;

if(!problem.IsFp32() && !problem.IsFp16())
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ bool ConvAsmImplicitGemmV4R1DynamicWrw::IsApplicable(const ExecutionContext& ctx
if(problem.HasNonPackedTensors())
return false;

if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;

if(problem.IsTensorsCasted())
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_bin_wino3x3U.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ bool ConvBinWinograd3x3U::IsApplicable(const ExecutionContext& ctx,

if(problem.HasNonPackedTensors())
return false;
if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;

if(!problem.IsLayoutDefault())
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_bin_winoRxS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ bool ConvBinWinogradRxS::IsApplicable(const ExecutionContext& ctx,
return false;
if(problem.HasNonPackedTensors())
return false;
if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;
if(problem.IsTensorsCasted())
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ bool ConvCKIgemmFwdBiasActivFused::IsApplicable(const FusionContext& ctx,
return false;
if(conv_problem.HasNonPackedTensors())
return false;
if(conv_problem.HasAtLeastOne64BitTensor())
if(!conv_problem.AllTensorsDimsFitIntoInt())
return false;
if(conv_problem.HasMixedDataTypes())
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_ck_igemm_fwd_bias_res_add_activ_fused.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ bool ConvCKIgemmFwdBiasResAddActivFused::IsApplicable(const FusionContext& ctx,
return false;
if(conv_problem.HasNonPackedTensors())
return false;
if(conv_problem.HasAtLeastOne64BitTensor())
if(!conv_problem.AllTensorsDimsFitIntoInt())
return false;
if(conv_problem.HasMixedDataTypes())
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ bool ConvCkIgemmFwdV6r1DlopsNchw::IsApplicable(const ExecutionContext& ctx,
return false;
if(problem.HasNonPackedTensors())
return false;
if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;
if(problem.IsTensorsCasted())
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_direct_naive_conv_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ bool ConvDirectNaiveConvBwd::IsApplicable(const ExecutionContext& ctx,

if(!problem.IsDirectionBackwardData())
return false;
if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;
if(!problem.IsLayoutDefault() && !problem.IsLayoutNHWC())
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_direct_naive_conv_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ bool ConvDirectNaiveConvFwd::IsApplicable(const ExecutionContext& ctx,
if(!problem.IsDirectionForward())
return false;

if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;

if(problem.IsTensorsCasted())
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_direct_naive_conv_wrw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ bool ConvDirectNaiveConvWrw::IsApplicable(const ExecutionContext& ctx,

if(!problem.IsDirectionBackwardWrW())
return false;
if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;
if(problem.IsTensorsCasted())
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ bool ConvHipImplicitGemm3DGroupBwdXdlops::IsApplicable(
return false;
if(miopen::IsEnabled(ENV(MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC)))
return false;
if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;
if(problem.HasMixedDataTypes())
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ bool ConvHipImplicitGemm3DGroupFwdXdlops::IsApplicable(
return false;
if(problem.GetConv().attribute.deterministic)
return false;
if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;
if(problem.HasMixedDataTypes())
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ bool ConvHipImplicitGemm3DGroupWrwXdlops::IsApplicable(
return false;
if(miopen::IsEnabled(ENV(MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC)))
return false;
if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;
if(problem.HasMixedDataTypes())
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ bool ConvHipImplicitGemmBwdXdlops::IsApplicable(
return false;
if(problem.HasNonPackedTensors())
return false;
if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;
if(problem.HasMixedDataTypes())
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ bool ConvHipImplicitGemmBwdDataV1R1::IsApplicable(const ExecutionContext& ctx,
return false;
if(problem.HasNonPackedTensors())
return false;
if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;
if(!problem.IsLayoutDefault())
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ bool ConvHipImplicitGemmBwdDataV1R1Xdlops::IsApplicable(const ExecutionContext&
if(problem.HasNonPackedTensors())
return false;

if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;

if(problem.IsTensorsCasted())
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ bool ConvHipImplicitGemmBwdDataV4R1::IsApplicable(const ExecutionContext& ctx,
if(problem.HasNonPackedTensors())
return false;

if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;

if(problem.IsTensorsCasted())
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ bool ConvHipImplicitGemmBwdDataV4R1Xdlops::IsApplicable(const ExecutionContext&
return false;
if(problem.HasNonPackedTensors())
return false;
if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;
if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16()))
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_hip_implicit_gemm_f16f8f16_bwd_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ bool ConvHipImplicitGemmF16F8F16BwdXdlops::IsApplicable(
return false;
if(problem.GetConv().attribute.deterministic)
return false;
if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;
if(problem.HasMixedDataTypes())
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_hip_implicit_gemm_f16f8f16_fwd_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ bool ConvHipImplicitGemmF16F8F16FwdXdlops::IsApplicable(
return false;
if(problem.HasNonPackedTensors())
return false;
if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;
if(!problem.IsTensorsCasted())
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_hip_implicit_gemm_f16f8f16_wrw_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ bool ConvHipImplicitGemmF16F8F16WrwXdlops::IsApplicable(
return false;
if(problem.GetConv().attribute.deterministic)
return false;
if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;
if(problem.HasMixedDataTypes())
return false;
Expand Down
4 changes: 2 additions & 2 deletions src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ bool ConvHipImplicitGemmV4R1Fwd::IsApplicable(const ExecutionContext& ctx,
return false;
if(problem.HasNonPackedTensors())
return false;
if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;
if(!problem.IsFp32() && !problem.IsFp16() && !problem.IsBfp16())
return false;
Expand Down Expand Up @@ -104,7 +104,7 @@ bool ConvHipImplicitGemmV4R1WrW::IsApplicable(const ExecutionContext& ctx,
return false;
if(!problem.IsDirectionBackwardWrW())
return false;
if(problem.HasAtLeastOne64BitTensor())
if(!problem.AllTensorsDimsFitIntoInt())
return false;
if(!ctx.use_hip_kernels)
return false;
Expand Down
Loading

0 comments on commit 5c7617c

Please sign in to comment.