-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
430 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
// Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. | ||
// All Rights Reserved. See the top-level LICENSE and NOTICE files for details. | ||
// | ||
// SPDX-License-Identifier: BSD-2-Clause | ||
// | ||
// This file is part of CEED: http://github.com/ceed | ||
|
||
#include <ceed.h> | ||
#include <ceed/backend.h> | ||
#include <stdbool.h> | ||
#include <string.h> | ||
|
||
#include "ceed-sve.h" | ||
|
||
//------------------------------------------------------------------------------ | ||
// Backend Init | ||
//------------------------------------------------------------------------------ | ||
static int CeedInit_Sve(const char *resource, Ceed ceed) { | ||
if (strcmp(resource, "/cpu/self") && strcmp(resource, "/cpu/self/sve") && strcmp(resource, "/cpu/self/sve/blocked")) { | ||
// LCOV_EXCL_START | ||
return CeedError(ceed, CEED_ERROR_BACKEND, "SVE backend cannot use resource: %s", resource); | ||
// LCOV_EXCL_STOP | ||
} | ||
CeedCallBackend(CeedSetDeterministic(ceed, true)); | ||
|
||
// Create reference CEED that implementation will be dispatched | ||
// through unless overridden | ||
Ceed ceed_ref; | ||
CeedInit("/cpu/self/opt/blocked", &ceed_ref); | ||
CeedCallBackend(CeedSetDelegate(ceed, ceed_ref)); | ||
|
||
if (CEED_SCALAR_TYPE == CEED_SCALAR_FP64) { | ||
CeedCallBackend(CeedSetBackendFunction(ceed, "Ceed", ceed, "TensorContractCreate", CeedTensorContractCreate_f64_Sve)); | ||
} else { | ||
CeedCallBackend(CeedSetBackendFunction(ceed, "Ceed", ceed, "TensorContractCreate", CeedTensorContractCreate_f32_Sve); | ||
} | ||
|
||
return CEED_ERROR_SUCCESS; | ||
} | ||
|
||
//------------------------------------------------------------------------------ | ||
// Backend Register | ||
//------------------------------------------------------------------------------ | ||
CEED_INTERN int CeedRegister_Sve_Blocked(void) { return CeedRegister("/cpu/self/sve/blocked", CeedInit_Sve, 30); } | ||
//------------------------------------------------------------------------------ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
// Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. | ||
// All Rights Reserved. See the top-level LICENSE and NOTICE files for details. | ||
// | ||
// SPDX-License-Identifier: BSD-2-Clause | ||
// | ||
// This file is part of CEED: http://github.com/ceed | ||
|
||
#include <ceed.h> | ||
#include <ceed/backend.h> | ||
#include <stdbool.h> | ||
#include <string.h> | ||
|
||
#include "ceed-sve.h" | ||
|
||
//------------------------------------------------------------------------------ | ||
// Backend Init | ||
//------------------------------------------------------------------------------ | ||
static int CeedInit_Sve(const char *resource, Ceed ceed) { | ||
if (strcmp(resource, "/cpu/self") && strcmp(resource, "/cpu/self/sve/serial")) { | ||
// LCOV_EXCL_START | ||
return CeedError(ceed, CEED_ERROR_BACKEND, "SVE backend cannot use resource: %s", resource); | ||
// LCOV_EXCL_STOP | ||
} | ||
CeedCallBackend(CeedSetDeterministic(ceed, true)); | ||
|
||
// Create reference CEED that implementation will be dispatched | ||
// through unless overridden | ||
Ceed ceed_ref; | ||
CeedInit("/cpu/self/opt/serial", &ceed_ref); | ||
CeedCallBackend(CeedSetDelegate(ceed, ceed_ref)); | ||
|
||
if (CEED_SCALAR_TYPE == CEED_SCALAR_FP64) { | ||
CeedCallBackend(CeedSetBackendFunction(ceed, "Ceed", ceed, "TensorContractCreate", CeedTensorContractCreate_f64_Sve)); | ||
} else { | ||
CeedCallBackend(CeedSetBackendFunction(ceed, "Ceed", ceed, "TensorContractCreate", CeedTensorContractCreate_f32_Sve)); | ||
} | ||
|
||
return CEED_ERROR_SUCCESS; | ||
} | ||
|
||
//------------------------------------------------------------------------------ | ||
// Backend Register | ||
//------------------------------------------------------------------------------ | ||
CEED_INTERN int CeedRegister_Sve_Serial(void) { return CeedRegister("/cpu/self/sve/serial", CeedInit_Sve, 35); } | ||
//------------------------------------------------------------------------------ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
// Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. | ||
// All Rights Reserved. See the top-level LICENSE and NOTICE files for details. | ||
// | ||
// SPDX-License-Identifier: BSD-2-Clause | ||
// | ||
// This file is part of CEED: http://github.com/ceed | ||
|
||
#include <ceed.h> | ||
#include <ceed/backend.h> | ||
#ifdef __ARM_FEATURE_SVE | ||
#include <arm_sve.h> | ||
#endif | ||
#include <stdbool.h> | ||
|
||
#include "ceed-sve.h" | ||
|
||
//------------------------------------------------------------------------------ | ||
// Blocked Tensor Contract | ||
//------------------------------------------------------------------------------ | ||
static inline int CeedTensorContract_Sve_Blocked(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const float *restrict t, | ||
CeedTransposeMode t_mode, const CeedInt add, const float *restrict u, float *restrict v, | ||
const CeedInt JJ) { | ||
CeedInt t_stride_0 = B, t_stride_1 = 1; | ||
if (t_mode == CEED_TRANSPOSE) { | ||
t_stride_0 = 1; | ||
t_stride_1 = J; | ||
} | ||
|
||
for (CeedInt a = 0; a < A; a++) { | ||
for (CeedInt b = 0; b < B; b++) { | ||
// Blocks of JJ rows | ||
for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) { | ||
for (CeedInt jj = 0; jj < JJ; jj++) { // unroll | ||
// C vectorization by compiler | ||
for (int32_t c = 0; c < C; c += svcntd()) { | ||
svbool_t pg = svwhilelt_b32(c, C); | ||
// Load u, v into vectors | ||
svfloat32_t u_vec = svld1(pg, &u[(a * B + b) * C + c]); | ||
svfloat32_t v_vec = svld1(pg, &v[(a * J + j + jj) * C + c]); | ||
// Basis matrix value | ||
float tq = t[(j + jj) * t_stride_0 + b * t_stride_1]; | ||
// fmadd | ||
svst1(pg, &v[(a * J + j + jj) * C + c], svmla_x(pg, v_vec, u_vec, tq)); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
// Remainder of rows | ||
CeedInt j = (J / JJ) * JJ; | ||
if (j < J) { | ||
for (CeedInt a = 0; a < A; a++) { | ||
for (CeedInt b = 0; b < B; b++) { | ||
// Blocks of JJ rows | ||
for (CeedInt jj = 0; jj < J - j; jj++) { // not unrolled | ||
// C vectorization by compiler | ||
for (int32_t c = 0; c < C; c += svcntd()) { | ||
svbool_t pg = svwhilelt_b32(c, C); | ||
// Load u, v into vectors | ||
svfloat32_t u_vec = svld1(pg, &u[(a * B + b) * C + c]); | ||
svfloat32_t v_vec = svld1(pg, &v[(a * J + j + jj) * C + c]); | ||
// Basis matrix value | ||
float tq = t[(j + jj) * t_stride_0 + b * t_stride_1]; | ||
// fmadd | ||
svst1(pg, &v[(a * J + j + jj) * C + c], svmla_x(pg, v_vec, u_vec, tq)); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
return CEED_ERROR_SUCCESS; | ||
} | ||
|
||
//------------------------------------------------------------------------------ | ||
// Blocked Tensor Contract | ||
//------------------------------------------------------------------------------ | ||
static inline int CeedTensorContract_Sve_Serial(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const float *restrict t, | ||
CeedTransposeMode t_mode, const CeedInt add, const float *restrict u, float *restrict v, | ||
const CeedInt JJ) { | ||
CeedInt t_stride_0 = B, t_stride_1 = 1; | ||
if (t_mode == CEED_TRANSPOSE) { | ||
t_stride_0 = 1; | ||
t_stride_1 = J; | ||
} | ||
|
||
for (CeedInt a = 0; a < A; a++) { | ||
for (CeedInt b = 0; b < B; b++) { | ||
for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) { | ||
for (CeedInt jj = 0; jj < JJ; jj++) { // unroll | ||
v[a * J + (j + jj)] += t[(j + jj) * t_stride_0 + b * t_stride_1] * u[a * B + b]; | ||
} | ||
} | ||
} | ||
} | ||
|
||
CeedInt j = (J / JJ) * JJ; | ||
if (j < J) { | ||
for (CeedInt a = 0; a < A; a++) { | ||
for (CeedInt b = 0; b < B; b++) { | ||
for (CeedInt jj = 0; jj < J - j; jj++) { // not unrolled | ||
v[a * J + (j + jj)] += t[(j + jj) * t_stride_0 + b * t_stride_1] * u[a * B + b]; | ||
} | ||
} | ||
} | ||
} | ||
|
||
return CEED_ERROR_SUCCESS; | ||
} | ||
|
||
//------------------------------------------------------------------------------ | ||
// Tensor Contract - Common Sizes | ||
//------------------------------------------------------------------------------ | ||
static int CeedTensorContract_Sve_Blocked_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const float *restrict t, | ||
CeedTransposeMode t_mode, const CeedInt add, const float *restrict u, float *restrict v) { | ||
return CeedTensorContract_Sve_Blocked(contract, A, B, C, J, t, t_mode, add, u, v, 8); | ||
} | ||
static int CeedTensorContract_Sve_Serial_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const float *restrict t, | ||
CeedTransposeMode t_mode, const CeedInt add, const float *restrict u, float *restrict v) { | ||
return CeedTensorContract_Sve_Serial(contract, A, B, C, J, t, t_mode, add, u, v, 8); | ||
} | ||
|
||
//------------------------------------------------------------------------------ | ||
// Tensor Contract Apply | ||
//------------------------------------------------------------------------------ | ||
static int CeedTensorContractApply_Sve(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const float *restrict t, | ||
CeedTransposeMode t_mode, const CeedInt add, const float *restrict u, float *restrict v) { | ||
if (!add) { | ||
for (CeedInt q = 0; q < A * J * C; q++) v[q] = (float)0.0; | ||
} | ||
|
||
if (C == 1) CeedTensorContract_Sve_Serial_8(contract, A, B, C, J, t, t_mode, true, u, v); | ||
else CeedTensorContract_Sve_Blocked_8(contract, A, B, C, J, t, t_mode, true, u, v); | ||
|
||
return CEED_ERROR_SUCCESS; | ||
} | ||
|
||
//------------------------------------------------------------------------------ | ||
// Tensor Contract Create | ||
//------------------------------------------------------------------------------ | ||
int CeedTensorContractCreate_f32_Sve(CeedBasis basis, CeedTensorContract contract) { | ||
Ceed ceed; | ||
CeedCallBackend(CeedTensorContractGetCeed(contract, &ceed)); | ||
|
||
CeedCallBackend(CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply", CeedTensorContractApply_Sve)); | ||
|
||
return CEED_ERROR_SUCCESS; | ||
} | ||
//------------------------------------------------------------------------------ |
Oops, something went wrong.