Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Remove all ctx parameters in memory operators #5862

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions python/pyabacus/src/hsolver/py_diago_cg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,6 @@ class PyDiagoCG
const int nrow = ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1);
const int nbands = ndim == 1 ? 1 : psi_in.shape().dim_size(0);
syncmem_z2z_h2h_op()(
this->ctx,
this->ctx,
spsi_out.data<std::complex<double>>(),
psi_in.data<std::complex<double>>(),
static_cast<size_t>(nrow * nbands)
Expand Down
2 changes: 1 addition & 1 deletion python/pyabacus/src/hsolver/py_diago_david.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class PyDiagoDavid
const int nrow,
const int nbands
) {
syncmem_op()(this->ctx, this->ctx, spsi_out, psi_in, static_cast<size_t>(nbands * nrow));
syncmem_op()(spsi_out, psi_in, static_cast<size_t>(nbands * nrow));
};

obj = std::make_unique<hsolver::DiagoDavid<std::complex<double>, base_device::DEVICE_CPU>>(
Expand Down
8 changes: 4 additions & 4 deletions source/module_base/kernels/dsp/dsp_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase, int nbase_x, int notconv

auto* swap = new T[notconv * nbase_x];
auto* target = new T[notconv * nbase_x];
syncmem_complex_op()(cpu_ctx, cpu_ctx, swap, hcc + nbase * nbase_x, notconv * nbase_x);
syncmem_complex_op()(swap, hcc + nbase * nbase_x, notconv * nbase_x);
if (base_device::get_current_precision(swap) == "single")
{
MPI_Reduce(swap,
Expand All @@ -97,8 +97,8 @@ void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase, int nbase_x, int notconv
diag_comm);
}

syncmem_complex_op()(cpu_ctx, cpu_ctx, hcc + nbase * nbase_x, target, notconv * nbase_x);
syncmem_complex_op()(cpu_ctx, cpu_ctx, swap, scc + nbase * nbase_x, notconv * nbase_x);
syncmem_complex_op()(hcc + nbase * nbase_x, target, notconv * nbase_x);
syncmem_complex_op()(swap, scc + nbase * nbase_x, notconv * nbase_x);

if (base_device::get_current_precision(swap) == "single")
{
Expand All @@ -121,7 +121,7 @@ void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase, int nbase_x, int notconv
diag_comm);
}

syncmem_complex_op()(cpu_ctx, cpu_ctx, scc + nbase * nbase_x, target, notconv * nbase_x);
syncmem_complex_op()(scc + nbase * nbase_x, target, notconv * nbase_x);
delete[] swap;
delete[] target;
}
Expand Down
20 changes: 10 additions & 10 deletions source/module_base/kernels/test/math_op_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,13 +306,13 @@ TEST_F(TestModuleBaseMathMultiDevice, cal_ylm_real_op_gpu)
std::vector<double> ylm(expected_ylm.size(), 0.0);
double * d_ylm = nullptr, * d_g = nullptr, * d_p = nullptr;

resmem_var_op()(gpu_ctx, d_g, g.size());
resmem_var_op()(gpu_ctx, d_p, p.size());
resmem_var_op()(gpu_ctx, d_ylm, ylm.size());
resmem_var_op()(d_g, g.size());
resmem_var_op()(d_p, p.size());
resmem_var_op()(d_ylm, ylm.size());

syncmem_var_h2d_op()(gpu_ctx, cpu_ctx, d_g, g.data(), g.size());
syncmem_var_h2d_op()(gpu_ctx, cpu_ctx, d_p, p.data(), p.size());
syncmem_var_h2d_op()(gpu_ctx, cpu_ctx, d_ylm, ylm.data(), ylm.size());
syncmem_var_h2d_op()(d_g, g.data(), g.size());
syncmem_var_h2d_op()(d_p, p.data(), p.size());
syncmem_var_h2d_op()(d_ylm, ylm.data(), ylm.size());

ModuleBase::cal_ylm_real_op<double, base_device::DEVICE_GPU>()(gpu_ctx,
ng,
Expand All @@ -326,15 +326,15 @@ TEST_F(TestModuleBaseMathMultiDevice, cal_ylm_real_op_gpu)
d_p,
d_ylm);

