Skip to content

Commit

Permalink
Initial porting of Op4dTensorGeneric (#3404)
Browse files Browse the repository at this point in the history
* initial changes

* fixing parameters for Op4dTensorGeneric kernel

* minor changes

* minor changes for hip tidy

---------

Co-authored-by: Alex Eremin <[email protected]>
  • Loading branch information
novakovicdj and CAHEK7 authored Dec 31, 2024
1 parent 75ea67e commit 532bf75
Show file tree
Hide file tree
Showing 2 changed files with 709 additions and 1 deletion.
117 changes: 116 additions & 1 deletion src/kernels/MIOpenTensorKernelsHip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,4 +226,119 @@ extern "C" __global__ void Op3dTensorGeneric(const MIOPEN_TYPE* a,
}
}

#endif
#endif

#ifdef USE_4D_TENSOR_GENERIC
// NCHW
extern "C" __global__ void Op4dTensorGeneric(MIOPEN_TYPE* a,
const int a_nstride,
const int a_cstride,
const int a_hstride,
MIOPEN_TYPE* b,
const int b_c,
const int b_h,
const int b_w,
const int b_nstride,
const int b_cstride,
const int b_hstride,
MIOPEN_TYPE* c,
const int c_c,
const int c_h,
const int c_w,
const int c_nstride,
const int c_cstride,
const int c_hstride,
const MIOPEN_TYPE alpha0,
const MIOPEN_TYPE alpha1,
const MIOPEN_TYPE beta,
const unsigned int bitmap,
const int work_per_wg,
const long Aoffset,
const long Boffset,
const long Coffset,
const int num_wg)
{
int gid = blockIdx.x;

MIOPEN_TYPE* a_off = a + Aoffset;
MIOPEN_TYPE* b_off = b + Boffset;
MIOPEN_TYPE* c_off = c + Coffset;

// MIOPEN_TYPE operand = b[gid + Boffset];
// num_wg: the number of workgroups should be launched
// MAX_NUM_WG: the maximum number of workgroups actually launched
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wfloat-equal"
if(beta == static_cast<MIOPEN_TYPE>(0))
#pragma clang diagnostic pop
{
for(; gid < num_wg; gid += MAX_NUM_WG)
{
int lid = threadIdx.x;

int o_h_div = (bitmap & (1 << 0)) ? 1 : c_w;
int o_c_div = o_h_div * ((bitmap & (1 << 1)) ? 1 : c_h);
int o_n_div = o_c_div * ((bitmap & (1 << 2)) ? 1 : c_c);

int o_w_gid_off = gid % b_w;
int o_h_gid_off = (gid / b_w) % b_h;
int o_c_gid_off = (gid / b_w / b_h) % b_c;
int o_n_gid_off = (gid / b_w / b_h) / b_c;

int bindex = o_n_gid_off * b_nstride + o_c_gid_off * b_cstride +
o_h_gid_off * b_hstride + o_w_gid_off;
MIOPEN_TYPE operand = b_off[bindex] * alpha1;

while(lid < work_per_wg)
{
int o_w = (bitmap & (1 << 0)) ? o_w_gid_off : lid % c_w;
int o_h = (bitmap & (1 << 1)) ? o_h_gid_off : (lid / o_h_div) % c_h;
int o_c = (bitmap & (1 << 2)) ? o_c_gid_off : (lid / o_c_div) % c_c;
int o_n = (bitmap & (1 << 3)) ? o_n_gid_off : lid / o_n_div;

int aindex = o_n * a_nstride + o_c * a_cstride + o_h * a_hstride + o_w;
int cindex = o_n * c_nstride + o_c * c_cstride + o_h * c_hstride + o_w;
c_off[cindex] = MIOPEN_TENSOR_OP(a_off[aindex] * alpha0, operand);

lid += blockDim.x;
}
}
}
else
{
for(; gid < num_wg; gid += MAX_NUM_WG)
{
int lid = threadIdx.x;

int o_h_div = (bitmap & (1 << 0)) ? 1 : c_w;
int o_c_div = o_h_div * ((bitmap & (1 << 1)) ? 1 : c_h);
int o_n_div = o_c_div * ((bitmap & (1 << 2)) ? 1 : c_c);

int o_w_gid_off = gid % b_w;
int o_h_gid_off = (gid / b_w) % b_h;
int o_c_gid_off = (gid / b_w / b_h) % b_c;
int o_n_gid_off = (gid / b_w / b_h) / b_c;

int bindex = o_n_gid_off * b_nstride + o_c_gid_off * b_cstride +
o_h_gid_off * b_hstride + o_w_gid_off;
MIOPEN_TYPE operand = b_off[bindex] * alpha1;

while(lid < work_per_wg)
{
int o_w = (bitmap & (1 << 0)) ? o_w_gid_off : lid % c_w;
int o_h = (bitmap & (1 << 1)) ? o_h_gid_off : (lid / o_h_div) % c_h;
int o_c = (bitmap & (1 << 2)) ? o_c_gid_off : (lid / o_c_div) % c_c;
int o_n = (bitmap & (1 << 3)) ? o_n_gid_off : lid / o_n_div;

int aindex = o_n * a_nstride + o_c * a_cstride + o_h * a_hstride + o_w;
int cindex = o_n * c_nstride + o_c * c_cstride + o_h * c_hstride + o_w;
c_off[cindex] =
MIOPEN_TENSOR_OP(a_off[aindex] * alpha0, operand) + beta * c_off[cindex];

lid += blockDim.x;
}
}
}
}

#endif
Loading

0 comments on commit 532bf75

Please sign in to comment.