Skip to content

Commit

Permalink
CI: PyTorch Surrogate Example (#621)
Browse files Browse the repository at this point in the history
* CI: PyTorch Surrogate Example

Cover our PyTorch surrogate example in CI.

* CI: Add Extra Example Requirements (CPU)

* Silence Torch Warning

And andticipate default change in future releases.

* PyTorch Threading Mixed with AMReX is icky

Issues as soon as we use MPI+OMP and add our `Drift` element.

* CTest: Skip Analysis/Plot if Run Failed

... to produce output

* CTest: Define 42 as Skip Return Code

Better than passing the test
https://cmake.org/cmake/help/latest/prop_test/SKIP_RETURN_CODE.html
  • Loading branch information
ax3l authored Jan 14, 2025
1 parent cd74b3a commit c00f6d2
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 10 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/dependencies/gcc-openmpi.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@ python3 -m pip install -U -r src/python/impactx/dashboard/requirements.txt
python3 -m pip install -U -r examples/requirements.txt
python3 -m pip install -U -r tests/python/requirements.txt

# extra tests
python3 -m pip install -U -r examples/requirements_torch_cpu.txt
python3 -m pip install -U openPMD-validator
2 changes: 2 additions & 0 deletions .github/workflows/dependencies/gcc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,6 @@ python3 -m pip install -U -r src/python/impactx/dashboard/requirements.txt
python3 -m pip install -U -r examples/requirements.txt
python3 -m pip install -U -r tests/python/requirements.txt

# extra tests
python3 -m pip install -U -r examples/requirements_torch_cpu.txt
python3 -m pip install -U openPMD-validator
25 changes: 25 additions & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ function(add_impactx_test name input is_mpi analysis_script plot_script)
else()
set_property(TEST ${name}.run APPEND PROPERTY ENVIRONMENT "OMP_NUM_THREADS=2")
endif()
# special return code for skipped tests (e.g., runtime prerequisite fails)
set_tests_properties(${name}.run PROPERTIES SKIP_RETURN_CODE 42)

# analysis and plots
set(THIS_Python_SCRIPT_EXE)
Expand All @@ -131,6 +133,11 @@ function(add_impactx_test name input is_mpi analysis_script plot_script)

# make HDF5 I/O more robust on various filesystems
set_property(TEST ${name}.analysis APPEND PROPERTY ENVIRONMENT "HDF5_USE_FILE_LOCKING=FALSE")

# run test failed? Mark this as skipped
set_property(TEST ${name}.analysis PROPERTY SKIP_REGULAR_EXPRESSION
"Supplied directory is not valid: diags"
)
endif()
if(plot_script)
add_test(NAME ${name}.plot
Expand All @@ -141,6 +148,11 @@ function(add_impactx_test name input is_mpi analysis_script plot_script)

# make HDF5 I/O more robust on various filesystems
set_property(TEST ${name}.plot APPEND PROPERTY ENVIRONMENT "HDF5_USE_FILE_LOCKING=FALSE")

# run test failed? Mark this as skipped
set_property(TEST ${name}.plot PROPERTY SKIP_REGULAR_EXPRESSION
"ValueError: No objects to concatenate"
)
endif()
endfunction()

Expand Down Expand Up @@ -1000,6 +1012,7 @@ add_impactx_test(spectrometer.py
OFF # no plot script yet
)


# Chicane with CSR ###########################################################
#
if(ImpactX_FFT)
Expand Down Expand Up @@ -1097,6 +1110,7 @@ add_impactx_test(linac-segment.py
OFF # no plot script yet
)


# Iteration of a linear one-turn map #########################################
#
# w/o space charge
Expand All @@ -1112,3 +1126,14 @@ add_impactx_test(linear-map.py
examples/linear_map/analysis_map.py
OFF # no plot script yet
)


# PyTorch Surrogate: Staged LPA ##############################################
#
add_impactx_test(pytorch_surrogate_model
examples/pytorch_surrogate_model/run_ml_surrogate_15_stage.py
OFF # ImpactX MPI-parallel
examples/pytorch_surrogate_model/analyze_ml_surrogate_15_stage.py
examples/pytorch_surrogate_model/visualize_ml_surrogate_15_stage.py
)
label_impactx_test(pytorch_surrogate_model slow)
18 changes: 10 additions & 8 deletions examples/pytorch_surrogate_model/run_ml_surrogate_15_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
print("Warning: Cannot import PyTorch. Skipping test.")
import sys

sys.exit(0)
sys.exit(42) # ImpactX special return code for skipped tests

import zipfile
from urllib import request
Expand Down Expand Up @@ -100,18 +100,19 @@ def download_and_unzip(url, data_dir):
data_url = "https://zenodo.org/records/10810754/files/models.zip?download=1"
download_and_unzip(data_url, "models.zip")

# It was found that the PyTorch multithreaded defaults interfere with MPI-enabled AMReX
# when initializing the models: https://github.com/AMReX-Codes/pyamrex/issues/322
# It was found that the PyTorch multithreaded defaults interfere with AMReX OpenMP
# when initializing the models or iterating elements:
# https://github.com/AMReX-Codes/pyamrex/issues/322
# https://github.com/ECP-WarpX/impactx/issues/773#issuecomment-2585043099
# So we manually set the number of threads to serial (1).
if Config.have_mpi:
n_threads = torch.get_num_threads()
torch.set_num_threads(1)
# Torch threading is not a problem with GPUs and might work when MPI is disabled.
# Could also just be a mixing of OpenMP libraries (gomp and llvm omp) when using the
# pre-build PyTorch pip packages.
torch.set_num_threads(1)
model_list = [
surrogate_model(f"models/beam_stage_{stage_i}_model.pt", device=device)
for stage_i in range(N_stage)
]
if Config.have_mpi:
torch.set_num_threads(n_threads)

pp_amrex = amr.ParmParse("amrex")
pp_amrex.add("the_arena_init_size", 0)
Expand Down Expand Up @@ -328,6 +329,7 @@ def set_lens(self, pc, step, period):
lpa = LPASurrogateStage(i, model_list[i], L_surrogate, L_stage_period * i)
lpa.nslice = n_slice
lpa.ds = L_surrogate
lpa.threadsafe = False
lpa_stages.append(lpa)

monitor = elements.BeamMonitor("monitor")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ class surrogate_model:
def __init__(self, model_file, device=None):
self.device = device
if device is None:
model_dict = torch.load(model_file, map_location="cpu")
model_dict = torch.load(model_file, map_location="cpu", weights_only=False)
else:
model_dict = torch.load(model_file, map_location=device)
model_dict = torch.load(model_file, map_location=device, weights_only=False)
self.source_means = torch.tensor(
model_dict["source_means"], device=self.device, dtype=torch.float64
)
Expand Down
6 changes: 6 additions & 0 deletions examples/requirements_torch_cpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# This is for CPU CI tests with extra requirements.
#
# For PyTorch, see alternative packages, e.g., for GPU here:
# https://pytorch.org/get-started/locally/
--extra-index-url https://download.pytorch.org/whl/cpu
torch

0 comments on commit c00f6d2

Please sign in to comment.