Skip to content

Commit

Permalink
Update CK-based 2d/3d convolution solvers to support nchw/ncdhw layout (
Browse files Browse the repository at this point in the history
  • Loading branch information
amberhassaan authored Feb 8, 2024
1 parent 5c7617c commit 3112e55
Show file tree
Hide file tree
Showing 27 changed files with 1,084 additions and 395 deletions.
20 changes: 11 additions & 9 deletions src/buffer_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,29 +148,31 @@ BuffInfo::BuffInfo(MemLayout_t layout, int nk, int c, int h, int w, int g, int _
}
}

MultiBufferWorkspaceTraits::MultiBufferWorkspaceTraits(std::initializer_list<size_t> v_size_,
size_t alignment_)
: v_size(v_size_), alignment(alignment_)
MultiBufferWorkspaceTraits::MultiBufferWorkspaceTraits(std::initializer_list<size_t> v_size_)
{

assert(v_size_.size() > 0);
size_t each_offset = 0;
v_offset.push_back(each_offset);
for(auto each_size : v_size)
for(auto each_size : v_size_)
{
size_t padding = (alignment - (each_size % alignment)) % alignment;
each_offset += each_size + padding;
auto padded_size = (each_size + max_padding) & (~max_padding);
each_offset += padded_size;
v_offset.push_back(each_offset);
}
}

size_t MultiBufferWorkspaceTraits::GetSize() const
{
return (&v_offset.back())[-1] + v_size.back();
assert(v_offset.size() > 1);
// last location contains the sum of all padded sizes
return v_offset.back();
}

size_t MultiBufferWorkspaceTraits::GetOffset(size_t index) const
{
if(index >= v_offset.size())
MIOPEN_THROW("index given overflows");
// last location contains the sum of all padded sizes
MIOPEN_THROW_IF(index >= (v_offset.size() - 1), "index given overflows");
return v_offset[index];
}

Expand Down
10 changes: 3 additions & 7 deletions src/conv/invokers/impl_gemm_dynamic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <miopen/handle.hpp>
#include <miopen/tensor_ops.hpp>
#include <miopen/solver/implicitgemm_util.hpp>
#include <miopen/util_sol.hpp>
#include <miopen/batched_transpose_sol.hpp>
#include <boost/any.hpp>

