Skip to content

Commit

Permalink
Merge pull request baidu-research#2 from sharan/volta_updates
Browse files Browse the repository at this point in the history
Deepbench updates for volta.
  • Loading branch information
Sharan Narang authored and GitHub Enterprise committed Nov 29, 2017
2 parents 5abbfc6 + ec79df0 commit 0cb389f
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 25 deletions.
7 changes: 4 additions & 3 deletions code/nvidia/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ CONV_LIBRARY?=cudnn
CONV_PATH?=$
KERNELS_DIR=../kernels/
PAD_KERNELS?=1
USE_TENSOR_CORES?=0
COMMA=,
NVCC_ARCH_ARGS=$(foreach a,$(subst $(COMMA), ,$(ARCH)),--generate-code arch=$(patsubst sm_%,compute_%,$(a)),code=$(a))

Expand All @@ -26,15 +27,15 @@ all: gemm conv rnn all_reduce

gemm:
$(MKDIR) $(BIN_DIR)
$(CUDA_PATH)/bin/$(NVCC) gemm_bench.cu -DPAD_KERNELS=$(PAD_KERNELS) -o $(BIN_DIR)/gemm_bench -I $(KERNELS_DIR) -I $(CUDA_PATH)/include -L $(BLAS_PATH) -l$(BLAS_LIBRARY) -L $(CUDA_LIB64) -lcurand $(NVCC_ARCH_ARGS) -std=c++11
$(CUDA_PATH)/bin/$(NVCC) gemm_bench.cu -DUSE_TENSOR_CORES=$(USE_TENSOR_CORES) -DPAD_KERNELS=$(PAD_KERNELS) -o $(BIN_DIR)/gemm_bench -I $(KERNELS_DIR) -I $(CUDA_PATH)/include -L $(BLAS_PATH) -l$(BLAS_LIBRARY) -L $(CUDA_LIB64) -lcurand $(NVCC_ARCH_ARGS) -std=c++11

conv:
$(MKDIR) $(BIN_DIR)
$(CUDA_PATH)/bin/$(NVCC) conv_bench.cu -DPAD_KERNELS=$(PAD_KERNELS) -o $(BIN_DIR)/conv_bench -I $(KERNELS_DIR) -I $(CUDA_PATH)/include -I $(CUDNN_PATH)/include/ -L $(CUDNN_PATH)/lib64/ -L $(CUDA_LIB64) -lcurand -lcudnn $(NVCC_ARCH_ARGS) -std=c++11
$(CUDA_PATH)/bin/$(NVCC) conv_bench.cu -DUSE_TENSOR_CORES=$(USE_TENSOR_CORES) -DPAD_KERNELS=$(PAD_KERNELS) -o $(BIN_DIR)/conv_bench -I $(KERNELS_DIR) -I $(CUDA_PATH)/include -I $(CUDNN_PATH)/include/ -L $(CUDNN_PATH)/lib64/ -L $(CUDA_LIB64) -lcurand -lcudnn $(NVCC_ARCH_ARGS) -std=c++11

rnn:
$(MKDIR) $(BIN_DIR)
$(CUDA_PATH)/bin/$(NVCC) rnn_bench.cu -o $(BIN_DIR)/rnn_bench -I $(KERNELS_DIR) -I $(CUDA_PATH)/include -I $(CUDNN_PATH)/include/ -L $(CUDNN_PATH)/lib64/ -L $(CUDA_LIB64) -lcurand -lcudnn $(NVCC_ARCH_ARGS) -std=c++11
$(CUDA_PATH)/bin/$(NVCC) rnn_bench.cu -DUSE_TENSOR_CORES=$(USE_TENSOR_CORES) -o $(BIN_DIR)/rnn_bench -I $(KERNELS_DIR) -I $(CUDA_PATH)/include -I $(CUDNN_PATH)/include/ -L $(CUDNN_PATH)/lib64/ -L $(CUDA_LIB64) -lcurand -lcudnn $(NVCC_ARCH_ARGS) -std=c++11

all_reduce: nccl_single nccl_mpi

Expand Down
69 changes: 57 additions & 12 deletions code/nvidia/conv_bench.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@
#define PAD_KERNELS 1
#endif

#ifndef USE_TENSOR_CORES
#if CUDNN_MAJOR >= 7
#define USE_TENSOR_CORES 1
#else
#define USE_TENSOR_CORES 0
#endif
#endif


