From 7f7bfeaeffd0da880f0aa821cfa127f8c734c82b Mon Sep 17 00:00:00 2001 From: "Li, Tingqian" Date: Mon, 7 Nov 2022 03:43:41 +0530 Subject: [PATCH] cpu: x64: binary injector: add rhs_addr_cache_reg to forked/diverged kernels --- src/cpu/x64/jit_avx512_core_fork_bf16_dw_conv_kernel.cpp | 2 +- src/cpu/x64/jit_gemm_convolution_utils.cpp | 2 +- src/cpu/x64/jit_uni_fork_dw_conv_kernel_f32.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/cpu/x64/jit_avx512_core_fork_bf16_dw_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_fork_bf16_dw_conv_kernel.cpp index 03e441c5aa9..11717077a7f 100644 --- a/src/cpu/x64/jit_avx512_core_fork_bf16_dw_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_fork_bf16_dw_conv_kernel.cpp @@ -591,7 +591,7 @@ void jit_avx512_fork_dw_conv_fwd_kernel_bf16::generate() { % (cpu_isa_traits::vlen / sizeof(float)); static constexpr bool use_exact_tail_scalar_bcast = false; const binary_injector::rhs_arg_static_params_t rhs_sp { - helper_vmm_idx, r10, r11, preserve_gpr, + helper_vmm_idx, r10, r11, r12, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), memory_desc_wrapper(&dst_md_), tail_size, k_oc_tail_mask, use_exact_tail_scalar_bcast}; diff --git a/src/cpu/x64/jit_gemm_convolution_utils.cpp b/src/cpu/x64/jit_gemm_convolution_utils.cpp index 51c8c860251..11e214e130f 100644 --- a/src/cpu/x64/jit_gemm_convolution_utils.cpp +++ b/src/cpu/x64/jit_gemm_convolution_utils.cpp @@ -66,7 +66,7 @@ struct jit_pp_kernel_t : pp_kernel_t, public jit_generator { static constexpr size_t tail_size = 0; static constexpr bool use_exact_tail_scalar_bcast = false; const binary_injector::rhs_arg_static_params_t rhs_sp { - helper_vmm_idx, r13, r14, preserve_gpr, + helper_vmm_idx, r13, r14, r15, preserve_gpr, preserve_vmm, PARAM_OFF(post_ops_binary_rhs_arg_vec), PARAM_OFF(dst_orig), memory_desc_wrapper(pd->dst_md()), tail_size, kreg_rem_mask, use_exact_tail_scalar_bcast}; diff --git a/src/cpu/x64/jit_uni_fork_dw_conv_kernel_f32.cpp b/src/cpu/x64/jit_uni_fork_dw_conv_kernel_f32.cpp index fdbb5e4a254..9b4dfb4878a 100644 --- a/src/cpu/x64/jit_uni_fork_dw_conv_kernel_f32.cpp +++ b/src/cpu/x64/jit_uni_fork_dw_conv_kernel_f32.cpp @@ -772,7 +772,7 @@ void jit_uni_fork_dw_conv_fwd_kernel_f32::generate() { % (cpu_isa_traits::vlen / sizeof(float)); static constexpr bool use_exact_tail_scalar_bcast = false; const binary_injector::rhs_arg_static_params_t rhs_sp { - helper_vmm_idx, r10, r11, preserve_gpr, + helper_vmm_idx, r10, r11, r12, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), memory_desc_wrapper(&dst_md_), tail_size, k_oc_tail_mask, use_exact_tail_scalar_bcast};