diff --git a/libs/lczero-common b/libs/lczero-common index 55e1b382ef..e05fb7a505 160000 --- a/libs/lczero-common +++ b/libs/lczero-common @@ -1 +1 @@ -Subproject commit 55e1b382efadd57903e37f2a2e29caef3ea85799 +Subproject commit e05fb7a505554682acc8a197eb797c26b6db161d diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index 395bab8d84..7b0bb9f277 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -37,6 +37,7 @@ namespace lczero { namespace cudnn_backend { namespace { constexpr int kInputPlanes = 112; +constexpr int kNumRpeKernelSplits = 2; } // namespace ///////////////////////////////////////////////////////////////////////////// @@ -1242,7 +1243,8 @@ __global__ void preprocess_for_attention_body_kernel( if (c >= input_size) { // concatenate from position encoding array if (is_pe_dense_embedding) { - op = (T)(encoding[n * 64 * encoding_size + hw * encoding_size + (c - input_size)]); + op = (T)(encoding[n * 64 * encoding_size + hw * encoding_size + + (c - input_size)]); } else { op = (T)(encoding[64 * hw + (c - input_size)]); } @@ -1309,6 +1311,436 @@ void applyInputGating(T* output, const T* input, const T* mult, const T* add, ReportCUDAErrors(cudaGetLastError()); } +// Get the index corresponding to position (i,j,k,l) in a tensor +// with dimensions (I,J,K,L) where the innermost dimension is L. +__device__ __forceinline__ int getTensorIndex(int i, int j, int k, int l, int I, + int J, int K, int L) { + if (i >= I || j >= J || k >= K || l >= L) return -1; + + return (((((i * J) + j) * K) + k) * L) + l; + + // int index; + // index = i; + // index *= J; + // index += j; + // index *= K; + // index += k; + // index *= L; + // index += l; + + // return index; +} + +template +__device__ __forceinline__ dT readInputTensor(const sT* input_tensor, size_t i, + size_t j, size_t k, size_t l, + size_t I, size_t J, size_t K, + size_t L) { + // i is the outermost|slowest|most-significant index, while l is the + // innermost|fastest|least-significant index. + if (i >= I || j >= J || k >= K || l >= L) return 0; + + // int index; + // index = i; + // index *= J; + // index += j; + // index *= K; + // index += k; + // index *= L; + // index += l; + // int index = getTensorIndex(i, j, k, l, I, J, K, L); + size_t index = (((((i * J) + j) * K) + k) * L) + l; + + return (dT)(input_tensor[index]); +} + +__device__ __forceinline__ float sharedDotProductSum(float val, int x, int y) { + // Sum is done along the x-axis, while y is the accumulator. + int warpPos = y % 32; + __shared__ float partialSum[32]; + if (x == 0) partialSum[warpPos] = 0.0f; + + __syncthreads(); + + // Get warp-wide sum. + float warpSum = warpReduce(val); + if (x == 0) atomicAdd(&partialSum[warpPos], warpSum); + + __syncthreads(); + + return partialSum[warpPos]; +} + +template +__global__ void rpeVectorMultiply_parallel_kernel( + const T* rpeInput, const T* rpeWeights, const T* skipAdd, T* output, int B, + int H, int Q, int K, int D, float outScale, size_t rpetype) { + int x = threadIdx.x; + int y = threadIdx.y + blockDim.y * blockIdx.y; + int z = threadIdx.z + blockDim.z * blockIdx.z; + int h = z % H; + z = z / H; + int q = z % Q; + z = z / Q; + int b = z % B; + + if (rpetype == 0) { + // RPE-Q + // rpeInput: [B, Q, H, D] -> transpose to [B, H, Q, (1, D)] + // rpeWeights: [H, Q, K, D] -> transpose to [1, H, Q, (D, K)] + // output: [B, H, Q, K] + + // Read tensors per the input layouts and write out per the output layout. + // Sum is along the D dimension, and K is on the x-axis. + // Each thread handles one product of the sum. Thread 0 sums the products. + int d = x; + int k = y; + if (b >= B || h >= H || q >= Q || k >= K || d >= D) return; + + const int tensorIndex = getTensorIndex(b, q, h, d, B, Q, H, D); + const int weightIndex = getTensorIndex(h, q, k, d, H, Q, K, D); + + T sum = (T)sharedDotProductSum( + (float)rpeInput[tensorIndex] * (float)rpeWeights[weightIndex], x, y); + + if (d == 0) { + int outIdx = getTensorIndex(b, h, q, k, B, H, Q, K); + output[outIdx] = (sum + (T)skipAdd[outIdx]) * (T)outScale; + } + } else if (rpetype == 1) { + // RPE-K + // rpeInput: [B, K, H, D] -> transpose to [B, H, K, (1, D)] + // rpeWeights: [H, K, Q, D] -> transpose to [1, H, K, (D, Q)] + // output: [B, H, Q, K] + + // Read tensors per the input layouts and write out per the output layout. + // Sum is along the D dimension, and K is on the x-axis. + // Each thread handles one product of the sum. Thread 0 sums the products. + int d = x; + int k = y; + if (b >= B || h >= H || q >= Q || k >= K || d >= D) return; + + const int tensorIndex = getTensorIndex(b, k, h, d, B, K, H, D); + const int weightIndex = getTensorIndex(h, k, q, d, H, K, Q, D); + + T sum = (T)sharedDotProductSum( + (float)rpeInput[tensorIndex] * (float)rpeWeights[weightIndex], x, y); + + if (d == 0) { + int outIdx = getTensorIndex(b, h, q, k, B, H, Q, K); + output[outIdx] = (sum + (T)skipAdd[outIdx]) * (T)outScale; + } + } else if (rpetype == 2) { + // RPE-V + // rpeInput: [B, H, Q, K] -> transpose to [B, H, Q, (1, K)] + // rpeWeights: [H, Q, D, K] -> transpose to [1, H, Q, (K, D)] + // output: [B, Q, H, D] + // The skip connection is also already in BQHD order. + + // Read tensors per the input layouts and write out per the output layout. + // Sum is along the K dimension, and D is on the x-axis. + int k = x; + int d = y; + if (b >= B || h >= H || q >= Q || k >= K || d >= D) return; + + const int tensorIndex = getTensorIndex(b, h, q, k, B, H, Q, K); + const int weightIndex = getTensorIndex(h, q, d, k, H, Q, D, K); + + T sum = (T)sharedDotProductSum( + (float)rpeInput[tensorIndex] * (float)rpeWeights[weightIndex], x, y); + + if (k == 0) { + int outIdx = getTensorIndex(b, q, h, d, B, Q, H, D); + output[outIdx] = (sum + (T)skipAdd[outIdx]) * (T)outScale; + } + } +} + +template +__device__ __forceinline__ T dotProductSum(int x, const T* U, const T* V, + int length, bool fp16) { + assert(length >= 16); + T sum = 0; + int sublen = length / kNumRpeKernelSplits; + int lane = x & (kNumRpeKernelSplits - 1); + int start = lane * sublen; + + // Load from memory (16 elements a time) + if (fp16) { + half u[8]; + half v[8]; +#pragma unroll + for (int h = start; h < start + sublen; h += 8) { + copyAs(&u[0], &U[h]); + copyAs(&v[0], &V[h]); +#pragma unroll + for (int i = 0; i < 8; i++) { + sum += (T)u[i] * (T)v[i]; + } + } + } else { + float u[4]; + float v[4]; +#pragma unroll + for (int h = start; h < start + sublen; h += 4) { + copyAs(&u[0], &U[h]); + copyAs(&v[0], &V[h]); +#pragma unroll + for (int i = 0; i < 4; i++) { + sum += (T)u[i] * (T)v[i]; + } + } + } + + // Warp-level reduction to sum up adjacent threads. + __syncwarp(); +#pragma unroll + for (int i = 1; i < kNumRpeKernelSplits; i = i << 1) { + sum += __shfl_down_sync(0xffffffff, sum, i); + } + return sum; +} + +template +__global__ void rpeVectorMultiply_kernel(const T* rpeInput, const T* rpeWeights, + const T* skipAdd, T* output, int B, + int H, int Q, int K, int D, + float outScale, size_t rpetype) { + const int x = threadIdx.x + blockDim.x * blockIdx.x; + const int q = threadIdx.y + blockDim.y * blockIdx.y; + const int bh = threadIdx.z + blockDim.z * blockIdx.z; + const int h = bh % H; + const int b = bh / H; + const bool fp16 = std::is_same::value; + const int lane = x & (kNumRpeKernelSplits - 1); + + if (rpetype == 0) { + // RPE-Q + // rpeInput: [B, Q, H, D] -> transpose to [B, H, Q, (1, D)] + // rpeWeights: [H, Q, K, D] -> transpose to [1, H, Q, (D, K)] + // output: [B, H, Q, K] + + // Read tensors per the input layouts and write out per the output layout. + // Sum is along the D dimension, and K is on the x-axis. + int k = x / kNumRpeKernelSplits; + if (b >= B || h >= H || q >= Q || k >= K) return; + + const int tensorIndex = getTensorIndex(b, q, h, 0, B, Q, H, D); + const int weightIndex = getTensorIndex(h, q, k, 0, H, Q, K, D); + + T sum = dotProductSum(x, rpeInput + tensorIndex, rpeWeights + weightIndex, + D, fp16); + + if (lane == 0) { + int outIdx = getTensorIndex(b, h, q, k, B, H, Q, K); + output[outIdx] = ((T)sum + (T)skipAdd[outIdx]) * (T)outScale; + } + } else if (rpetype == 1) { + // RPE-K + // rpeInput: [B, K, H, D] -> transpose to [B, H, K, (1, D)] + // rpeWeights: [H, K, Q, D] -> transpose to [1, H, K, (D, Q)] + // output: [B, H, Q, K] + + // Read tensors per the input layouts and write out per the output layout. + // Sum is along the D dimension, and K is on the x-axis. + int k = x / kNumRpeKernelSplits; + if (b >= B || h >= H || q >= Q || k >= K) return; + + const int tensorIndex = getTensorIndex(b, k, h, 0, B, K, H, D); + const int weightIndex = getTensorIndex(h, k, q, 0, H, K, Q, D); + T sum = dotProductSum(x, rpeInput + tensorIndex, rpeWeights + weightIndex, + D, fp16); + if (lane == 0) { + int outIdx = getTensorIndex(b, h, q, k, B, H, Q, K); + output[outIdx] = ((T)sum + (T)skipAdd[outIdx]) * (T)outScale; + } + } else if (rpetype == 2) { + // RPE-V + // rpeInput: [B, H, Q, K] -> transpose to [B, H, Q, (1, K)] + // rpeWeights: [H, Q, D, K] -> transpose to [1, H, Q, (K, D)] + // output: [B, Q, H, D] + // The skip connection is also already in BQHD order. + + // Read tensors per the input layouts and write out per the output layout. + // Sum is along the K dimension, and D is on the x-axis. + int d = x / kNumRpeKernelSplits; + if (b >= B || h >= H || q >= Q || d >= D) return; + + const int tensorIndex = getTensorIndex(b, h, q, 0, B, H, Q, K); + const int weightIndex = getTensorIndex(h, q, d, 0, H, Q, D, K); + T sum = dotProductSum(x, rpeInput + tensorIndex, rpeWeights + weightIndex, + K, fp16); + if (lane == 0) { + int outIdx = getTensorIndex(b, q, h, d, B, Q, H, D); + output[outIdx] = ((T)sum + (T)skipAdd[outIdx]) * (T)outScale; + } + } +} + +template +void multiplyRPEAttentionLogits(const T* rpeInput, const T* rpeWeights, + const T* attnInput, T* output, int B, int H, + int Q, int K, int D, float outScale, + size_t rpetype, cudaStream_t stream) { + if (rpetype > 2) { + throw Exception("unsupported rpetype in multiplyRPEAttentionLogits."); + } + + // rpeType: Q | K | V + // rpeInput: [B, Q, H, D] | [B, K, H, D] | [B, H, Q, K] + // rpeWeights: [H, Q, K, D] | [H, K, Q, D] | [H, Q, D, K] + // attnInput: [B, H, Q, K] | [B, H, Q, K] | [B, Q, H, D] + + // int lda = (rpetype == 2) ? K : D; + // if (lda > 64 || rpetype != 0) { +#if 1 + // 3D block structure where x-axis maps to K (or D for rpetype==2), y-axis to + // Q and z-axis to BH. Each thread calculates the vector sum-product for the + // cube at BHQK | BHQD (i.e. "d,dk->k" | "d,dq->q" | "k,kd->d"). + dim3 blockDim, gridDim; + int X = rpetype == 2 ? D : K; + blockDim.x = std::min(32, X); + blockDim.y = std::min(16, Q); + blockDim.z = std::min(std::max(512 / (blockDim.x * blockDim.y), 1u), + (unsigned int)(B * H)); + gridDim.x = DivUp(X * kNumRpeKernelSplits, blockDim.x); + gridDim.y = DivUp(Q, blockDim.y); + gridDim.z = DivUp(B * H, blockDim.z); + + rpeVectorMultiply_kernel<<>>( + rpeInput, rpeWeights, attnInput, output, B, H, Q, K, D, outScale, + rpetype); +#else + // } else { + // 3D block structure where x-axis maps to D (or K for rpetype==2), y-axis + // maps to K (or D for rpetype==2) and z-axis to BQH. Each thread calculates + // the product while the warp-sum does the sum of the vector dot-product. + dim3 blockDim, gridDim; + int Y = rpetype == 2 ? D : K; + blockDim.x = std::min(64, lda); + blockDim.y = std::min(DivUp(1024, blockDim.x), Y); + blockDim.z = std::max(1024 / (blockDim.x * blockDim.y), 1u); + gridDim.x = 1; + gridDim.y = DivUp(Y, blockDim.y); + gridDim.z = DivUp(B * H * Q, blockDim.z); + + rpeVectorMultiply_parallel_kernel<<>>( + rpeInput, rpeWeights, attnInput, output, B, H, Q, K, D, outScale, + rpetype); +#endif + // } + ReportCUDAErrors(cudaGetLastError()); +} + +template +__global__ void rpeQK_multiply_kernel(const T* rpeInputQ, const T* rpeWeightsQ, + const T* rpeInputK, const T* rpeWeightsK, + const T* skipAdd, T* output, int B, int H, + int Q, int K, int D, float outScale) { + // Fused version of rpeVectorMultiply_kernel for RPE-Q and RPE-K. + const int x = threadIdx.x + blockDim.x * blockIdx.x; + const int q = threadIdx.y + blockDim.y * blockIdx.y; + const int bh = threadIdx.z + blockDim.z * blockIdx.z; + const int h = bh % H; + const int b = bh / H; + const bool fp16 = std::is_same::value; + const int lane = x & (kNumRpeKernelSplits - 1); + + // Read tensors per the input layouts and write out per the output layout. + // Sum is along the D dimension, and K is on the x-axis. + int k = x / kNumRpeKernelSplits; + if (b >= B || h >= H || q >= Q || k >= K) return; + + // RPE-Q sum + const int tidxq = getTensorIndex(b, q, h, 0, B, Q, H, D); + const int widxq = getTensorIndex(h, q, k, 0, H, Q, K, D); + T sum1 = dotProductSum(x, rpeInputQ + tidxq, rpeWeightsQ + widxq, D, fp16); + + // RPE-K sum. + const int tidxk = getTensorIndex(b, k, h, 0, B, K, H, D); + const int widxk = getTensorIndex(h, k, q, 0, H, K, Q, D); + T sum2 = dotProductSum(x, rpeInputK + tidxk, rpeWeightsK + widxk, D, fp16); + + // Write out the result. + if (lane == 0) { + int outIdx = getTensorIndex(b, h, q, k, B, H, Q, K); + output[outIdx] = (T)(sum1 + sum2 + skipAdd[outIdx]) * (T)outScale; + } +} + +template +void multiplyRpeQKLogits(const T* rpeInputQ, const T* rpeWeightsQ, + const T* rpeInputK, const T* rpeWeightsK, + const T* attnInput, T* output, int B, int H, int Q, + int K, int D, float outScale, cudaStream_t stream) { + // rpeType: Q | K + // rpeInput: [B, Q, H, D] | [B, K, H, D] + // rpeWeights: [H, Q, K, D] | [H, K, Q, D] + // attnInput: [B, H, Q, K] | [B, H, Q, K] + dim3 blockDim, gridDim; + blockDim.x = std::min(32, K); + blockDim.y = std::min(16, Q); + blockDim.z = std::min(std::max(512 / (blockDim.x * blockDim.y), 1u), + (unsigned int)(B * H)); + gridDim.x = DivUp(K * kNumRpeKernelSplits, blockDim.x); + gridDim.y = DivUp(Q, blockDim.y); + gridDim.z = DivUp(B * H, blockDim.z); + + rpeQK_multiply_kernel<<>>( + rpeInputQ, rpeWeightsQ, rpeInputK, rpeWeightsK, attnInput, output, B, H, + Q, K, D, outScale); + + ReportCUDAErrors(cudaGetLastError()); +} + +template +__global__ void permuteTensor_kernel(T* output, const T* input, int s1, int s2, + int s3, int s4, int p1, int p2, int p3, + int p4) { + // Shared memory for shape of tensor. + __shared__ int tshp[4]; + if (threadIdx.x == 0) { + tshp[0] = s1; + tshp[1] = s2; + tshp[2] = s3; + tshp[3] = s4; + } + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + if (tid >= s1 * s2 * s3 * s4) return; + + int tloc[] = {0, 0, 0, 0}; + int index = tid; + + tloc[3] = (index % s4); + index /= s4; + tloc[2] = index % s3; + index /= s3; + tloc[1] = index % s2; + index /= s2; + tloc[0] = index; + + int outIdx = (((((tloc[p1] * tshp[p2]) + tloc[p2]) * tshp[p3]) + tloc[p3]) * + tshp[p4]) + + tloc[p4]; + output[outIdx] = (T)input[tid]; +} + +template +void permuteTensor(T* output, const T* input, int s1, int s2, int s3, int s4, + int p1, int p2, int p3, int p4, cudaStream_t stream) { + // The order and bounds arrays are assumed to have 4 elements. + int elements = s1 * s2 * s3 * s4; + const int kBlockSize = 1024; + int blocks = DivUp(elements, kBlockSize); + + permuteTensor_kernel<<>>( + output, input, s1, s2, s3, s4, p1, p2, p3, p4); + ReportCUDAErrors(cudaGetLastError()); +} + // Template instantiation. template void copyTypeConverted(half* op, float* ip, int N, cudaStream_t stream); @@ -1595,5 +2027,34 @@ template void applyInputGating(float* output, const float* input, const float* mult, const float* add, int N, int C, int output_size, cudaStream_t stream); + +template void multiplyRPEAttentionLogits( + const half* rpeInput, const half* rpeWeights, const half* attnInput, + half* output, int B, int H, int Q, int K, int D, float outScale, + size_t rpetype, cudaStream_t stream); + +template void multiplyRPEAttentionLogits( + const float* rpeInput, const float* rpeWeights, const float* attnInput, + float* output, int B, int H, int Q, int K, int D, float outScale, + size_t rpetype, cudaStream_t stream); + +template void multiplyRpeQKLogits( + const half* rpeInputQ, const half* rpeWeightsQ, const half* rpeInputK, + const half* rpeWeightsK, const half* attnInput, half* output, int B, int H, + int Q, int K, int D, float outScale, cudaStream_t stream); + +template void multiplyRpeQKLogits( + const float* rpeInputQ, const float* rpeWeightsQ, const float* rpeInputK, + const float* rpeWeightsK, const float* attnInput, float* output, int B, + int H, int Q, int K, int D, float outScale, cudaStream_t stream); + +template void permuteTensor(half* output, const half* input, int s1, + int s2, int s3, int s4, int p1, int p2, + int p3, int p4, cudaStream_t stream); + +template void permuteTensor(float* output, const float* input, int s1, + int s2, int s3, int s4, int p1, int p2, + int p3, int p4, cudaStream_t stream); + } // namespace cudnn_backend } // namespace lczero diff --git a/src/neural/cuda/kernels.h b/src/neural/cuda/kernels.h index a1a2145737..0a5fc71b7f 100644 --- a/src/neural/cuda/kernels.h +++ b/src/neural/cuda/kernels.h @@ -157,5 +157,23 @@ void inputPreprocessForAttentionBody(T* output, const T* input, template void applyInputGating(T* output, const T* input, const T* mult, const T* add, int N, int HW, int C, cudaStream_t stream); + +template +void multiplyRPEAttentionLogits(const T* rpeInput, const T* rpeWeights, + const T* attnInput, T* output, int B, int H, + int Q, int K, int D, float outScale, + size_t rpetype, cudaStream_t stream); + +template +void multiplyRpeQKLogits(const T* rpeInputQ, const T* rpeWeightsQ, + const T* rpeInputK, const T* rpeWeightsK, + const T* attnInput, T* output, int B, int H, int Q, + int K, int D, float outScale, + cudaStream_t stream); + +template +void permuteTensor(T* output, const T* input, int s1, int s2, int s3, int s4, + int p1, int p2, int p3, int p4, cudaStream_t stream); + } // namespace cudnn_backend } // namespace lczero diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 8543443897..c72fbe0892 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -39,59 +39,87 @@ namespace lczero { -#if 0 +#if 1 +// function to calculate mean +static float mean(float arr[], int n) { + float sum = 0; + for (int i = 0; i < n; i++) { + sum += arr[i]; + } + return sum / n; +} + +// function to calculate standard deviation +static float stdDev(float arr[], int n) { + float m = mean(arr, n); // get the mean + float var = 0; // initialize variance + for (int i = 0; i < n; i++) { + var += pow(arr[i] - m, 2); // add the squared difference from mean + } + var /= n; // divide by number of elements + return sqrt(var); // return the square root of variance +} + // debug code to dump allocation in GPU memory template -void dumpTensor(T* memory, int elements, const char* message, bool only_summary = false) { - const bool fp16 = std::is_same::value; - printf("\n%s\n", message); - int elementSize = (int) (fp16 ? sizeof(half) : sizeof(float)); - int bytes = elements * elementSize; - void *temp = malloc(bytes); - cudaMemcpy(temp, memory, bytes, cudaMemcpyDeviceToHost); - float maxval = -std::numeric_limits::max(); - float minval = std::numeric_limits::max(); - int nans = 0; - int nanss[10] {}; - - for (int i = 0; i < elements; i++) - { - float val; - if (fp16) - { - half *arr = (half*)temp; - val = (float)arr[i]; - } - else - { - float *arr = (float *)temp; - val = arr[i]; - } - maxval = std::max(maxval, val); - minval = std::min(minval, val); +void dumpTensor(T* memory, int elements, const char* message, int lines = -1, + int start = 0) { + const bool fp16 = std::is_same::value; + printf("\n%s\n", message); + int elementSize = (int)(fp16 ? sizeof(half) : sizeof(float)); + int bytes = elements * elementSize; + void* temp = malloc(bytes); + cudaMemcpy(temp, memory, bytes, cudaMemcpyDeviceToHost); + float maxval = -std::numeric_limits::max(); + float minval = std::numeric_limits::max(); + int cnans = 0; + int nans[10]{}; + + std::vector fpArr(elements); + for (int i = 0; i < elements; i++) { + float val; + if (fp16) { + half* arr = (half*)temp; + val = (float)arr[i]; + } else { + float* arr = (float*)temp; + val = arr[i]; + } + fpArr[i] = val; + maxval = std::max(maxval, val); + minval = std::min(minval, val); - if (std::isnan(val)) { - if (nans < 10) nanss[nans] = i; - nans++; - } + if (std::isnan(val)) { + if (cnans < 10) nans[cnans] = i; + cnans++; + } - if (!only_summary || i < 2 || i == elements - 1) { - // printf("%8.4f ", val); - // if ((i % 8) == 7) printf("\n"); - printf("%i;%.6f\n", i, val); - } + if ((i >= start && (i < start + lines || lines == -1)) || + i == elements - 1) { + // printf("%8.4f ", val); + // if ((i % 8) == 7) printf("\n"); + printf("%i;%.6f\n", i, val); } - free(temp); - if (maxval == -std::numeric_limits::max()) - maxval = std::numeric_limits::quiet_NaN(); - if (minval == std::numeric_limits::max()) - minval = std::numeric_limits::quiet_NaN(); + } + free(temp); + if (maxval == -std::numeric_limits::max()) + maxval = std::numeric_limits::quiet_NaN(); + if (minval == std::numeric_limits::max()) + minval = std::numeric_limits::quiet_NaN(); + + float avg = mean(&fpArr[0], elements); + float stddev = stdDev(&fpArr[0], elements); + printf( + "Max: %.6f, Min: %.6f, Mean: %.6f, StdDev: %.6f\n" + "NaNs: %i of %i", + maxval, minval, avg, stddev, cnans, elements); - printf("Max: %.6f, Min: %.6f, NaNs: %i of %i", maxval, minval, nans, elements); + if (cnans > 0) { printf("\nNaN indices: "); - for (int i=0; i 10) printf("......"); - printf("\n"); + for (int i = 0; i < cnans && i < 10; i++) printf("%i ", nans[i]); + if (cnans > 10) printf("......"); + } + printf("\n"); } #endif @@ -1419,6 +1447,68 @@ void allocAndUpload(DataType** gpu_dest, std::vector cpu_src, (int)cpu_src.size(), 0); } +template +static void cublasXgemm(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + float alpha, const DataType* A, int lda, + const DataType* B, int ldb, float beta, DataType* C, + int ldc) { + const bool fp16 = std::is_same::value; + if (fp16) { + unsigned short alpha_h = FP32toFP16(alpha); + unsigned short beta_h = FP32toFP16(beta); + ReportCUBLASErrors(cublasHgemm( + handle, transa, transb, m, n, k, (const half*)&alpha_h, (const half*)A, + lda, (const half*)B, ldb, (const half*)&beta_h, (half*)C, ldc)); + } else { + ReportCUBLASErrors(cublasSgemm(handle, transa, transb, m, n, k, &alpha, + (const float*)A, lda, (const float*)B, ldb, + &beta, (float*)C, ldc)); + } +} + +template +static void cublasXGemmStridedBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, float alpha, const void* A, int lda, + long long int strideA, const void* B, int ldb, long long int strideB, + float beta, void* C, int ldc, long long int strideC, int batchCount) { + const bool fp16 = std::is_same::value; + if (fp16) { + unsigned short alpha_h = FP32toFP16(alpha); + unsigned short beta_h = FP32toFP16(beta); + ReportCUBLASErrors(cublasGemmStridedBatchedEx( + handle, transa, transb, m, n, k, &alpha_h, A, CUDA_R_16F, lda, strideA, + B, CUDA_R_16F, ldb, strideB, &beta_h, C, CUDA_R_16F, ldc, strideC, + batchCount, CUDA_R_16F, CUBLAS_GEMM_DEFAULT)); + } else { + ReportCUBLASErrors(cublasGemmStridedBatchedEx( + handle, transa, transb, m, n, k, &alpha, A, CUDA_R_32F, lda, strideA, B, + CUDA_R_32F, ldb, strideB, &beta, C, CUDA_R_32F, ldc, strideC, + batchCount, CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); + } +} + +template +static void cublasXGemmBatched(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + float alpha, DataType** A, int lda, DataType** B, + int ldb, float beta, DataType** C, int ldc, + int batchCount) { + const bool fp16 = std::is_same::value; + if (fp16) { + unsigned short alpha_h = FP32toFP16(alpha); + unsigned short beta_h = FP32toFP16(beta); + ReportCUBLASErrors(cublasHgemmBatched( + handle, transa, transb, m, n, k, (const half*)&alpha_h, (half**)A, lda, + (half**)B, ldb, (const half*)&beta_h, (half**)C, ldc, batchCount)); + } else { + ReportCUBLASErrors(cublasSgemmBatched( + handle, transa, transb, m, n, k, &alpha, (float**)A, lda, (float**)B, + ldb, &beta, (float**)C, ldc, batchCount)); + } +} + template AttentionPolicyHead::AttentionPolicyHead( BaseLayer* ip, const MultiHeadWeights::PolicyHead& weights, @@ -1579,67 +1669,87 @@ EncoderBlock::EncoderBlock( // GPU memory already allocated in AttentionBody. smol_global = smolgen_global_scratch; } -} -template -static void cublasXgemm(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - float alpha, const DataType* A, int lda, - const DataType* B, int ldb, float beta, DataType* C, - int ldc) { - const bool fp16 = std::is_same::value; - if (fp16) { - unsigned short alpha_h = FP32toFP16(alpha); - unsigned short beta_h = FP32toFP16(beta); - ReportCUBLASErrors(cublasHgemm( - handle, transa, transb, m, n, k, (const half*)&alpha_h, (const half*)A, - lda, (const half*)B, ldb, (const half*)&beta_h, (half*)C, ldc)); - } else { - ReportCUBLASErrors(cublasSgemm(handle, transa, transb, m, n, k, &alpha, - (const float*)A, lda, (const float*)B, ldb, - &beta, (float*)C, ldc)); - } -} + // RPE weights. + mha_rpe_q_size_ = cpu_weights.mha.rpe_q.size(); + mha_rpe_k_size_ = cpu_weights.mha.rpe_k.size(); + mha_rpe_v_size_ = cpu_weights.mha.rpe_v.size(); + + if (mha_rpe_q_size_ > 0 || mha_rpe_k_size_ > 0 || mha_rpe_v_size_ > 0) { + // Weights factorizer. + int rows = 15 * 15; + int cols = 64 * 64; + int row, col; + std::vector rpe_map(rows * cols); + // 15 * 15 in units for distance pairs to 64 * 64 pairs of squares. + // Distance pairs mapped on rows, while square pairs mapped on columns. + for (auto i = 0; i < 8; i++) { + for (auto j = 0; j < 8; j++) { + for (auto k = 0; k < 8; k++) { + for (auto l = 0; l < 8; l++) { + row = 15 * (i - k + 7) + (j - l + 7); + col = 64 * (i * 8 + j) + k * 8 + l; + rpe_map[row * cols + col] = 1.0f; + } + } + } + } + allocAndUpload(&mha_rpe_factorizer, rpe_map, scratch); + + // We need a cublas instance for the gemm. Create a temporary one. + cublasHandle_t tmp_cublas; + ReportCUBLASErrors(cublasCreate(&tmp_cublas)); + + // Allocate RPE weights and multiply by factorizer + DataType* rpe_scratch; + int heads = encoder_heads_; + int depth = mha_q_size_ / encoder_heads_; + if (mha_rpe_q_size_ > 0) { + allocAndUpload(&rpe_scratch, cpu_weights.mha.rpe_q, scratch); + + // Gemm to factorize the RPE Q weights. + cublasXgemm(tmp_cublas, CUBLAS_OP_N, CUBLAS_OP_T, 4096, + mha_q_size_, 225, 1.0f, mha_rpe_factorizer, 4096, + rpe_scratch, mha_q_size_, 0.0f, (DataType*)scratch, + 4096); + + // Permute RPE Q weights: [D, H, Q, K] -> [H, Q, K, D] + ReportCUDAErrors( + cudaMalloc(&mha_rpe_q, mha_q_size_ * 4096 * sizeof(DataType))); + permuteTensor((DataType*)mha_rpe_q, (const DataType*)scratch, depth, + heads, 64, 64, 1, 2, 3, 0, 0); + } + if (mha_rpe_k_size_ > 0) { + allocAndUpload(&rpe_scratch, cpu_weights.mha.rpe_k, scratch); -template -static void cublasXGemmStridedBatched( - cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, - int m, int n, int k, float alpha, const void* A, int lda, - long long int strideA, const void* B, int ldb, long long int strideB, - float beta, void* C, int ldc, long long int strideC, int batchCount) { - const bool fp16 = std::is_same::value; - if (fp16) { - unsigned short alpha_h = FP32toFP16(alpha); - unsigned short beta_h = FP32toFP16(beta); - ReportCUBLASErrors(cublasGemmStridedBatchedEx( - handle, transa, transb, m, n, k, &alpha_h, A, CUDA_R_16F, lda, strideA, - B, CUDA_R_16F, ldb, strideB, &beta_h, C, CUDA_R_16F, ldc, strideC, - batchCount, CUDA_R_16F, CUBLAS_GEMM_DEFAULT)); - } else { - ReportCUBLASErrors(cublasGemmStridedBatchedEx( - handle, transa, transb, m, n, k, &alpha, A, CUDA_R_32F, lda, strideA, B, - CUDA_R_32F, ldb, strideB, &beta, C, CUDA_R_32F, ldc, strideC, - batchCount, CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); - } -} + // Gemm to factorize the RPE K weights. + cublasXgemm(tmp_cublas, CUBLAS_OP_N, CUBLAS_OP_T, 4096, + mha_k_size_, 225, 1.0f, mha_rpe_factorizer, 4096, + rpe_scratch, mha_k_size_, 0.0f, (DataType*)scratch, + 4096); -template -static void cublasXGemmBatched(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - float alpha, DataType** A, int lda, DataType** B, - int ldb, float beta, DataType** C, int ldc, - int batchCount) { - const bool fp16 = std::is_same::value; - if (fp16) { - unsigned short alpha_h = FP32toFP16(alpha); - unsigned short beta_h = FP32toFP16(beta); - ReportCUBLASErrors(cublasHgemmBatched( - handle, transa, transb, m, n, k, (const half*)&alpha_h, (half**)A, lda, - (half**)B, ldb, (const half*)&beta_h, (half**)C, ldc, batchCount)); - } else { - ReportCUBLASErrors(cublasSgemmBatched( - handle, transa, transb, m, n, k, &alpha, (float**)A, lda, (float**)B, - ldb, &beta, (float**)C, ldc, batchCount)); + // Permute RPE K weights: [D, H, Q, K] -> [H, K, Q, D] + ReportCUDAErrors( + cudaMalloc(&mha_rpe_k, mha_k_size_ * 4096 * sizeof(DataType))); + permuteTensor((DataType*)mha_rpe_k, (const DataType*)scratch, depth, + heads, 64, 64, 1, 3, 2, 0, 0); + } + if (mha_rpe_v_size_ > 0) { + allocAndUpload(&rpe_scratch, cpu_weights.mha.rpe_v, scratch); + + // Gemm to factorize the RPE V weights. + cublasXgemm(tmp_cublas, CUBLAS_OP_N, CUBLAS_OP_T, 4096, + mha_v_size_, 225, 1.0f, mha_rpe_factorizer, 4096, + rpe_scratch, mha_v_size_, 0.0f, (DataType*)scratch, + 4096); + + // Permute RPE V weights: [D, H, Q, K] -> [H, Q, D, K] + ReportCUDAErrors( + cudaMalloc(&mha_rpe_v, mha_v_size_ * 4096 * sizeof(DataType))); + permuteTensor((DataType*)mha_rpe_v, (const DataType*)scratch, depth, + heads, 64, 64, 1, 2, 0, 3, 0); + } + cublasDestroy(tmp_cublas); } } @@ -1790,7 +1900,9 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, cublas, CUBLAS_OP_T, CUBLAS_OP_N, 64 /*M*/, 64 /*N*/, depth /*K*/, // A/B, and M/N are swapped for row-major to col-major // transform - factor, // to handle "/ tf.math.sqrt(dk)" + mha_rpe_q_size_ > 0 || mha_rpe_k_size_ > 0 + ? 1.0f + : factor, // in RPE nets, scaling is done after RPE logits *offset_pointers, // mha_k + offset /*A*/, d_model /*LDA*/, // (d_model = depth * encoder_heads_) to skip over // other "depth" slices / heads @@ -1808,6 +1920,41 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, N * encoder_heads_); } + { + // RPE Q and K. + if (mha_rpe_q_size_ > 0 && mha_rpe_k_size_ > 0) { + // Matrix-vector multiplication for query x rpe_q + key x rpe_k + attn + // Note: mha_q and mha_k here are not yet transposed, so shape is still + // BQHD/BKHD. mha_q @ rpe_q: [B, Q, H, D] x [D, H, Q, K] mha_k @ rpe_k: + // [B, K, H, D] x [D, H, Q, K] Kernel performs the required + // transpositions. + multiplyRpeQKLogits(mha_q, mha_rpe_q, mha_k, mha_rpe_k, buffer1, + buffer1, N, encoder_heads_, 64, 64, depth, + factor, stream); + } else { + // RPE Q. + if (mha_rpe_q_size_ > 0) { + // Matrix-vector multiplication for query x rpe_q + // Note: mha_q here is not yet transposed, so shape is still BQHD. + // mha_q @ rpe_q: [B, Q, H, D] x [D, H, Q, K] + // Kernel performs the required transpositions. + float outScale = mha_rpe_k_size_ == 0 ? factor : 1.0f; + multiplyRPEAttentionLogits(mha_q, mha_rpe_q, buffer1, buffer1, + N, encoder_heads_, 64, 64, depth, + outScale, 0, stream); + } + // RPE K. + if (mha_rpe_k_size_ > 0) { + // Matrix-vector multiplication for key x rpe_k + // Note: mha_k here is not yet transposed, so shape is still BKHD. + // mha_k @ rpe_k: [B, K, H, D] x [D, H, Q, K] + // Kernel performs the required transpositions. + multiplyRPEAttentionLogits(mha_k, mha_rpe_k, buffer1, buffer1, + N, encoder_heads_, 64, 64, depth, + factor, 1, stream); + } + } + } // attention_weights = tf.nn.softmax(scaled_attention_logits, axis = -1) // attention_weights -> buffer1 if (has_smolgen_) { @@ -1838,6 +1985,17 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, N * encoder_heads_); } + if (mha_rpe_v_size_ > 0) { + // RPE-V + // Matrix-vector multiplication for attn x rpe_v + // attn @ rpe_v: [B, H, Q, K] x [D, H, Q, K] + // Kernel performs the required transpositions. + // The matmul result (buffer2) is already with BQHD shape. + multiplyRPEAttentionLogits(buffer1, mha_rpe_v, buffer2, buffer2, + N, encoder_heads_, 64, 64, depth, 1.0f, + 2, stream); + } + // #final dense layer (mha_dense), buffer2 -> buffer1 { const int num_inputs = d_model; diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index de563f9346..8eb2172181 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -353,6 +353,8 @@ class EncoderBlock { DataType *mha_v_w, *mha_v_b; DataType *mha_qkv_w, *mha_qkv_b; DataType *mha_dense_w, *mha_dense_b; + DataType *mha_rpe_q, *mha_rpe_k, *mha_rpe_v; + DataType *mha_rpe_factorizer; DataType *ln1_gammas, *ln1_betas; @@ -372,6 +374,9 @@ class EncoderBlock { int mha_k_size_; int mha_v_size_; int mha_dense_size_; + int mha_rpe_q_size_; + int mha_rpe_k_size_; + int mha_rpe_v_size_; int ffn_dense1_size_; int ffn_dense2_size_; @@ -379,7 +384,7 @@ class EncoderBlock { int embedding_op_size_; int encoder_heads_; - float alpha_; // scale to apply to skip connection add + float alpha_; // scale to apply to skip connection add float default_eps_; // value of epsilon where it wasn't specified in training const bool has_smolgen_; @@ -485,13 +490,18 @@ class AttentionBody : public BaseLayer { private: // GPU allocations to hold various weights used by the attention net body. - DataType *ip_emb_pre_w_, *ip_emb_pre_b_; // input position preprocessing weights. - DataType *ip_emb_w_, *ip_emb_b_; // "embedding" layer in net body - DataType *ip_emb_ln_g_, *ip_emb_ln_b_; // input embedding layernorm gamma and beta - DataType *ip_mult_gate_, *ip_add_gate_; // input gating - DataType *ip_emb_ffn_d1_w_, *ip_emb_ffn_d1_b_; // input embedding FFN dense1 weights - DataType *ip_emb_ffn_d2_w_, *ip_emb_ffn_d2_b_; // input embedding FFN dense2 weights - DataType *ip_emb_ffn_ln_g_, *ip_emb_ffn_ln_b_; // input embedding FFN layernorm gamma and beta + DataType *ip_emb_pre_w_, + *ip_emb_pre_b_; // input position preprocessing weights. + DataType *ip_emb_w_, *ip_emb_b_; // "embedding" layer in net body + DataType *ip_emb_ln_g_, + *ip_emb_ln_b_; // input embedding layernorm gamma and beta + DataType *ip_mult_gate_, *ip_add_gate_; // input gating + DataType *ip_emb_ffn_d1_w_, + *ip_emb_ffn_d1_b_; // input embedding FFN dense1 weights + DataType *ip_emb_ffn_d2_w_, + *ip_emb_ffn_d2_b_; // input embedding FFN dense2 weights + DataType *ip_emb_ffn_ln_g_, + *ip_emb_ffn_ln_b_; // input embedding FFN layernorm gamma and beta DataType *smolgen_global_; // global smolgen weights for all encoder layers DataType *pos_encoding_; int embedding_dense_size_; @@ -523,8 +533,8 @@ class ValueHead : public BaseLayer { public: ValueHead(BaseLayer* ip, const MultiHeadWeights::ValueHead& weights, - void* scratch, bool attention_body, bool wdl, ActivationFunction act, - int max_batch_size, bool use_gemm_ex); + void* scratch, bool attention_body, bool wdl, + ActivationFunction act, int max_batch_size, bool use_gemm_ex); ~ValueHead(); void Eval(int N, DataType* output, const DataType* input, const DataType* input2, void* scratch, size_t scratch_size, @@ -548,6 +558,5 @@ class ValueHead : public BaseLayer { ActivationFunction act_; }; - } // namespace cudnn_backend } // namespace lczero diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index cf67d1336c..420b92d1fb 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -121,8 +121,8 @@ static size_t getMaxAttentionBodySize(const MultiHeadWeights& weights, int N) { template class CudaNetworkComputation : public NetworkComputation { public: - CudaNetworkComputation(CudaNetwork* network, - bool wdl, bool moves_left); + CudaNetworkComputation(CudaNetwork* network, bool wdl, + bool moves_left); ~CudaNetworkComputation(); void AddInput(InputPlanes&& input) override { @@ -530,8 +530,8 @@ class CudaNetwork : public Network { pblczero::NetworkFormat::VALUE_WDL; BaseLayer* lastlayer = attn_body_ ? encoder_last_ : resi_last_; auto value_main = std::make_unique>( - lastlayer, head, scratch_mem_, attn_body_, wdl_, act, - max_batch_size_, use_gemm_ex); + lastlayer, head, scratch_mem_, attn_body_, wdl_, act, max_batch_size_, + use_gemm_ex); network_.emplace_back(std::move(value_main)); } diff --git a/src/neural/network_legacy.cc b/src/neural/network_legacy.cc index 53846353c6..44f79c4517 100644 --- a/src/neural/network_legacy.cc +++ b/src/neural/network_legacy.cc @@ -142,7 +142,10 @@ BaseWeights::MHA::MHA(const pblczero::Weights::MHA& mha) dense_w(LayerAdapter(mha.dense_w()).as_vector()), dense_b(LayerAdapter(mha.dense_b()).as_vector()), smolgen(Smolgen(mha.smolgen())), - has_smolgen(mha.has_smolgen()) {} + has_smolgen(mha.has_smolgen()), + rpe_q(LayerAdapter(mha.rpe_q()).as_vector()), + rpe_k(LayerAdapter(mha.rpe_k()).as_vector()), + rpe_v(LayerAdapter(mha.rpe_v()).as_vector()) {} BaseWeights::FFN::FFN(const pblczero::Weights::FFN& ffn) : dense1_w(LayerAdapter(ffn.dense1_w()).as_vector()), diff --git a/src/neural/network_legacy.h b/src/neural/network_legacy.h index 72ce67544f..7ad5db82ba 100644 --- a/src/neural/network_legacy.h +++ b/src/neural/network_legacy.h @@ -81,6 +81,9 @@ struct BaseWeights { Vec dense_b; Smolgen smolgen; bool has_smolgen; + Vec rpe_q; + Vec rpe_k; + Vec rpe_v; }; struct FFN {