Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuda blas backward compatibility #1747

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 28 additions & 16 deletions src/neural/cuda/layers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1423,7 +1423,7 @@ template <typename DataType>
AttentionPolicyHead<DataType>::AttentionPolicyHead(
BaseLayer<DataType>* ip, const MultiHeadWeights::PolicyHead& weights,
void* scratch, bool attention_body, ActivationFunction act,
int max_batch_size)
int max_batch_size, bool use_gemm_ex)
: BaseLayer<DataType>(64 * 64 + 24 * 8, 1, 1, ip),
attention_body_(attention_body),
// Old networks without attention body (e.g. T79) use hardcoded SELU
Expand Down Expand Up @@ -1475,8 +1475,8 @@ AttentionPolicyHead<DataType>::AttentionPolicyHead(
nullptr, 0, // smolgen weights not implemented in
// policy encoder heads yet.
max_batch_size, ACTIVATION_SWISH, act_,
1e-6); // attentionbody nets don't have policy encoders, so using old
// epsilon for backward compatibility with T78.
1e-6, // attentionbody nets don't have policy encoders, so
use_gemm_ex); // using old epsilon for backward compatibility with T78.
encoder_weights_.emplace_back(pW);
}
}
Expand All @@ -1486,15 +1486,16 @@ EncoderBlock<DataType>::EncoderBlock(
const MultiHeadWeights::EncoderLayer& cpu_weights, void* scratch, int heads,
int size, float alpha, DataType* smolgen_global_scratch,
int smolgen_global_size, int max_batch_size, ActivationFunction smolgen_act,
ActivationFunction ffn_act, float default_eps)
ActivationFunction ffn_act, float default_eps, bool use_gemm_ex)
: embedding_op_size_(size),
encoder_heads_(heads),
alpha_(alpha),
has_smolgen_(cpu_weights.mha.has_smolgen),
smolgen_activation_(smolgen_act),
ffn_activation_(ffn_act),
max_batch_size_(max_batch_size),
default_eps_(default_eps) {
default_eps_(default_eps),
use_gemm_ex_(use_gemm_ex) {
mha_q_size_ = cpu_weights.mha.q_b.size();
mha_k_size_ = cpu_weights.mha.k_b.size();
mha_v_size_ = cpu_weights.mha.v_b.size();
Expand Down Expand Up @@ -1606,7 +1607,8 @@ static void cublasXGemmStridedBatched(
cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k, float alpha, const void* A, int lda,
long long int strideA, const void* B, int ldb, long long int strideB,
float beta, void* C, int ldc, long long int strideC, int batchCount) {
float beta, void* C, int ldc, long long int strideC, int batchCount,
bool use_gemm_ex) {
const bool fp16 = std::is_same<half, DataType>::value;
if (fp16) {
unsigned short alpha_h = FP32toFP16(alpha);
Expand All @@ -1616,10 +1618,17 @@ static void cublasXGemmStridedBatched(
B, CUDA_R_16F, ldb, strideB, &beta_h, C, CUDA_R_16F, ldc, strideC,
batchCount, CUDA_R_16F, CUBLAS_GEMM_DEFAULT));
} else {
ReportCUBLASErrors(cublasGemmStridedBatchedEx(
handle, transa, transb, m, n, k, &alpha, A, CUDA_R_32F, lda, strideA, B,
CUDA_R_32F, ldb, strideB, &beta, C, CUDA_R_32F, ldc, strideC,
batchCount, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));
if (use_gemm_ex) {
ReportCUBLASErrors(cublasGemmStridedBatchedEx(
handle, transa, transb, m, n, k, &alpha, A, CUDA_R_32F, lda, strideA,
B, CUDA_R_32F, ldb, strideB, &beta, C, CUDA_R_32F, ldc, strideC,
batchCount, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));
} else {
ReportCUBLASErrors(cublasSgemmStridedBatched(
handle, transa, transb, m, n, k, &alpha, (const float*)A, lda,
strideA, (const float*)B, ldb, strideB, &beta, (float*)C, ldc,
strideC, batchCount));
}
}
}

Expand Down Expand Up @@ -1737,7 +1746,8 @@ void EncoderBlock<DataType>::Eval(int N, DataType* in_out_tensor,
cublasXGemmStridedBatched<DataType>(
cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f,
mha_qkv_w, num_inputs, num_inputs * num_outputs, in_out_tensor,
num_inputs, 0, 0.0f, mha_q, num_outputs, num_outputs * max_batch, 3);
num_inputs, 0, 0.0f, mha_q, num_outputs, num_outputs * max_batch, 3,
use_gemm_ex_);
addBiasBatched<DataType>(mha_q, mha_q, mha_qkv_b, 3, batch, num_outputs,
max_batch, ACTIVATION_NONE, stream);
}
Expand Down Expand Up @@ -1930,7 +1940,7 @@ void AttentionPolicyHead<DataType>::Eval(
cublasXGemmStridedBatched<DataType>(
cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f,
wqk_w_, num_inputs, num_inputs * num_outputs, input2_tensor, num_inputs,
0, 0.0f, wq, num_outputs, num_outputs * batch, 2);
0, 0.0f, wq, num_outputs, num_outputs * batch, 2, use_gemm_ex_);