syncmem_var_d2h_op()(cpu_ctx, gpu_ctx, ylm.data(), d_ylm, ylm.size());
syncmem_var_d2h_op()(ylm.data(), d_ylm, ylm.size());

for (int ii = 0; ii < ylm.size(); ii++) {
EXPECT_LT(fabs(ylm[ii] - expected_ylm[ii]), 6e-5);
}

delmem_var_op()(gpu_ctx, d_g);
delmem_var_op()(gpu_ctx, d_p);
delmem_var_op()(gpu_ctx, d_ylm);
delmem_var_op()(d_g);
delmem_var_op()(d_p);
delmem_var_op()(d_ylm);
}

#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
88 changes: 44 additions & 44 deletions source/module_base/math_chebyshev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ Chebyshev<REAL, Device>::Chebyshev(const int norder_in) : fftw(2 * EXTEND * nord
coefc_cpu = new std::complex<REAL>[norder];
if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
{
resmem_var_op()(this->ctx, this->coef_real, norder);
resmem_complex_op()(this->ctx, this->coef_complex, norder);
resmem_var_op()(this->coef_real, norder);
resmem_complex_op()(this->coef_complex, norder);
}
else
{
Expand All @@ -84,8 +84,8 @@ Chebyshev<REAL, Device>::~Chebyshev()
delete[] polytrace;
if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
{
delmem_var_op()(this->ctx, this->coef_real);
delmem_complex_op()(this->ctx, this->coef_complex);
delmem_var_op()(this->coef_real);
delmem_complex_op()(this->coef_complex);
}
else
{
Expand Down Expand Up @@ -129,29 +129,29 @@ REAL Chebyshev<REAL, Device>::ddot_real(const std::complex<REAL>* psi_L,
pL = (REAL*)psi_L;
pR = (REAL*)psi_R;
REAL* dot_device = nullptr;
resmem_var_op()(this->ctx, dot_device, 1);
resmem_var_op()(dot_device, 1);
container::kernels::blas_dot<REAL, ct_Device>()(dim2, pL, 1, pR, 1, dot_device);
syncmem_var_d2h_op()(cpu_ctx, this->ctx, &result, dot_device, 1);
delmem_var_op()(this->ctx, dot_device);
syncmem_var_d2h_op()(&result, dot_device, 1);
delmem_var_op()(dot_device);
}
else
{
REAL *pL, *pR;
pL = (REAL*)psi_L;
pR = (REAL*)psi_R;
REAL* dot_device = nullptr;
resmem_var_op()(this->ctx, dot_device, 1);
resmem_var_op()(dot_device, 1);
for (int i = 0; i < m; ++i)
{
int dim2 = 2 * N;
container::kernels::blas_dot<REAL, ct_Device>()(dim2, pL, 1, pR, 1, dot_device);
REAL result_temp = 0;
syncmem_var_d2h_op()(cpu_ctx, this->ctx, &result_temp, dot_device, 1);
syncmem_var_d2h_op()(&result_temp, dot_device, 1);
result += result_temp;
pL += 2 * LDA;
pR += 2 * LDA;
}
delmem_var_op()(this->ctx, dot_device);
delmem_var_op()(dot_device);
}
return result;
}
Expand Down Expand Up @@ -211,7 +211,7 @@ void Chebyshev<REAL, Device>::calcoef_real(std::function<REAL(REAL)> fun)

if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
{
syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, coef_real, coefr_cpu, norder);
syncmem_var_h2d_op()(coef_real, coefr_cpu, norder);
}

getcoef_real = true;
Expand Down Expand Up @@ -301,7 +301,7 @@ void Chebyshev<REAL, Device>::calcoef_complex(std::function<std::complex<REAL>(s
}
if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
{
syncmem_complex_h2d_op()(this->ctx, this->cpu_ctx, coef_complex, coefc_cpu, norder);
syncmem_complex_h2d_op()(coef_complex, coefc_cpu, norder);
}

getcoef_complex = true;
Expand Down Expand Up @@ -392,7 +392,7 @@ void Chebyshev<REAL, Device>::calcoef_pair(std::function<REAL(REAL)> fun1, std::

if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
{
syncmem_complex_h2d_op()(this->ctx, this->cpu_ctx, coef_complex, coefc_cpu, norder);
syncmem_complex_h2d_op()(coef_complex, coefc_cpu, norder);
}

getcoef_complex = true;
Expand Down Expand Up @@ -427,17 +427,17 @@ void Chebyshev<REAL, Device>::calfinalvec_real(
ndmxt = LDA * m;
}

resmem_complex_op()(this->ctx, arraynp1, ndmxt);
resmem_complex_op()(this->ctx, arrayn, ndmxt);
resmem_complex_op()(this->ctx, arrayn_1, ndmxt);
resmem_complex_op()(arraynp1, ndmxt);
resmem_complex_op()(arrayn, ndmxt);
resmem_complex_op()(arrayn_1, ndmxt);

memcpy_complex_op()(this->ctx, this->ctx, arrayn_1, wavein, ndmxt);
memcpy_complex_op()(arrayn_1, wavein, ndmxt);
// ModuleBase::GlobalFunc::DCOPY(wavein, arrayn_1, ndmxt);

funA(arrayn_1, arrayn, m);

// 0- & 1-st order
setmem_complex_op()(this->ctx, waveout, 0, ndmxt);
setmem_complex_op()(waveout, 0, ndmxt);
std::complex<REAL> coef0 = std::complex<REAL>(coefr_cpu[0], 0);
container::kernels::blas_axpy<std::complex<REAL>, ct_Device>()(ndmxt, &coef0, arrayn_1, 1, waveout, 1);
std::complex<REAL> coef1 = std::complex<REAL>(coefr_cpu[1], 0);
Expand All @@ -462,9 +462,9 @@ void Chebyshev<REAL, Device>::calfinalvec_real(
arrayn = arraynp1;
arraynp1 = tem;
}
delmem_complex_op()(this->ctx, arraynp1);
delmem_complex_op()(this->ctx, arrayn);
delmem_complex_op()(this->ctx, arrayn_1);
delmem_complex_op()(arraynp1);
delmem_complex_op()(arrayn);
delmem_complex_op()(arrayn_1);
return;
}

Expand Down Expand Up @@ -496,16 +496,16 @@ void Chebyshev<REAL, Device>::calfinalvec_complex(
ndmxt = LDA * m;
}

resmem_complex_op()(this->ctx, arraynp1, ndmxt);
resmem_complex_op()(this->ctx, arrayn, ndmxt);
resmem_complex_op()(this->ctx, arrayn_1, ndmxt);
resmem_complex_op()(arraynp1, ndmxt);
resmem_complex_op()(arrayn, ndmxt);
resmem_complex_op()(arrayn_1, ndmxt);

memcpy_complex_op()(this->ctx, this->ctx, arrayn_1, wavein, ndmxt);
memcpy_complex_op()(arrayn_1, wavein, ndmxt);

funA(arrayn_1, arrayn, m);

// 0- & 1-st order
setmem_complex_op()(this->ctx, waveout, 0, ndmxt);
setmem_complex_op()(waveout, 0, ndmxt);
container::kernels::blas_axpy<std::complex<REAL>, ct_Device>()(ndmxt, &coefc_cpu[0], arrayn_1, 1, waveout, 1);
container::kernels::blas_axpy<std::complex<REAL>, ct_Device>()(ndmxt, &coefc_cpu[1], arrayn, 1, waveout, 1);
// for (int i = 0; i < ndmxt; ++i)
Expand All @@ -527,9 +527,9 @@ void Chebyshev<REAL, Device>::calfinalvec_complex(
arrayn = arraynp1;
arraynp1 = tem;
}
delmem_complex_op()(this->ctx, arraynp1);
delmem_complex_op()(this->ctx, arrayn);
delmem_complex_op()(this->ctx, arrayn_1);
delmem_complex_op()(arraynp1);
delmem_complex_op()(arrayn);
delmem_complex_op()(arrayn_1);
return;
}

Expand All @@ -553,7 +553,7 @@ void Chebyshev<REAL, Device>::calpolyvec_complex(
std::complex<REAL>*tmpin = wavein, *tmpout = arrayn_1;
for (int i = 0; i < m; ++i)
{
memcpy_complex_op()(this->ctx, this->ctx, tmpout, tmpin, N);
memcpy_complex_op()(tmpout, tmpin, N);
// ModuleBase::GlobalFunc::DCOPY(tmpin, tmpout, N);
tmpin += LDA;
tmpout += LDA;
Expand Down Expand Up @@ -595,11 +595,11 @@ void Chebyshev<REAL, Device>::tracepolyA(
ndmxt = LDA * m;
}

resmem_complex_op()(this->ctx, arraynp1, ndmxt);
resmem_complex_op()(this->ctx, arrayn, ndmxt);
resmem_complex_op()(this->ctx, arrayn_1, ndmxt);
resmem_complex_op()(arraynp1, ndmxt);
resmem_complex_op()(arrayn, ndmxt);
resmem_complex_op()(arrayn_1, ndmxt);

memcpy_complex_op()(this->ctx, this->ctx, arrayn_1, wavein, ndmxt);
memcpy_complex_op()(arrayn_1, wavein, ndmxt);
// ModuleBase::GlobalFunc::DCOPY(wavein, arrayn_1, ndmxt);

funA(arrayn_1, arrayn, m);
Expand All @@ -618,9 +618,9 @@ void Chebyshev<REAL, Device>::tracepolyA(
arraynp1 = tem;
}

delmem_complex_op()(this->ctx, arraynp1);
delmem_complex_op()(this->ctx, arrayn);
delmem_complex_op()(this->ctx, arrayn_1);
delmem_complex_op()(arraynp1);
delmem_complex_op()(arrayn);
delmem_complex_op()(arrayn_1);
return;
}

Expand Down Expand Up @@ -669,11 +669,11 @@ bool Chebyshev<REAL, Device>::checkconverge(
std::complex<REAL>* arrayn = nullptr;
std::complex<REAL>* arrayn_1 = nullptr;

resmem_complex_op()(this->ctx, arraynp1, LDA);
resmem_complex_op()(this->ctx, arrayn, LDA);
resmem_complex_op()(this->ctx, arrayn_1, LDA);
resmem_complex_op()(arraynp1, LDA);
resmem_complex_op()(arrayn, LDA);
resmem_complex_op()(arrayn_1, LDA);

memcpy_complex_op()(this->ctx, this->ctx, arrayn_1, wavein, N);
memcpy_complex_op()(arrayn_1, wavein, N);
// ModuleBase::GlobalFunc::DCOPY(wavein, arrayn_1, N);

if (tmin == tmax)
Expand Down Expand Up @@ -754,9 +754,9 @@ bool Chebyshev<REAL, Device>::checkconverge(
arraynp1 = tem;
}

delmem_complex_op()(this->ctx, arraynp1);
delmem_complex_op()(this->ctx, arrayn);
delmem_complex_op()(this->ctx, arrayn_1);
delmem_complex_op()(arraynp1);
delmem_complex_op()(arrayn);
delmem_complex_op()(arrayn_1);
return converge;
}

Expand Down
8 changes: 4 additions & 4 deletions source/module_base/math_ylmreal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ void YlmReal::Ylm_Real(Device * ctx, const int lmax2, const int ng, const FPTYPE
ModuleBase::WARNING_QUIT("YLM_REAL","l>30 or l<0");
}
FPTYPE * p = nullptr, * phi = nullptr, * cost = nullptr;
resmem_var_op()(ctx, p, (lmax + 1) * (lmax + 1) * ng, "YlmReal::Ylm_Real");
resmem_var_op()(p, (lmax + 1) * (lmax + 1) * ng, "YlmReal::Ylm_Real");

cal_ylm_real_op()(
ctx,
Expand All @@ -342,9 +342,9 @@ void YlmReal::Ylm_Real(Device * ctx, const int lmax2, const int ng, const FPTYPE
p,
ylm);

delmem_var_op()(ctx, p);
delmem_var_op()(ctx, phi);
delmem_var_op()(ctx, cost);
delmem_var_op()(p);
delmem_var_op()(phi);
delmem_var_op()(cost);
} // end subroutine ylmr2

//==========================================================
Expand Down
Loading
Loading