Skip to content

Commit

Permalink
DSMC: make the check of the number of processes more robust (#5515)
Browse files Browse the repository at this point in the history
The previous code was brittle because it was not using
`AMREX_ALWAYS_ASSERT_WITH_MESSAGE` (which indeed cannot be used inside a
GPU kernel. In practice, a user could specify e.g. 10 scattering
processes and not get an error.

The new code checks the number of processes before calling the GPU
kernel, by using `AMREX_ALWAYS_ASSERT_WITH_MESSAGE`.

---------

Co-authored-by: Roelof Groenewald <[email protected]>
  • Loading branch information
RemiLehe and roelof-groenewald authored Dec 17, 2024
1 parent 2ea2dd8 commit 2cdcb77
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
* @param[in] scattering processes an array of scattering processes included for consideration.
* @param[in] engine the random engine.
*/
template <typename index_type>
template <int max_process_count, typename index_type>
AMREX_GPU_HOST_DEVICE AMREX_INLINE
void CollisionPairFilter (const amrex::ParticleReal u1x, const amrex::ParticleReal u1y,
const amrex::ParticleReal u1z, const amrex::ParticleReal u2x,
Expand Down Expand Up @@ -65,11 +65,11 @@ void CollisionPairFilter (const amrex::ParticleReal u1x, const amrex::ParticleRe

// Evaluate the cross-section for each scattering process to determine
// the total collision probability.
AMREX_ASSERT_WITH_MESSAGE(
(process_count < 4), "Too many scattering processes in DSMC routine."
);
int coll_type[4] = {0, 0, 0, 0};
amrex::ParticleReal sigma_sums[4] = {0._prt, 0._prt, 0._prt, 0._prt};

// The size of the arrays below is a compile-time constant (template parameter)
// for performance reasons: it avoids dynamic memory allocation on the GPU.
int coll_type[max_process_count] = {0};
amrex::ParticleReal sigma_sums[max_process_count] = {0._prt};
for (int ii = 0; ii < process_count; ii++) {
auto const& scattering_process = scattering_processes[ii];
coll_type[ii] = int(scattering_process.m_type);
Expand Down
6 changes: 5 additions & 1 deletion Source/Particles/Collision/BinaryCollision/DSMC/DSMCFunc.H
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,11 @@ public:
u1y[I1[i1]] = u1xbuf*std::sin(theta) + u1y[I1[i1]]*std::cos(theta);
#endif

CollisionPairFilter(
const int max_process_count = 4; // Pre-defined value, for performance reasons
AMREX_ALWAYS_ASSERT_WITH_MESSAGE(
(m_process_count < max_process_count), "Too many scattering processes in DSMC routine (hardcoded to only allow 4). Update the max_process_count value in source code to allow more scattering processes."
);
CollisionPairFilter<max_process_count>(
u1x[ I1[i1] ], u1y[ I1[i1] ], u1z[ I1[i1] ],
u2x[ I2[i2] ], u2y[ I2[i2] ], u2z[ I2[i2] ],
m1, m2, w1[ I1[i1] ], w2[ I2[i2] ],
Expand Down

0 comments on commit 2cdcb77

Please sign in to comment.