Skip to content

Commit

Permalink
add old qkv code as fallback for gpus with cc < 5.3
Browse files Browse the repository at this point in the history
  • Loading branch information
borg323 committed Apr 13, 2023
1 parent a6ce7b6 commit f1bb681
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 4 deletions.
39 changes: 37 additions & 2 deletions src/neural/cuda/layers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1772,7 +1772,7 @@ void EncoderBlock<DataType>::Eval(int N, DataType* in_out_tensor,
float factor = 1.0f / sqrt((float)depth);

// matmul_qk = tf.matmul(q, k, transpose_b=True)
{
if (use_gemm_ex_) {
if (*offset_pointers == nullptr) {
std::vector<DataType*> offsets(encoder_heads_ * max_batch_size_ * 5);
for (int i = 0; i < encoder_heads_ * max_batch_size_; i++) {
Expand Down Expand Up @@ -1816,6 +1816,27 @@ void EncoderBlock<DataType>::Eval(int N, DataType* in_out_tensor,
64 /*LDC*/,
// 64 * 64 /*strideC*/,
N * encoder_heads_);
} else {
for (int i = 0; i < encoder_heads_; i++) {
int offset = i * depth;
// layout of the output: encoder_heads_ * Batch * 64 * 64
int outOffset = i * N * 64 * 64;
cublasXGemmStridedBatched<DataType>(
cublas, CUBLAS_OP_T, CUBLAS_OP_N, 64 /*M*/, 64 /*N*/,
depth /*K*/, // A/B, and M/N are swapped for row-major to col-major
// transform
factor, // to handle "/ tf.math.sqrt(dk)"
mha_k + offset /*A*/,
d_model /*LDA*/, // (d_model = depth * encoder_heads_) to skip over
// other "depth" slices / heads
64 * d_model, /*strideA*/
mha_q + offset /*B*/,
d_model /*LDB*/, // to skip over other other "depth" slices / heads
64 * d_model, /*strideB*/
0.0f,
buffer1 + outOffset /*C*/, // output (matmul_qk) goes to buffer1
64 /*LDC*/, 64 * 64 /*strideC*/, N, false);
}
}

// attention_weights = tf.nn.softmax(scaled_attention_logits, axis = -1)
Expand All @@ -1829,7 +1850,7 @@ void EncoderBlock<DataType>::Eval(int N, DataType* in_out_tensor,
(const DataType*)nullptr, stream);
}

{
if (use_gemm_ex_) {
cublasXGemmBatched<DataType>(
cublas, CUBLAS_OP_N, CUBLAS_OP_N, depth /*M*/, 64 /*N*/, 64 /*K*/, 1.0f,
*offset_pointers + encoder_heads_ * max_batch_size_ *
Expand All @@ -1846,6 +1867,20 @@ void EncoderBlock<DataType>::Eval(int N, DataType* in_out_tensor,
d_model /*LDC*/,
// 64 * d_model /*strideC*/,
N * encoder_heads_);
} else {
for (int i = 0; i < encoder_heads_; i++) {
int offset = i * depth; // for output and "v" matrix
// layout: encoder_heads_ * Batch*64*64
int weightsOffset = i * N * 64 * 64;
cublasXGemmStridedBatched<DataType>(
cublas, CUBLAS_OP_N, CUBLAS_OP_N, depth /*M*/, 64 /*N*/, 64 /*K*/,
1.0f, mha_v + offset /*A*/, // "v" matrix
d_model /*LDA*/, // to skip over other "depth" slices / heads
64 * d_model, /*strideA*/
buffer1 + weightsOffset /*B*/, 64 /*LDB*/, 64 * 64, /*strideB*/
0.0f, buffer2 + offset /*C*/, // output goes to buffer2
d_model /*LDC*/, 64 * d_model /*strideC*/, N, false);
}
}

// #final dense layer (mha_dense), buffer2 -> buffer1
Expand Down
3 changes: 2 additions & 1 deletion src/neural/cuda/network_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ class CudaNetwork : public Network {
use_res_block_winograd_fuse_opt_ = options.Get<bool>("res_block_fusing");
}

const bool use_gemm_ex = deviceProp.major >= 5;
const bool use_gemm_ex = (deviceProp.major > 5) ||
(deviceProp.major == 5 && deviceProp.minor >= 3);

// 0. Check for SE.
has_se_ = false;
Expand Down
3 changes: 2 additions & 1 deletion src/neural/cuda/network_cudnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,8 @@ class CudnnNetwork : public Network {
}
}

const bool use_gemm_ex = deviceProp.major >= 5;
const bool use_gemm_ex = (deviceProp.major > 5) ||
(deviceProp.major == 5 && deviceProp.minor >= 3);

// 0. Check for SE.
has_se_ = false;
Expand Down

0 comments on commit f1bb681

Please sign in to comment.