namespace miopen {
Expand Down Expand Up @@ -561,8 +561,6 @@ InvokerFactory MakeImplGemmDynamicForwardXdlopsNHWCInvokerFactory(
int trans_weight_idx = -1;
int trans_output_idx = -1;

constexpr size_t buf_alignment = 256;

if(is_nchw)
{
TransposeSolutionDefault2Nhwc trans_input(ctx, problem.GetInDataType(), n, c, hi, wi);
Expand Down Expand Up @@ -601,7 +599,7 @@ InvokerFactory MakeImplGemmDynamicForwardXdlopsNHWCInvokerFactory(
const size_t cast_size = need_cast ? miopen::GetTypeSize(miopenFloat) * n * k * ho * wo : 0;

MultiBufferWorkspaceTraits wt(
{trans_input_size, trans_weight_size, trans_output_size, cast_size}, buf_alignment);
{trans_input_size, trans_weight_size, trans_output_size, cast_size});

trans_input_offset = wt.GetOffset(0);
trans_weight_offset = wt.GetOffset(1);
Expand Down Expand Up @@ -879,8 +877,6 @@ InvokerFactory MakeImplGemmDynamicBackwardDataXdlopsNHWCInvokerFactory(
int trans_weight_idx = -1;
int trans_output_idx = -1;

constexpr size_t buf_alignment = 256;

if(is_nchw)
{
TransposeSolutionNhwc2Default trans_input(ctx, problem.GetOutDataType(), n, c, hi, wi);
Expand Down Expand Up @@ -919,7 +915,7 @@ InvokerFactory MakeImplGemmDynamicBackwardDataXdlopsNHWCInvokerFactory(
const size_t cast_size = need_cast ? miopen::GetTypeSize(miopenFloat) * n * c * hi * wi : 0;

MultiBufferWorkspaceTraits wt(
{trans_input_size, trans_weight_size, trans_output_size, cast_size}, buf_alignment);
{trans_input_size, trans_weight_size, trans_output_size, cast_size});

trans_input_offset = wt.GetOffset(0);
trans_weight_offset = wt.GetOffset(1);
Expand Down
61 changes: 61 additions & 0 deletions src/include/miopen/batched_transpose_sol.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#define GUARD_MIOPEN_BATCHED_TRANSPOSE_SOL_HPP

#include <miopen/miopen.h>
#include <miopen/errors.hpp>
#include <miopen/kernel_info.hpp>
#include <miopen/op_kernel_args.hpp>
#include <miopen/execution_context.hpp>
Expand Down Expand Up @@ -66,6 +67,66 @@ struct BatchedTransposeSolution
BatchedTransposeParam kernel_param_heuristic;
};

struct TransposeSolutionDefault2Nhwc : public BatchedTransposeSolution
{
TransposeSolutionDefault2Nhwc(const ExecutionContext& ctx_,
miopenDataType_t data_type_,
uint32_t n_,
uint32_t c_,
uint32_t h_,
uint32_t w_)
: BatchedTransposeSolution(ctx_, data_type_, n_, c_, h_ * w_)
{
MIOPEN_THROW_IF(size_t(h_ * w_) != (size_t(h_) * size_t(w_)), "integer overflow");
}
};

struct TransposeSolutionNhwc2Default : public BatchedTransposeSolution
{
TransposeSolutionNhwc2Default(const ExecutionContext& ctx_,
miopenDataType_t data_type_,
uint32_t n_,
uint32_t c_,
uint32_t h_,
uint32_t w_)
: BatchedTransposeSolution(ctx_, data_type_, n_, h_ * w_, c_)
{
MIOPEN_THROW_IF(size_t(h_ * w_) != (size_t(h_) * size_t(w_)), "integer overflow");
}
};

struct TransposeSolutionDefault2Ndhwc : public BatchedTransposeSolution
{
TransposeSolutionDefault2Ndhwc(const ExecutionContext& ctx_,
miopenDataType_t data_type_,
uint32_t n_,
uint32_t c_,
uint32_t d_,
uint32_t h_,
uint32_t w_)
: BatchedTransposeSolution(ctx_, data_type_, n_, c_, d_ * h_ * w_)
{
MIOPEN_THROW_IF(size_t(d_ * h_ * w_) != (size_t(d_) * size_t(h_) * size_t(w_)),
"integer overflow");
}
};

struct TransposeSolutionNdhwc2Default : public BatchedTransposeSolution
{
TransposeSolutionNdhwc2Default(const ExecutionContext& ctx_,
miopenDataType_t data_type_,
uint32_t n_,
uint32_t c_,
uint32_t d_,
uint32_t h_,
uint32_t w_)
: BatchedTransposeSolution(ctx_, data_type_, n_, d_ * h_ * w_, c_)
{
MIOPEN_THROW_IF(size_t(d_ * h_ * w_) != (size_t(d_) * size_t(h_) * size_t(w_)),
"integer overflow");
}
};

} // namespace miopen

#endif
8 changes: 5 additions & 3 deletions src/include/miopen/buffer_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,13 +310,15 @@ struct WinogradBufferInfo

struct MultiBufferWorkspaceTraits
{
MultiBufferWorkspaceTraits(std::initializer_list<size_t> v_size_, size_t alignment_);
MultiBufferWorkspaceTraits(std::initializer_list<size_t> v_size_);
size_t GetSize() const;
size_t GetOffset(size_t index) const;

std::vector<size_t> v_size;
std::vector<size_t> v_offset;
size_t alignment;
// aligning and padding to 256 byte boundary
constexpr static size_t max_padding = 255ull;
static_assert((max_padding & (max_padding + 1)) == 0,
"max_padding should be 1 less than a power of 2");
};

} // namespace miopen
Expand Down
2 changes: 2 additions & 0 deletions src/include/miopen/conv/tensors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ struct ConvDataTensors
out(tensors.dx)
{
}

operator ConvTensors() const { return {inDesc, in, wDesc, w, outDesc, out}; }
};

struct ConvWrwTensors
Expand Down
12 changes: 12 additions & 0 deletions src/include/miopen/errors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ template <class... Params>
miopen::MIOpenThrow(__FILE__, __LINE__, __VA_ARGS__); \
} while(false)