/*
Usage:
Expand Down Expand Up @@ -104,6 +112,9 @@ public:
x_desc_ = TensorDescriptor4d<T1>(format, n, c, h, w);
w_desc_ = FilterDescriptor4d<T1>(format, k, c, r, s);

#if (CUDNN_MAJOR >= 7) && (USE_TENSOR_CORES)
cudnnSetConvolutionMathType(conv_desc_.desc(), CUDNN_TENSOR_OP_MATH);
#endif
// Get output dimensions
CHECK_CUDNN_ERROR(cudnnGetConvolution2dForwardOutputDim(conv_desc_.desc(),
x_desc_.desc(),
Expand All @@ -113,11 +124,7 @@ public:
&out_h,
&out_w));

if (std::is_same<T1, uint8_t>::value) {
h_desc_ = TensorDescriptor4d<T1>(CUDNN_TENSOR_NHWC, out_n, out_c, out_h, out_w);
} else {
h_desc_ = TensorDescriptor4d<T1>(CUDNN_TENSOR_NCHW, out_n, out_c, out_h, out_w);
}
h_desc_ = TensorDescriptor4d<T1>(format, out_n, out_c, out_h, out_w);

output_dims_ = {out_w, out_h, out_c, out_n};

Expand Down Expand Up @@ -155,6 +162,10 @@ public:
&fwd_perf));
fwd_algo_ = fwd_perf.algo;
}
#endif
#if (CUDNN_MAJOR >= 7) && (USE_TENSOR_CORES)
// Tensor Op math only supports IMPLICIT_PRECOMP_GEMM algorithm
fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
#endif
if (std::is_same<T1, uint8_t>::value) {
//Note: cudnn workspace size function doesn't work for INT8_CONFIG
Expand Down Expand Up @@ -201,6 +212,11 @@ public:
&filter_perf));
bwd_params_algo_ = filter_perf.algo;
#endif
#if (CUDNN_MAJOR >= 7) && (USE_TENSOR_CORES)
// Tensor Op math only supports this algorithm.
bwd_params_algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
#endif

// Backward params workspace
CHECK_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle_.handle(),
x_desc_.desc(),
Expand Down Expand Up @@ -236,6 +252,11 @@ public:
&data_perf));
bwd_inputs_algo_ = data_perf.algo;
#endif
#if (CUDNN_MAJOR >= 7) && (USE_TENSOR_CORES)
//Tensor Op math only supports this algorithm.
bwd_inputs_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
#endif

CHECK_CUDNN_ERROR(cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle_.handle(),
w_desc_.desc(),
h_desc_.desc(),
Expand Down Expand Up @@ -475,7 +496,7 @@ int main(int argc, char **argv) {
std::cout << "total_time (usec)";
}

if (PAD_KERNELS && precision == "int8" && inference)
if (PAD_KERNELS && ((precision == "int8" && inference) || (USE_TENSOR_CORES && !inference)))
std::cout << " pad_kerenels ";

std::cout << " fwd_algo " << std::endl;
Expand Down Expand Up @@ -506,24 +527,41 @@ int main(int argc, char **argv) {

#if CUDNN_MAJOR >= 6
int padded_c, padded_w, padded_h;
int pad_value;

padded_c = c;
padded_h = h;
padded_w = w;

if (precision == "int8") {
if (c%4 || w%4 || h%4) {
pad_value = 4;
if (c % pad_value || w % pad_value || h % pad_value) {
pad_kernels_count++;
if (PAD_KERNELS) {
pad_dim(padded_c);
pad_dim(padded_h);
pad_dim(padded_w);
pad_dim(padded_c, pad_value);
pad_dim(padded_h, pad_value);
pad_dim(padded_w, pad_value);
need_padding = true;
} else {
skip_kernel = true;
}
}
}
#if (USE_TENSOR_CORES)
// Tensor cores need channels to be a multiple of 8. So, added padding for some kernels.
if (!inference) {
pad_value = 8;
if (c % pad_value) {
pad_kernels_count++;
if (PAD_KERNELS) {
pad_dim(padded_c, pad_value);
need_padding = true;
} else {
skip_kernel = true;
}
}
}
#endif
#endif

int fwd_time, bwd_inputs_time, bwd_params_time;
Expand All @@ -535,10 +573,10 @@ int main(int argc, char **argv) {
#if CUDNN_MAJOR >= 6
if (precision == "float") {
std::tie(fwd_time, bwd_inputs_time, bwd_params_time, fwd_algo_s) =
time_cnn<float, float>(k, c, r, s, n, h, w, pad_h, pad_w, hstride, wstride, num_repeats, curand_gen, inference);
time_cnn<float, float>(k, padded_c, r, s, n, padded_h, padded_w, pad_h, pad_w, hstride, wstride, num_repeats, curand_gen, inference);
} else if (precision == "half") {
std::tie(fwd_time, bwd_inputs_time, bwd_params_time, fwd_algo_s) =
time_cnn<uint16_t, uint16_t>(k, c, r, s, n, h, w, pad_h, pad_w, hstride, wstride, num_repeats, curand_gen, inference);
time_cnn<uint16_t, uint16_t>(k, padded_c, r, s, n, padded_h, padded_w, pad_h, pad_w, hstride, wstride, num_repeats, curand_gen, inference);
} else if ((precision == "int8") && inference) {
if (!skip_kernel) {
std::tie(fwd_time, bwd_inputs_time, bwd_params_time, fwd_algo_s) =
Expand Down Expand Up @@ -578,12 +616,19 @@ int main(int argc, char **argv) {
std::cout << std::setw(15) << need_padding;
}



if (!inference) {
std::cout << std::setw(24) << std::setprecision(7) << bwd_inputs_time;
std::cout << std::setw(24) << std::setprecision(7) << bwd_params_time;
std::cout << std::setw(19) << std::setprecision(8) << fwd_time + bwd_inputs_time + bwd_params_time;
}

if (USE_TENSOR_CORES && PAD_KERNELS && !inference) {
std::cout << std::setw(15) << need_padding;
}


std::cout << std::setw(25) << fwd_algo_s;
std::cout << std::endl;
}
Expand Down
13 changes: 11 additions & 2 deletions code/nvidia/cudnn_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ class RNNDescriptor {
RNNDescriptor() {}
RNNDescriptor(int hidden_size, int num_layers, cudnnDropoutDescriptor_t dropout_desc,
cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t direction,
std::string rnn_type) {
std::string rnn_type, cudnnHandle_t cudnn_handle) {
cudnnDataType_t type;
if (std::is_same<T, float>::value)
type = CUDNN_DATA_FLOAT;
Expand All @@ -346,12 +346,18 @@ class RNNDescriptor {
else
throw std::runtime_error("Unknown rnn type");

#if CUDNN_MAJOR >= 7
cudnnRNNAlgo_t rnn_algo = CUDNN_RNN_ALGO_STANDARD;
#endif

cudnnRNNDescriptor_t * desc = new cudnnRNNDescriptor_t;

CHECK_CUDNN_ERROR(cudnnCreateRNNDescriptor(desc));


#if CUDNN_MAJOR >= 7
CHECK_CUDNN_ERROR(cudnnSetRNNDescriptor_v5(*desc,
CHECK_CUDNN_ERROR(cudnnSetRNNDescriptor(cudnn_handle,
*desc,
#else
CHECK_CUDNN_ERROR(cudnnSetRNNDescriptor(*desc,
#endif
Expand All @@ -361,6 +367,9 @@ class RNNDescriptor {
input_mode,
direction,
rnn_mode,
#if CUDNN_MAJOR >= 7
rnn_algo,
#endif
type));

desc_.reset(desc, RNNDescriptorDeleter());
Expand Down
34 changes: 30 additions & 4 deletions code/nvidia/gemm_bench.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@
#define PAD_KERNELS 1
#endif

#ifndef USE_TENSOR_CORES
#if __CUDACC_VER_MAJOR__ > 8
#define USE_TENSOR_CORES 1
#else
#define USE_TENSOR_CORES 0
#endif
#endif

/*
Usage:
Expand Down Expand Up @@ -67,19 +75,21 @@ int time_gemm(Tensor<T1> A, Tensor<T1> B, Tensor<T2> C, bool a_t, bool b_t, cubl
int k = a_t ? A.dims()[0] : A.dims()[1];
int n = C.dims()[1];

int numRepeats = std::max(std::ceil(1e11 / (m * k * n)), 10.);
int numRepeats = 400;
cublasStatus_t stat;

#if (__CUDACC_VER_MAJOR__ >= 8)
cudaDataType_t A_type = CUDA_R_32F;
cudaDataType_t B_type = CUDA_R_32F;
cudaDataType_t C_type = CUDA_R_32F;
cudaDataType_t compute_type = CUDA_R_32F;
cublasGemmAlgo_t algo;

if (std::is_same<T1, uint16_t>::value) {
A_type = CUDA_R_16F;
B_type = CUDA_R_16F;
C_type = CUDA_R_16F;
compute_type = CUDA_R_16F;
}

if (std::is_same<T1, uint8_t>::value) {
Expand All @@ -89,6 +99,12 @@ int time_gemm(Tensor<T1> A, Tensor<T1> B, Tensor<T2> C, bool a_t, bool b_t, cubl
compute_type = CUDA_R_32I;
}

#if (USE_TENSOR_CORES)
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
#else
algo = CUBLAS_GEMM_DFALT;
#endif

#endif

#if (__CUDACC_VER_MAJOR__ < 8)
Expand Down Expand Up @@ -117,7 +133,7 @@ int time_gemm(Tensor<T1> A, Tensor<T1> B, Tensor<T2> C, bool a_t, bool b_t, cubl
&beta,
C.begin(), C_type, C.dims()[0],
compute_type,
CUBLAS_GEMM_DFALT);
algo);
#endif

if (stat != CUBLAS_STATUS_SUCCESS) {
Expand Down Expand Up @@ -155,7 +171,7 @@ int time_gemm(Tensor<T1> A, Tensor<T1> B, Tensor<T2> C, bool a_t, bool b_t, cubl
&beta,
C.begin(), C_type, C.dims()[0],
compute_type,
CUBLAS_GEMM_DFALT);
algo);
#endif
if (stat != CUBLAS_STATUS_SUCCESS) {
throw std::runtime_error("sgemm failed");
Expand Down Expand Up @@ -197,6 +213,16 @@ int main(int argc, char **argv) {
std::cout << "CUBLAS init failed" << std::endl;
}

#if (USE_TENSOR_CORES) && (__CUDACC_VER_MAJOR__ > 8)
status = cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH);
#endif

if (status != CUBLAS_STATUS_SUCCESS) {
std::cout << "CUBLAS math mode failed" << std::endl;
}



curandGenerator_t curand_gen;

curandCreateGenerator(&curand_gen, CURAND_RNG_PSEUDO_DEFAULT);
Expand Down Expand Up @@ -237,7 +263,7 @@ int main(int argc, char **argv) {
if (pad_m%4) {
pad_kernels_count++;
if (PAD_KERNELS) {
pad_dim(pad_m);
pad_dim(pad_m, 4);
need_padding = true;
} else {
skip_kernel = true;
Expand Down
15 changes: 14 additions & 1 deletion code/nvidia/rnn_bench.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ float, half, int8 for inference
*/

