-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathgemm-simple.cu
164 lines (129 loc) · 4.86 KB
/
gemm-simple.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#include <cuda.h>
#include <cublas_v2.h>
#include <stdlib.h>
#include <cute/tensor.hpp>
template <typename T>
void gen_rand_data(T *data, int n);
template <typename T, int kTileM, int kTileN, int kTileK, typename TiledMMA>
__global__ void gemm_simple(T *Cptr, const T *Aptr, const T *Bptr, int m, int n, int k) {
using namespace cute;
Tensor A = make_tensor(make_gmem_ptr(Aptr), make_shape(m, k), make_stride(k, Int<1>{}));
Tensor B = make_tensor(make_gmem_ptr(Bptr), make_shape(n, k), make_stride(k, Int<1>{}));
Tensor C = make_tensor(make_gmem_ptr(Cptr), make_shape(m, n), make_stride(n, Int<1>{}));
int ix = blockIdx.x;
int iy = blockIdx.y;
Tensor gA = local_tile(A, make_tile(Int<kTileM>{}, Int<kTileK>{}), make_coord(iy, _));
Tensor gB = local_tile(B, make_tile(Int<kTileN>{}, Int<kTileK>{}), make_coord(ix, _));
Tensor gC = local_tile(C, make_tile(Int<kTileM>{}, Int<kTileN>{}), make_coord(iy, ix));
// gA(kTileM, kTileK, num_tile_k)
// gB(kTileN, kTileK, num_tile_k)
// gC(kTileM, kTileN)
TiledMMA tiled_mma;
auto thr_mma = tiled_mma.get_slice(threadIdx.x);
auto tAgA = thr_mma.partition_A(gA); // (MMA, MMA_M, MMA_K, num_tile_k)
auto tBgB = thr_mma.partition_B(gB); // (MMA, MMA_N, MMA_K, num_tile_k)
auto tCgC = thr_mma.partition_C(gC); // (MMA, MMA_M, MMA_N)
auto tArA = thr_mma.partition_fragment_A(gA(_, _, 0)); // (MMA, MMA_M, MMA_K)
auto tBrB = thr_mma.partition_fragment_B(gB(_, _, 0)); // (MMA, MMA_N, MMA_K)
auto tCrC = thr_mma.partition_fragment_C(gC(_, _)); // (MMA, MMA_M, MMA_N)
clear(tCrC);
int num_tile_k = size<2>(gA);
#pragma unroll 1
for(int itile = 0; itile < num_tile_k; ++itile) {
cute::copy(tAgA(_, _, _, itile), tArA);
cute::copy(tBgB(_, _, _, itile), tBrB);
cute::gemm(tiled_mma, tCrC, tArA, tBrB, tCrC);
}
cute::copy(tCrC, tCgC);
}
int main() {
srand(10086);
using T = cute::half_t;
using namespace cute;
T *Cptr;
T *Aptr;
T *Bptr;
int m = 81920;
int n = 256;
int k = 256;
cudaMalloc(&Cptr, sizeof(T) * m * n);
cudaMalloc(&Aptr, sizeof(T) * m * k);
cudaMalloc(&Bptr, sizeof(T) * k * n);
T *Aptr_host;
T *Bptr_host;
Aptr_host = (T*)malloc(sizeof(T) * m * k);
Bptr_host = (T*)malloc(sizeof(T) * n * k);
gen_rand_data(Aptr_host, m * k);
gen_rand_data(Bptr_host, n * k);
cudaMemcpy(Aptr, Aptr_host, sizeof(T) * m * k, cudaMemcpyHostToDevice);
cudaMemcpy(Bptr, Bptr_host, sizeof(T) * n * k, cudaMemcpyHostToDevice);
using mma_op = SM80_16x8x16_F16F16F16F16_TN;
using mma_traits = MMA_Traits<mma_op>;
using mma_atom = MMA_Atom<mma_traits>;
using MMA = decltype(make_tiled_mma(mma_atom{},
make_layout(Shape<_2, _2, _1>{}),
make_layout(Shape<_1, _2, _1>{})));
constexpr int kTileM = 128;
constexpr int kTileN = 128;
constexpr int kTileK = 32;
dim3 block(size(MMA{}));
dim3 grid(n / kTileN, m / kTileM);
for (int i = 0; i < 100; ++i) {
gemm_simple<T, kTileM, kTileN, kTileK, MMA><<<grid, block>>>(Cptr, Aptr, Bptr, m, n, k);
}
cudaDeviceSynchronize();
auto err = cudaGetLastError();
printf("err = %d, str = %s\n", err, cudaGetErrorString(err));
// cublas
T *Cptr_cublas;
cudaMalloc(&Cptr_cublas, sizeof(T) * m * n);
cublasHandle_t handle;
cublasCreate(&handle);
half alpha = half(1.f);
half beta = half(0.f);
for (int i = 0; i < 100; ++i) {
cublasStatus_t ret = cublasHgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N,
n, m, k,
&alpha,
(half *)Bptr, k,
(half *)Aptr, k,
&beta,
(half *)Cptr_cublas, n);
if (ret != CUBLAS_STATUS_SUCCESS) {
printf("blas err = %d, str = %s\n", ret, cublasGetStatusString(ret));
}
}
cudaDeviceSynchronize();
err = cudaGetLastError();
printf("err = %d, str = %s\n", err, cudaGetErrorString(err));
T *Cptr_host;
T *Cptr_cublas_host;
Cptr_host = (T*)malloc(sizeof(T) * m * n);
Cptr_cublas_host = (T*)malloc(sizeof(T) * m * n);
// compare
cudaMemcpy(Cptr_host, Cptr, sizeof(T) * m * n, cudaMemcpyDeviceToHost);
cudaMemcpy(Cptr_cublas_host, Cptr_cublas, sizeof(T) * m * n, cudaMemcpyDeviceToHost);
float threshold = 0.1;
for (int i = 0; i < m * n; ++i) {
float v1 = Cptr_host[i];
float v2 = Cptr_cublas_host[i];
if (fabs(v2 - v1) > threshold) {
printf("v1 = %f, v2 = %f\n", v1, v2);
}
}
Tensor tensor_C = make_tensor(Cptr_host, make_shape(m, n), make_stride(n, 1));
Tensor tensor_C_cublas = make_tensor(Cptr_cublas_host, make_shape(m, n), make_stride(n, 1));
auto tile = make_tile(8, 8);
auto coor = make_coord(0, 0);
Tensor tc1 = local_tile(tensor_C, tile, coor);
Tensor tc1_cublas = local_tile(tensor_C_cublas, tile, coor);
print_tensor(tc1);
print_tensor(tc1_cublas);
}
template <typename T>
void gen_rand_data(T *data, int n) {
for (int i = 0; i < n; ++i) {
float v = (rand() % 200 - 100) * 0.01;
data[i] = v;
}
}