addBiasBatched<DataType>(wq, wq, wqk_b_, 2, batch, num_outputs,
ACTIVATION_NONE, stream);
Expand All @@ -1953,7 +1963,7 @@ void AttentionPolicyHead<DataType>::Eval(
wk /*A*/, policy_d_model_ /*LDA*/, 64 * policy_d_model_, /*strideA*/
wq /*B*/, policy_d_model_ /*LDB*/, 64 * policy_d_model_, /*strideB*/
0.0f, output /*C*/, // output (policy_attn_logits)
64 /*LDC*/, 64 * 64 + 8 * 24 /*strideC*/, N);
64 /*LDC*/, 64 * 64 + 8 * 24 /*strideC*/, N, use_gemm_ex_);
}

// Compute promotion_logits in a single kernel (and put the result just after
Expand Down Expand Up @@ -2046,8 +2056,10 @@ AttentionBody<DataType>::AttentionBody(const MultiHeadWeights& weights,
void* scratch, Activations activations,
int num_res_blocks, int input_c,
int max_batch_size,
bool is_pe_dense_embedding)
: BaseLayer<DataType>(weights.ip_emb_b.size(), 8, 8, nullptr),
bool is_pe_dense_embedding,
bool use_gemm_ex)
: BaseLayer<DataType>(weights.ip_emb_b.size(), 8, 8, nullptr, false,
use_gemm_ex),
embedding_op_size_(weights.ip_emb_b.size()),
encoder_head_count_(weights.encoder_head_count),
activations_(activations),
Expand Down Expand Up @@ -2111,7 +2123,7 @@ AttentionBody<DataType>::AttentionBody(const MultiHeadWeights& weights,
enc, scratch, encoder_head_count_, embedding_op_size_, alpha,
smolgen_global_, smolgen_global_size_, max_batch_size,
activations_.smolgen_activation, activations_.ffn_activation,
is_pe_dense_embedding_ ? 1e-3 : 1e-6);
is_pe_dense_embedding_ ? 1e-3 : 1e-6, use_gemm_ex);
encoder_weights_.emplace_back(pW);
}
}
Expand Down
10 changes: 7 additions & 3 deletions src/neural/cuda/layers.h
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ class EncoderBlock {
int heads, int size, float alpha,
DataType* smolgen_global_scratch, int smolgen_global_size,
int max_batch_size, ActivationFunction smolgen_act,
ActivationFunction ffn_act, float default_eps);
ActivationFunction ffn_act, float default_eps, bool use_gemm_ex);
~EncoderBlock();

void Eval(int N, DataType* inpop, DataType* scratch0, DataType* scratch1,
Expand Down Expand Up @@ -393,6 +393,7 @@ class EncoderBlock {
int smol_global_size_;

const int max_batch_size_;
const bool use_gemm_ex_;
};

// The Attention policy head implementation
Expand All @@ -406,12 +407,14 @@ class AttentionPolicyHead : public BaseLayer<DataType> {
using BaseLayer<DataType>::GetC;
using BaseLayer<DataType>::GetH;
using BaseLayer<DataType>::GetW;
using BaseLayer<DataType>::use_gemm_ex_;

public:
AttentionPolicyHead(BaseLayer<DataType>* ip,
const MultiHeadWeights::PolicyHead& weights,
void* scratch, bool attention_body,
ActivationFunction act, int max_batch_size);
ActivationFunction act, int max_batch_size,
bool use_gemm_ex);
~AttentionPolicyHead();
void Eval(int N, DataType* output, const DataType* input,
const DataType* input2, void* scratch, size_t scratch_size,
Expand Down Expand Up @@ -476,7 +479,8 @@ class AttentionBody : public BaseLayer<DataType> {
public:
AttentionBody(const MultiHeadWeights& weights, void* scratch,
Activations activations, int num_res_blocks, int input_c,
int max_batch_size, bool is_pe_dense_embedding);
int max_batch_size, bool is_pe_dense_embedding,
bool use_gemm_ex);
~AttentionBody();
void Eval(int N, DataType* output, const DataType* input,
const DataType* input2, void* scratch, size_t scratch_size,
Expand Down
5 changes: 3 additions & 2 deletions src/neural/cuda/network_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,8 @@ class CudaNetwork : public Network {
numBlocks_ > 0 ? kNumFilters : kInputPlanes, max_batch_size_,
static_cast<InputEmbedding>(
file.format().network_format().input_embedding()) ==
InputEmbedding::INPUT_EMBEDDING_PE_DENSE);
InputEmbedding::INPUT_EMBEDDING_PE_DENSE,
use_gemm_ex);
network_.emplace_back(std::move(attention_body));

encoder_last_ = getLastLayer();
Expand All @@ -469,7 +470,7 @@ class CudaNetwork : public Network {
if (attn_policy_) {
auto AttentionPolicy = std::make_unique<AttentionPolicyHead<DataType>>(
getLastLayer(), head, scratch_mem_, attn_body_, act,
max_batch_size_);
max_batch_size_, use_gemm_ex);
network_.emplace_back(std::move(AttentionPolicy));

auto policymap = std::make_unique<PolicyMapLayer<DataType>>(
Expand Down
2 changes: 1 addition & 1 deletion src/neural/cuda/network_cudnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ class CudnnNetwork : public Network {
if (attn_policy_) {
auto AttentionPolicy = std::make_unique<AttentionPolicyHead<DataType>>(
getLastLayer(), head, scratch_mem_, false, ACTIVATION_SELU,
max_batch_size_);
max_batch_size_, use_gemm_ex);
network_.emplace_back(std::move(AttentionPolicy));

auto policymap = std::make_unique<PolicyMapLayer<DataType>>(
Expand Down