diff --git a/src/include/miopen/conv/problem_description.hpp b/src/include/miopen/conv/problem_description.hpp index 0b00199d6b..7812828a73 100644 --- a/src/include/miopen/conv/problem_description.hpp +++ b/src/include/miopen/conv/problem_description.hpp @@ -391,9 +391,9 @@ struct ProblemDescription : ProblemDescriptionBase GetWeightsDataType() == GetOutDataType()); } - bool HasAtLeastOne64BitTensor() const + bool AllTensorsDimsFitIntoInt() const { - return in.Is64Bit() || weights.Is64Bit() || out.Is64Bit(); + return in.AllDimsFitIntoInt() && weights.AllDimsFitIntoInt() && out.AllDimsFitIntoInt(); } void HeuristicUpdateLayouts(); diff --git a/src/include/miopen/fusion/utils.hpp b/src/include/miopen/fusion/utils.hpp index e11fcbfa1a..38e1d69789 100644 --- a/src/include/miopen/fusion/utils.hpp +++ b/src/include/miopen/fusion/utils.hpp @@ -86,7 +86,7 @@ inline bool WinoCommonIsApplicable(const FusionContext& context, const FusionDes return false; if(conv_problem.HasNonPackedTensors()) return false; - if(conv_problem.HasAtLeastOne64BitTensor()) + if(!conv_problem.AllTensorsDimsFitIntoInt()) return false; if(!conv_problem.IsLayoutDefault()) return false; diff --git a/src/include/miopen/tensor.hpp b/src/include/miopen/tensor.hpp index d0713470ed..0291617891 100644 --- a/src/include/miopen/tensor.hpp +++ b/src/include/miopen/tensor.hpp @@ -205,7 +205,7 @@ struct MIOPEN_EXPORT TensorDescriptor : miopenTensorDescriptor } bool IsPacked() const; - bool Is64Bit() const; + bool AllDimsFitIntoInt() const; bool operator==(const TensorDescriptor& rhs) const; bool operator!=(const TensorDescriptor& rhs) const; diff --git a/src/solver/conv_MP_bidirectional_winograd.cpp b/src/solver/conv_MP_bidirectional_winograd.cpp index a3282e7f71..410161df0a 100644 --- a/src/solver/conv_MP_bidirectional_winograd.cpp +++ b/src/solver/conv_MP_bidirectional_winograd.cpp @@ -327,7 +327,7 @@ bool ConvMPBidirectWinograd::IsA if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!problem.IsLayoutDefault()) diff --git a/src/solver/conv_asm_1x1u.cpp b/src/solver/conv_asm_1x1u.cpp index b337b71a04..6405e4c4c7 100644 --- a/src/solver/conv_asm_1x1u.cpp +++ b/src/solver/conv_asm_1x1u.cpp @@ -530,7 +530,7 @@ bool ConvAsm1x1U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!(problem.IsDirectionForward() || problem.IsDirectionBackwardData())) return false; diff --git a/src/solver/conv_asm_1x1u_stride2.cpp b/src/solver/conv_asm_1x1u_stride2.cpp index dfe525e034..698951dc28 100644 --- a/src/solver/conv_asm_1x1u_stride2.cpp +++ b/src/solver/conv_asm_1x1u_stride2.cpp @@ -494,7 +494,7 @@ bool ConvAsm1x1UV2::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; diff --git a/src/solver/conv_asm_3x3u.cpp b/src/solver/conv_asm_3x3u.cpp index 07fb601a85..b10733aacd 100644 --- a/src/solver/conv_asm_3x3u.cpp +++ b/src/solver/conv_asm_3x3u.cpp @@ -183,7 +183,7 @@ bool ConvAsm3x3U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; diff --git a/src/solver/conv_asm_5x10u2v2b1.cpp b/src/solver/conv_asm_5x10u2v2b1.cpp index 3b229f66a4..a523caa8a6 100644 --- a/src/solver/conv_asm_5x10u2v2b1.cpp +++ b/src/solver/conv_asm_5x10u2v2b1.cpp @@ -50,7 +50,7 @@ bool ConvAsm5x10u2v2b1::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; diff --git a/src/solver/conv_asm_5x10u2v2f1.cpp b/src/solver/conv_asm_5x10u2v2f1.cpp index 101099c937..618b66b58c 100644 --- a/src/solver/conv_asm_5x10u2v2f1.cpp +++ b/src/solver/conv_asm_5x10u2v2f1.cpp @@ -51,7 +51,7 @@ bool ConvAsm5x10u2v2f1::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; diff --git a/src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp b/src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp index 6d559bb2a8..857fe60cc6 100644 --- a/src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp +++ b/src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp @@ -56,7 +56,7 @@ bool ConvAsm7x7c3h224w224k64u2v2p3q3f1::IsApplicable(const ExecutionContext& ctx if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsTensorsCasted()) diff --git a/src/solver/conv_asm_dir_BwdWrW1x1.cpp b/src/solver/conv_asm_dir_BwdWrW1x1.cpp index ddcc93f90a..1ae6ea4002 100644 --- a/src/solver/conv_asm_dir_BwdWrW1x1.cpp +++ b/src/solver/conv_asm_dir_BwdWrW1x1.cpp @@ -487,7 +487,7 @@ bool ConvAsmBwdWrW1x1::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; diff --git a/src/solver/conv_asm_dir_BwdWrW3x3.cpp b/src/solver/conv_asm_dir_BwdWrW3x3.cpp index 0e2e5ae1f6..5c8e58305d 100644 --- a/src/solver/conv_asm_dir_BwdWrW3x3.cpp +++ b/src/solver/conv_asm_dir_BwdWrW3x3.cpp @@ -403,7 +403,7 @@ bool ConvAsmBwdWrW3x3::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; diff --git a/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp b/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp index c84679cd98..f97262819f 100644 --- a/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp +++ b/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp @@ -149,7 +149,7 @@ bool ConvAsmImplicitGemmV4R1DynamicBwd::IsApplicable(const ExecutionContext& ctx if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!problem.Is2d()) diff --git a/src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp b/src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp index 5579a99b22..5b785e1ec8 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp @@ -998,7 +998,7 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlops::IsApplicable(const ExecutionContext if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!problem.IsFp32() && !problem.IsFp16()) diff --git a/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp index d82e3de0e2..1588ac0ef8 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp @@ -956,7 +956,7 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::IsApplicable( if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!problem.IsFp32() && !problem.IsFp16() && diff --git a/src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp b/src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp index e2c9060244..8c900308dc 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp @@ -1523,7 +1523,7 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlops::IsApplicable(const ExecutionContext if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!problem.IsFp32() && !problem.IsFp16()) diff --git a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp index ce2b812eff..e60f5f4911 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp @@ -569,7 +569,7 @@ bool ConvAsmImplicitGemmGTCDynamicFwdDlopsNCHWC::IsApplicable( if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!problem.IsLayoutNCHWc()) diff --git a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp index 710457ffc5..d1cdcecd85 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp @@ -901,7 +901,7 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::IsApplicable( if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!problem.IsFp32() && !problem.IsFp16() && diff --git a/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp index a86d4a54e4..806095efb3 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp @@ -881,7 +881,7 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::IsApplicable( if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!problem.IsFp32() && !problem.IsFp16() && diff --git a/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp b/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp index fb32408b96..f9f8e81e1c 100644 --- a/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp +++ b/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp @@ -302,7 +302,7 @@ bool ConvAsmImplicitGemmV4R1DynamicFwd::IsApplicable(const ExecutionContext& ctx if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!problem.IsFp32()) @@ -348,7 +348,7 @@ bool ConvAsmImplicitGemmV4R1DynamicFwd_1x1::IsApplicable(const ExecutionContext& if(!problem.Is2d()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!problem.IsFp32()) diff --git a/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp b/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp index fa4628bfe8..f6ff1d79a7 100644 --- a/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp +++ b/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp @@ -843,7 +843,7 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlops::IsApplicable(const ExecutionContext if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!problem.IsFp32() && !problem.IsFp16()) diff --git a/src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp b/src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp index 1e434e6553..05d412901d 100644 --- a/src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp +++ b/src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp @@ -326,7 +326,7 @@ bool ConvAsmImplicitGemmV4R1DynamicWrw::IsApplicable(const ExecutionContext& ctx if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsTensorsCasted()) diff --git a/src/solver/conv_bin_wino3x3U.cpp b/src/solver/conv_bin_wino3x3U.cpp index 0c7f669a65..73d9a10e96 100644 --- a/src/solver/conv_bin_wino3x3U.cpp +++ b/src/solver/conv_bin_wino3x3U.cpp @@ -70,7 +70,7 @@ bool ConvBinWinograd3x3U::IsApplicable(const ExecutionContext& ctx, if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!problem.IsLayoutDefault()) diff --git a/src/solver/conv_bin_winoRxS.cpp b/src/solver/conv_bin_winoRxS.cpp index d82f8ab740..cb0045ab16 100644 --- a/src/solver/conv_bin_winoRxS.cpp +++ b/src/solver/conv_bin_winoRxS.cpp @@ -225,7 +225,7 @@ bool ConvBinWinogradRxS::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsTensorsCasted()) return false; diff --git a/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp b/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp index 966a14bf95..2b80b1bb85 100644 --- a/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp +++ b/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp @@ -426,7 +426,7 @@ bool ConvCKIgemmFwdBiasActivFused::IsApplicable(const FusionContext& ctx, return false; if(conv_problem.HasNonPackedTensors()) return false; - if(conv_problem.HasAtLeastOne64BitTensor()) + if(!conv_problem.AllTensorsDimsFitIntoInt()) return false; if(conv_problem.HasMixedDataTypes()) return false; diff --git a/src/solver/conv_ck_igemm_fwd_bias_res_add_activ_fused.cpp b/src/solver/conv_ck_igemm_fwd_bias_res_add_activ_fused.cpp index 9517956c8a..665175c39f 100644 --- a/src/solver/conv_ck_igemm_fwd_bias_res_add_activ_fused.cpp +++ b/src/solver/conv_ck_igemm_fwd_bias_res_add_activ_fused.cpp @@ -412,7 +412,7 @@ bool ConvCKIgemmFwdBiasResAddActivFused::IsApplicable(const FusionContext& ctx, return false; if(conv_problem.HasNonPackedTensors()) return false; - if(conv_problem.HasAtLeastOne64BitTensor()) + if(!conv_problem.AllTensorsDimsFitIntoInt()) return false; if(conv_problem.HasMixedDataTypes()) return false; diff --git a/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp b/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp index fd06640fcc..4b7b688c22 100644 --- a/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp +++ b/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp @@ -113,7 +113,7 @@ bool ConvCkIgemmFwdV6r1DlopsNchw::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsTensorsCasted()) return false; diff --git a/src/solver/conv_direct_naive_conv_bwd.cpp b/src/solver/conv_direct_naive_conv_bwd.cpp index 3caa6efbf3..abd286509d 100644 --- a/src/solver/conv_direct_naive_conv_bwd.cpp +++ b/src/solver/conv_direct_naive_conv_bwd.cpp @@ -49,7 +49,7 @@ bool ConvDirectNaiveConvBwd::IsApplicable(const ExecutionContext& ctx, if(!problem.IsDirectionBackwardData()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!problem.IsLayoutDefault() && !problem.IsLayoutNHWC()) return false; diff --git a/src/solver/conv_direct_naive_conv_fwd.cpp b/src/solver/conv_direct_naive_conv_fwd.cpp index 9601df1d8a..8e38537be4 100644 --- a/src/solver/conv_direct_naive_conv_fwd.cpp +++ b/src/solver/conv_direct_naive_conv_fwd.cpp @@ -56,7 +56,7 @@ bool ConvDirectNaiveConvFwd::IsApplicable(const ExecutionContext& ctx, if(!problem.IsDirectionForward()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsTensorsCasted()) diff --git a/src/solver/conv_direct_naive_conv_wrw.cpp b/src/solver/conv_direct_naive_conv_wrw.cpp index 022e4f1888..95e7e75f7b 100644 --- a/src/solver/conv_direct_naive_conv_wrw.cpp +++ b/src/solver/conv_direct_naive_conv_wrw.cpp @@ -56,7 +56,7 @@ bool ConvDirectNaiveConvWrw::IsApplicable(const ExecutionContext& ctx, if(!problem.IsDirectionBackwardWrW()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsTensorsCasted()) { diff --git a/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp index d090edab19..0908d9b850 100644 --- a/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp @@ -336,7 +336,7 @@ bool ConvHipImplicitGemm3DGroupBwdXdlops::IsApplicable( return false; if(miopen::IsEnabled(ENV(MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC))) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.HasMixedDataTypes()) return false; diff --git a/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp index dd71e659db..a4fdfd7e52 100644 --- a/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp @@ -312,7 +312,7 @@ bool ConvHipImplicitGemm3DGroupFwdXdlops::IsApplicable( return false; if(problem.GetConv().attribute.deterministic) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.HasMixedDataTypes()) return false; diff --git a/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp index 3595d06c38..aae133335e 100644 --- a/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp @@ -331,7 +331,7 @@ bool ConvHipImplicitGemm3DGroupWrwXdlops::IsApplicable( return false; if(miopen::IsEnabled(ENV(MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC))) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.HasMixedDataTypes()) return false; diff --git a/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp index 58ebefb723..5b8f534ed8 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp @@ -265,7 +265,7 @@ bool ConvHipImplicitGemmBwdXdlops::IsApplicable( return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.HasMixedDataTypes()) return false; diff --git a/src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp b/src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp index 0493c6ed51..258f805479 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp @@ -645,7 +645,7 @@ bool ConvHipImplicitGemmBwdDataV1R1::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!problem.IsLayoutDefault()) return false; diff --git a/src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp index 3855158680..89239684f4 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp @@ -792,7 +792,7 @@ bool ConvHipImplicitGemmBwdDataV1R1Xdlops::IsApplicable(const ExecutionContext& if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsTensorsCasted()) diff --git a/src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp b/src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp index 72b5e438b1..91a54b3106 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp @@ -760,7 +760,7 @@ bool ConvHipImplicitGemmBwdDataV4R1::IsApplicable(const ExecutionContext& ctx, if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsTensorsCasted()) diff --git a/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp index 661bbd5eaa..b8f020c732 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp @@ -843,7 +843,7 @@ bool ConvHipImplicitGemmBwdDataV4R1Xdlops::IsApplicable(const ExecutionContext& return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; diff --git a/src/solver/conv_hip_implicit_gemm_f16f8f16_bwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_f16f8f16_bwd_xdlops.cpp index 0bbd0fb4a6..976931f793 100644 --- a/src/solver/conv_hip_implicit_gemm_f16f8f16_bwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_f16f8f16_bwd_xdlops.cpp @@ -303,7 +303,7 @@ bool ConvHipImplicitGemmF16F8F16BwdXdlops::IsApplicable( return false; if(problem.GetConv().attribute.deterministic) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.HasMixedDataTypes()) return false; diff --git a/src/solver/conv_hip_implicit_gemm_f16f8f16_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_f16f8f16_fwd_xdlops.cpp index 743b149092..8b1491a24c 100644 --- a/src/solver/conv_hip_implicit_gemm_f16f8f16_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_f16f8f16_fwd_xdlops.cpp @@ -300,7 +300,7 @@ bool ConvHipImplicitGemmF16F8F16FwdXdlops::IsApplicable( return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!problem.IsTensorsCasted()) return false; diff --git a/src/solver/conv_hip_implicit_gemm_f16f8f16_wrw_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_f16f8f16_wrw_xdlops.cpp index fdf658beab..065267dbcc 100644 --- a/src/solver/conv_hip_implicit_gemm_f16f8f16_wrw_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_f16f8f16_wrw_xdlops.cpp @@ -300,7 +300,7 @@ bool ConvHipImplicitGemmF16F8F16WrwXdlops::IsApplicable( return false; if(problem.GetConv().attribute.deterministic) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.HasMixedDataTypes()) return false; diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp index 55c6bf8e39..13b8c2c7e5 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp @@ -62,7 +62,7 @@ bool ConvHipImplicitGemmV4R1Fwd::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!problem.IsFp32() && !problem.IsFp16() && !problem.IsBfp16()) return false; @@ -104,7 +104,7 @@ bool ConvHipImplicitGemmV4R1WrW::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.IsDirectionBackwardWrW()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!ctx.use_hip_kernels) return false; diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp index cb886720da..31e24fef54 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp @@ -590,7 +590,7 @@ bool ConvHipImplicitGemmV4R4Fwd::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!IsComposableKernelSupportedHardware(ctx)) return false; diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp index 77670a6d3f..2659c5ff41 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp @@ -996,7 +996,7 @@ bool ConvHipImplicitGemmForwardV4R4Xdlops::IsApplicable(const ExecutionContext& if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsTensorsCasted()) diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp index ace8040b22..cb5ef4b1c6 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp @@ -1064,7 +1064,7 @@ bool ConvHipImplicitGemmForwardV4R4Xdlops_Padded_Gemm::IsApplicable( if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(ctx.GetStream().GetDeviceName() == "gfx90a" && problem.IsGfx90aFp16altRequired()) diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp index aacef3faa6..f85e023417 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp @@ -1026,7 +1026,7 @@ bool ConvHipImplicitGemmForwardV4R5Xdlops::IsApplicable(const ExecutionContext& if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsTensorsCasted()) diff --git a/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp index 63e6d9917d..21aecc3234 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp @@ -266,7 +266,7 @@ bool ConvHipImplicitGemmFwdXdlops::IsApplicable( return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.HasMixedDataTypes()) return false; diff --git a/src/solver/conv_hip_implicit_gemm_grouped_bwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_grouped_bwd_xdlops.cpp index 2c4e806112..ca6c86a537 100644 --- a/src/solver/conv_hip_implicit_gemm_grouped_bwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_grouped_bwd_xdlops.cpp @@ -306,7 +306,7 @@ bool ConvHipImplicitGemmGroupBwdXdlops::IsApplicable( return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsTensorsCasted()) return false; diff --git a/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp index bfe6186daf..fa1f2a8e21 100644 --- a/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp @@ -431,7 +431,7 @@ bool ConvHipImplicitGemmGroupFwdXdlops::IsApplicable( return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsTensorsCasted()) return false; diff --git a/src/solver/conv_hip_implicit_gemm_grouped_wrw_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_grouped_wrw_xdlops.cpp index f1000a35c4..741744ad67 100644 --- a/src/solver/conv_hip_implicit_gemm_grouped_wrw_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_grouped_wrw_xdlops.cpp @@ -302,7 +302,7 @@ bool ConvHipImplicitGemmGroupWrwXdlops::IsApplicable( return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!problem.IsDirectionBackwardWrW()) return false; diff --git a/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp b/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp index 42bcce677d..251dc6814c 100644 --- a/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp +++ b/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp @@ -591,7 +591,7 @@ bool ConvHipImplicitGemmV4R4WrW::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!problem.IsLayoutDefault()) return false; diff --git a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp index 18a16b01bc..1124e5dd1e 100644 --- a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp @@ -1064,7 +1064,7 @@ bool ConvHipImplicitGemmWrwV4R4Xdlops::IsApplicable(const ExecutionContext& ctx, if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) diff --git a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp index 62f66c7c70..ad019f0e40 100644 --- a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp +++ b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp @@ -1136,7 +1136,7 @@ bool ConvHipImplicitGemmWrwV4R4Xdlops_Padded_Gemm::IsApplicable( if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsTensorsCasted()) diff --git a/src/solver/conv_mlir_igemm_bwd.cpp b/src/solver/conv_mlir_igemm_bwd.cpp index 570827ca86..43737dc9b6 100644 --- a/src/solver/conv_mlir_igemm_bwd.cpp +++ b/src/solver/conv_mlir_igemm_bwd.cpp @@ -52,7 +52,7 @@ bool ConvMlirIgemmBwd::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!IsComposableKernelSupportedHardware(ctx)) return false; diff --git a/src/solver/conv_mlir_igemm_bwd_xdlops.cpp b/src/solver/conv_mlir_igemm_bwd_xdlops.cpp index d8cf228dda..786d540a3c 100644 --- a/src/solver/conv_mlir_igemm_bwd_xdlops.cpp +++ b/src/solver/conv_mlir_igemm_bwd_xdlops.cpp @@ -55,7 +55,7 @@ bool ConvMlirIgemmBwdXdlops::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) return false; diff --git a/src/solver/conv_mlir_igemm_fwd.cpp b/src/solver/conv_mlir_igemm_fwd.cpp index e0fc124323..0e6ab3a167 100644 --- a/src/solver/conv_mlir_igemm_fwd.cpp +++ b/src/solver/conv_mlir_igemm_fwd.cpp @@ -172,7 +172,7 @@ bool ConvMlirIgemmFwd::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!IsComposableKernelSupportedHardware(ctx)) return false; diff --git a/src/solver/conv_mlir_igemm_fwd_xdlops.cpp b/src/solver/conv_mlir_igemm_fwd_xdlops.cpp index 3ce11b91af..94c03ffb2c 100644 --- a/src/solver/conv_mlir_igemm_fwd_xdlops.cpp +++ b/src/solver/conv_mlir_igemm_fwd_xdlops.cpp @@ -69,7 +69,7 @@ bool ConvMlirIgemmFwdXdlops::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!IsComposableKernelSupportedHardware(ctx)) return false; diff --git a/src/solver/conv_mlir_igemm_wrw.cpp b/src/solver/conv_mlir_igemm_wrw.cpp index 8a6dfa74d1..9dda9561ad 100644 --- a/src/solver/conv_mlir_igemm_wrw.cpp +++ b/src/solver/conv_mlir_igemm_wrw.cpp @@ -55,7 +55,7 @@ bool ConvMlirIgemmWrW::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) return false; diff --git a/src/solver/conv_mlir_igemm_wrw_xdlops.cpp b/src/solver/conv_mlir_igemm_wrw_xdlops.cpp index bc42b937cb..c2f3624002 100644 --- a/src/solver/conv_mlir_igemm_wrw_xdlops.cpp +++ b/src/solver/conv_mlir_igemm_wrw_xdlops.cpp @@ -56,7 +56,7 @@ bool ConvMlirIgemmWrWXdlops::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) return false; diff --git a/src/solver/conv_multipass_wino3x3WrW.cpp b/src/solver/conv_multipass_wino3x3WrW.cpp index f59fe453fa..4b005281cb 100644 --- a/src/solver/conv_multipass_wino3x3WrW.cpp +++ b/src/solver/conv_multipass_wino3x3WrW.cpp @@ -469,7 +469,7 @@ bool ConvWinograd3x3MultipassWrW return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; diff --git a/src/solver/conv_ocl_dir2D11x11.cpp b/src/solver/conv_ocl_dir2D11x11.cpp index d5f8dc697f..9371615be5 100644 --- a/src/solver/conv_ocl_dir2D11x11.cpp +++ b/src/solver/conv_ocl_dir2D11x11.cpp @@ -52,7 +52,7 @@ bool ConvOclDirectFwd11x11::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; diff --git a/src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp b/src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp index 5cc86294db..4e301aa44b 100644 --- a/src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp +++ b/src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp @@ -65,7 +65,7 @@ bool ConvOclBwdWrW1x1::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; diff --git a/src/solver/conv_ocl_dir2D_bwdWrW_2.cpp b/src/solver/conv_ocl_dir2D_bwdWrW_2.cpp index 92618edd9e..c94cb80e19 100644 --- a/src/solver/conv_ocl_dir2D_bwdWrW_2.cpp +++ b/src/solver/conv_ocl_dir2D_bwdWrW_2.cpp @@ -462,7 +462,7 @@ bool ConvOclBwdWrW2::IsApplicableBase(const ExecutionContext& ctx return false; if(!problem.IsDirectionBackwardWrW()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; diff --git a/src/solver/conv_ocl_dir2D_bwdWrW_53.cpp b/src/solver/conv_ocl_dir2D_bwdWrW_53.cpp index fe875912ef..0c23680b50 100644 --- a/src/solver/conv_ocl_dir2D_bwdWrW_53.cpp +++ b/src/solver/conv_ocl_dir2D_bwdWrW_53.cpp @@ -56,7 +56,7 @@ bool ConvOclBwdWrW53::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; diff --git a/src/solver/conv_ocl_dir2Dfwd.cpp b/src/solver/conv_ocl_dir2Dfwd.cpp index 753050afe2..58d69912b4 100644 --- a/src/solver/conv_ocl_dir2Dfwd.cpp +++ b/src/solver/conv_ocl_dir2Dfwd.cpp @@ -53,7 +53,7 @@ bool ConvOclDirectFwd::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; diff --git a/src/solver/conv_ocl_dir2Dfwd1x1.cpp b/src/solver/conv_ocl_dir2Dfwd1x1.cpp index 65ad0dc50d..64d9b31086 100644 --- a/src/solver/conv_ocl_dir2Dfwd1x1.cpp +++ b/src/solver/conv_ocl_dir2Dfwd1x1.cpp @@ -64,7 +64,7 @@ bool ConvOclDirectFwd1x1::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; diff --git a/src/solver/conv_ocl_dir2Dfwdgen.cpp b/src/solver/conv_ocl_dir2Dfwdgen.cpp index b4d6f50f96..468e89c30d 100644 --- a/src/solver/conv_ocl_dir2Dfwdgen.cpp +++ b/src/solver/conv_ocl_dir2Dfwdgen.cpp @@ -50,7 +50,7 @@ bool ConvOclDirectFwdGen::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; diff --git a/src/solver/conv_winoRxS.cpp b/src/solver/conv_winoRxS.cpp index 228e064020..ed11fa9aaf 100644 --- a/src/solver/conv_winoRxS.cpp +++ b/src/solver/conv_winoRxS.cpp @@ -654,7 +654,7 @@ static bool IsApplicableBase(const ExecutionContext& ctx, const ProblemDescripti return false; if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(!(problem.IsFp32() || problem.IsFp16())) return false; diff --git a/src/solver/conv_wino_fury_RxS.cpp b/src/solver/conv_wino_fury_RxS.cpp index 148573621e..289c9d5c86 100644 --- a/src/solver/conv_wino_fury_RxS.cpp +++ b/src/solver/conv_wino_fury_RxS.cpp @@ -177,7 +177,7 @@ bool ConvWinoFuryRxS::IsApplicable(const ExecutionContext& if(problem.HasNonPackedTensors()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; if(is2x3() && miopen::IsDisabled(ENV(MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F2X3))) diff --git a/src/solver/fft.cpp b/src/solver/fft.cpp index 4bf13f4b93..e32082e34f 100644 --- a/src/solver/fft.cpp +++ b/src/solver/fft.cpp @@ -121,7 +121,7 @@ bool fft::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& pr if(!problem.IsLayoutDefault()) return false; - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; const auto is_fwd = problem.IsDirectionForward(); diff --git a/src/solver/gemm.cpp b/src/solver/gemm.cpp index f6232b91f1..7131bf5669 100644 --- a/src/solver/gemm.cpp +++ b/src/solver/gemm.cpp @@ -77,7 +77,7 @@ static inline bool IsAnyBufferFp16(const TensorDescriptor& xDesc, bool GemmFwdBase::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; const auto& xDesc = problem.GetIn(); diff --git a/src/solver/gemm_bwd.cpp b/src/solver/gemm_bwd.cpp index f2adb229f4..e68a16d6cd 100644 --- a/src/solver/gemm_bwd.cpp +++ b/src/solver/gemm_bwd.cpp @@ -97,7 +97,7 @@ SlowdownFactor(int n_oper, const double oper_factor, const double multiple_oper_ bool GemmBwdBase::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; const auto& dyDesc = problem.GetIn(); diff --git a/src/solver/gemm_wrw.cpp b/src/solver/gemm_wrw.cpp index b6fb73296d..b07a9200d5 100644 --- a/src/solver/gemm_wrw.cpp +++ b/src/solver/gemm_wrw.cpp @@ -66,7 +66,7 @@ SlowdownFactor(int n_oper, const double oper_factor, const double multiple_oper_ bool GemmWrwBase::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM - if(problem.HasAtLeastOne64BitTensor()) + if(!problem.AllTensorsDimsFitIntoInt()) return false; const auto& dyDesc = problem.GetIn(); diff --git a/src/tensor.cpp b/src/tensor.cpp index 0ebfc2dc8d..bf887e0c5d 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -430,17 +430,21 @@ std::size_t TensorDescriptor::GetNumBytes() const bool TensorDescriptor::IsPacked() const { return this->packed; } -bool TensorDescriptor::Is64Bit() const +bool TensorDescriptor::AllDimsFitIntoInt() const { if(std::any_of(lens.cbegin(), lens.cend(), [](std::size_t x) { return x > std::numeric_limits::max(); })) - return true; + { + return false; + } if(std::any_of(strides.cbegin(), strides.cend(), [](std::size_t x) { return x > std::numeric_limits::max(); })) - return true; - return false; + { + return false; + } + return true; } bool TensorDescriptor::operator==(const TensorDescriptor& rhs) const