#define MIOPEN_THROW_IF(condition, msg) \
do \
{ \
if((condition)) \
{ \
miopen::MIOpenThrow(__FILE__, \
__LINE__, \
miopenStatusInternalError, \
std::string(msg) + ", failed condition: " #condition); \
} \
} while(false)

#define MIOPEN_THROW_CL_STATUS(...) \
MIOPEN_THROW(miopenStatusUnknownError, miopen::OpenCLErrorMessage(__VA_ARGS__))
#define MIOPEN_THROW_HIP_STATUS(...) \
Expand Down
1 change: 1 addition & 0 deletions src/include/miopen/handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ struct MIOPEN_EXPORT Handle : miopenHandle
template <class Container>
Allocator::ManageDataPtr Write(const Container& c)
{
assert(!c.empty());
using type = typename Container::value_type;
auto buf = this->Create<type>(c.size());
return std::move(
Expand Down
24 changes: 24 additions & 0 deletions src/include/miopen/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4545,6 +4545,10 @@ struct ConvHipImplicitGemmGroupFwdXdlops final
return 0.02f;
};

size_t GetWorkspaceSize(const ExecutionContext&,
const miopen::conv::ProblemDescription&) const override;
bool MayNeedWorkspace() const override { return true; }

private:
template <typename DataType>
bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const;
Expand Down Expand Up @@ -4617,6 +4621,10 @@ struct ConvHipImplicitGemm3DGroupFwdXdlops final
return 0.02f;
};

size_t GetWorkspaceSize(const ExecutionContext&,
const miopen::conv::ProblemDescription&) const override;
bool MayNeedWorkspace() const override { return true; }

private:
template <typename DataType>
bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const;
Expand Down Expand Up @@ -4694,6 +4702,10 @@ struct ConvHipImplicitGemm3DGroupWrwXdlops final
return 0.02f;
};

size_t GetWorkspaceSize(const ExecutionContext&,
const miopen::conv::ProblemDescription&) const override;
bool MayNeedWorkspace() const override { return true; }

private:
template <typename DataType>
bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const;
Expand Down Expand Up @@ -4771,6 +4783,10 @@ struct ConvHipImplicitGemm3DGroupBwdXdlops final
return 0.02f;
};

size_t GetWorkspaceSize(const ExecutionContext&,
const miopen::conv::ProblemDescription&) const override;
bool MayNeedWorkspace() const override { return true; }

private:
template <typename DataType>
bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const;
Expand Down Expand Up @@ -4847,6 +4863,10 @@ struct ConvHipImplicitGemmGroupBwdXdlops final
return 0.02f;
};

size_t GetWorkspaceSize(const ExecutionContext&,
const miopen::conv::ProblemDescription&) const override;
bool MayNeedWorkspace() const override { return true; }

private:
template <typename DataType>
bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const;
Expand Down Expand Up @@ -4923,6 +4943,10 @@ struct ConvHipImplicitGemmGroupWrwXdlops final
return 0.02f;
};

size_t GetWorkspaceSize(const ExecutionContext&,
const miopen::conv::ProblemDescription&) const override;
bool MayNeedWorkspace() const override { return true; }

private:
template <typename DataType>
bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const;
Expand Down
Loading

0 comments on commit 3112e55

Please sign in to comment.