#ifndef USE_TENSOR_CORES
#if CUDNN_MAJOR >= 7
#define USE_TENSOR_CORES 1
#else
#define USE_TENSOR_CORES 0
#endif
#endif


cudnnHandle_t cudnn_handle;
curandGenerator_t curand_gen;
Expand Down Expand Up @@ -135,7 +143,8 @@ class cudnnRNN {
dropout_.desc(),
CUDNN_SKIP_INPUT,
CUDNN_UNIDIRECTIONAL,
rnn_type);
rnn_type,
cudnn_handle);
cudnnDataType_t type;
if (std::is_same<T, float>::value)
type = CUDNN_DATA_FLOAT;
Expand All @@ -154,6 +163,10 @@ class cudnnRNN {
&weight_size_,
type) );

#if (CUDNN_MAJOR >= 7) && (USE_TENSOR_CORES)
CHECK_CUDNN_ERROR( cudnnSetRNNMatrixMathType(rnn_desc_.desc(), CUDNN_TENSOR_OP_MATH) );
#endif

weights_ = rand<T>(std::vector<int>{static_cast<int>(weight_size_ / sizeof(T)), 1}, curand_gen);


Expand Down
7 changes: 4 additions & 3 deletions code/nvidia/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@ rand(std::vector<int> dims, curandGenerator_t curand_gen) {
return tensor;
}

void pad_dim(int & dim) {
if (dim % 4) {
int pad = 4 - dim%4;
void pad_dim(int & dim, int pad_v) {
assert(pad_v > 0);
if (dim % pad_v) {
int pad = pad_v - dim%pad_v;
dim += pad;
}
}
Binary file added results/train/DeepBench_NV_V100.xlsx
Binary file not shown.

0 comments on commit 0cb389f

Please sign in to comment.