Skip to content

Commit

Permalink
Mha Bias added to Find 2.0 for forward pass (#3240)
Browse files Browse the repository at this point in the history
  • Loading branch information
Vsevolod1983 authored Sep 10, 2024
1 parent 8acc858 commit 1ffcdf3
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 25 deletions.
51 changes: 26 additions & 25 deletions include/miopen/miopen.h
Original file line number Diff line number Diff line change
Expand Up @@ -5407,33 +5407,34 @@ typedef enum
miopenTensorMhaAmaxDK = 33,
miopenTensorMhaAmaxDV = 34,
miopenTensorMhaAmaxDS = 35,
miopenTensorMhaBias = 36,

#ifdef MIOPEN_BETA_API
miopenTensorActivationX = 36,
miopenTensorActivationY = 37,
miopenTensorActivationDX = 38,
miopenTensorActivationDY = 39,
miopenTensorBiasX = 40,
miopenTensorBiasY = 41,
miopenTensorBias = 42,
miopenTensorSoftmaxX = 43,
miopenTensorSoftmaxY = 44,
miopenTensorSoftmaxDX = 45,
miopenTensorSoftmaxDY = 46,
miopenTensorBatchnormX = 47,
miopenTensorBatchnormY = 48,
miopenTensorBatchnormRunningMean = 49,
miopenTensorBatchnormRunningVariance = 50,
miopenTensorBatchnormSavedMean = 51,
miopenTensorBatchnormSavedVariance = 52,
miopenTensorBatchnormScale = 53,
miopenTensorBatchnormScaleDiff = 54,
miopenTensorBatchnormEstimatedMean = 55,
miopenTensorBatchnormEstimatedVariance = 56,
miopenTensorBatchnormBias = 57,
miopenTensorBatchnormBiasDiff = 58,
miopenTensorBatchnormDX = 59,
miopenTensorBatchnormDY = 60,
miopenTensorActivationX = 37,
miopenTensorActivationY = 38,
miopenTensorActivationDX = 39,
miopenTensorActivationDY = 40,
miopenTensorBiasX = 41,
miopenTensorBiasY = 42,
miopenTensorBias = 43,
miopenTensorSoftmaxX = 44,
miopenTensorSoftmaxY = 45,
miopenTensorSoftmaxDX = 46,
miopenTensorSoftmaxDY = 47,
miopenTensorBatchnormX = 48,
miopenTensorBatchnormY = 49,
miopenTensorBatchnormRunningMean = 50,
miopenTensorBatchnormRunningVariance = 51,
miopenTensorBatchnormSavedMean = 52,
miopenTensorBatchnormSavedVariance = 53,
miopenTensorBatchnormScale = 54,
miopenTensorBatchnormScaleDiff = 55,
miopenTensorBatchnormEstimatedMean = 56,
miopenTensorBatchnormEstimatedVariance = 57,
miopenTensorBatchnormBias = 58,
miopenTensorBatchnormBiasDiff = 59,
miopenTensorBatchnormDX = 60,
miopenTensorBatchnormDY = 61,
#endif

miopenTensorArgumentIsScalar = 1U << 31,
Expand Down
1 change: 1 addition & 0 deletions src/api/find2_0_commons.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ inline std::ostream& operator<<(std::ostream& stream, const miopenTensorArgument
case miopenTensorMhaAmaxDK: stream << "miopenTensorMhaAmaxDK"; break;
case miopenTensorMhaAmaxDV: stream << "miopenTensorMhaAmaxDV"; break;
case miopenTensorMhaAmaxDS: stream << "miopenTensorMhaAmaxDS"; break;
case miopenTensorMhaBias: stream << "miopenTensorMhaBias"; break;
case miopenTensorSoftmaxX: stream << "SoftmaxX"; break;
case miopenTensorSoftmaxY: stream << "SoftmaxY"; break;
case miopenTensorSoftmaxDX: stream << "SoftmaxDX"; break;
Expand Down
1 change: 1 addition & 0 deletions src/include/miopen/graphapi/util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ inline std::string_view tensorEnumIdToStr(miopenTensorArgumentId_t id)
ENUM_CASE(miopenTensorMhaAmaxDK)
ENUM_CASE(miopenTensorMhaAmaxDV)
ENUM_CASE(miopenTensorMhaAmaxDS)
ENUM_CASE(miopenTensorMhaBias)
default: MIOPEN_THROW(miopenStatusInternalError, "unknown tensor enum id");
}
#undef ENUM_CASE
Expand Down
4 changes: 4 additions & 0 deletions src/include/miopen/mha/mha.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ struct MhaInputDescsForward
TensorDescriptor dropoutSeedDesc;
TensorDescriptor dropoutOffsetDesc;

TensorDescriptor biasDesc;

// output tensors
TensorDescriptor oDesc;
TensorDescriptor amaxODesc;
Expand Down Expand Up @@ -129,6 +131,8 @@ struct MhaDataForward
ConstData_t dropoutSeedData;
ConstData_t dropoutOffsetData;

ConstData_t biasData;

// output tensors
Data_t oData;
Data_t amaxOData;
Expand Down
1 change: 1 addition & 0 deletions src/problem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ mha::ProblemDescription Problem::AsMha() const
dpDesc,
dsDesc,
doffDesc,
GetTensorDescriptor(miopenTensorMhaBias, TensorDescriptor()),
oDesc,
GetTensorDescriptorChecked(miopenTensorMhaAmaxO, "miopenTensorMhaAmaxO"),
GetTensorDescriptorChecked(miopenTensorMhaAmaxS, "miopenTensorMhaAmaxS"),
Expand Down
9 changes: 9 additions & 0 deletions src/solution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,14 @@ void Solution::RunImpl(Handle& handle,
auto dropoutOffset =
get_input_checked(miopenTensorMhaDropoutOffset, "miopenTensorMhaDropoutOffset");

// reading bias buffer as an optional parameter
Data_t biasBuffer = nullptr;
const auto& found = inputs.find(miopenTensorMhaBias);
if(found != inputs.end())
{
biasBuffer = found->second.buffer;
}

const auto invoke_ctx = [&]() -> AnyInvokeParams {
switch(problem_casted.GetDirection())
{
Expand All @@ -322,6 +330,7 @@ void Solution::RunImpl(Handle& handle,
dropoutProbability.buffer,
dropoutSeed.buffer,
dropoutOffset.buffer,
biasBuffer,
o.buffer,
amaxO.buffer,
amaxS.buffer,
Expand Down
6 changes: 6 additions & 0 deletions test/gtest/mha_find20.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ class MhaFind20Test
CreateTensor(miopenTensorMhaDropoutSeed).InitWithInt64Value(0);
CreateTensor(miopenTensorMhaDropoutOffset).InitWithInt64Value(0);

CreateTensor(miopenTensorMhaBias, test_n, test_h, test_s, test_s).InitWithRandom();

if(isForward)
{
CreateTensor(miopenTensorMhaQ, test_n, test_h, test_s, test_d).InitWithRandom();
Expand Down Expand Up @@ -423,6 +425,8 @@ class MhaFind20Test
const auto& mhads = tensors[miopenTensorMhaDropoutSeed];
const auto& mhado = tensors[miopenTensorMhaDropoutOffset];

const auto& mhabias = tensors[miopenTensorMhaBias];

mha::MhaInputDescsForward inputDescs = {
mhaK->GetTensorDescriptor(),
mhaQ->GetTensorDescriptor(),
Expand All @@ -437,6 +441,7 @@ class MhaFind20Test
mhadp->GetTensorDescriptor(),
mhads->GetTensorDescriptor(),
mhado->GetTensorDescriptor(),
mhabias->GetTensorDescriptor(),
tensors[miopenTensorMhaO]->GetTensorDescriptor(),
tensors[miopenTensorMhaAmaxO]->GetTensorDescriptor(),
tensors[miopenTensorMhaAmaxS]->GetTensorDescriptor(),
Expand All @@ -459,6 +464,7 @@ class MhaFind20Test
mhadp->gpuBuffer.get(),
mhads->gpuBuffer.get(),
mhado->gpuBuffer.get(),
mhabias->gpuBuffer.get(),
outputResultsMap[miopenTensorMhaO]->gpuBuffer.get(),
outputResultsMap[miopenTensorMhaAmaxO]->gpuBuffer.get(),
outputResultsMap[miopenTensorMhaAmaxS]->gpuBuffer.get(),
Expand Down

0 comments on commit 1ffcdf3

Please sign in to comment.