Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: enable tracking on GPUs #28

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions examples/001d_instability_wake_table_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import hilbert
from scipy.stats import linregress

import xtrack as xt
import xpart as xp
import xwakes as xw
import xfields as xf
import xobjects as xo

context = xo.ContextCupy(device=3)
#context = xo.ContextCpu()

# Simulation settings
n_turns = 1000

wake_table_filename = xf.general._pkg_root.joinpath(
'../test_data/HLLHC_wake.dat')
wake_file_columns = ['time', 'longitudinal', 'dipole_x', 'dipole_y',
'quadrupole_x', 'quadrupole_y', 'dipole_xy',
'quadrupole_xy', 'dipole_yx', 'quadrupole_yx',
'constant_x', 'constant_y']
mytable = xw.read_headtail_file(
wake_file=wake_table_filename,
wake_file_columns=wake_file_columns
)
wf = xw.WakeFromTable(
table=mytable,
columns=['dipole_x', 'dipole_y', 'quadrupole_x', 'quadrupole_y'],
)
wf.configure_for_tracking(
zeta_range=(-0.375, 0.375),
num_slices=100,
num_turns=1,
_context=context
)

one_turn_map = xt.LineSegmentMap(
length=27e3, betx=70., bety=80.,
qx=62.31, qy=60.32,
longitudinal_mode='linear_fixed_qs',
dqx=-10., dqy=-10., # <-- to see fast mode-0 instability
qs=2e-3, bets=731.27,
_context=context
)

# Generate line
line = xt.Line(elements=[one_turn_map, wf],
element_names=['one_turn_map', 'wf'])


line.particle_ref = xt.Particles(p0c=7e12, _context=context)
line.build_tracker(_context=context)

# Generate particles
particles = xp.generate_matched_gaussian_bunch(line=line,
num_particles=40_000_000, total_intensity_particles=2.3e11,
nemitt_x=2e-6, nemitt_y=2e-6, sigma_z=0.075,
_context=context)

# Apply a distortion to the bunch to trigger an instability
amplitude = 1e-3
particles.x += amplitude
particles.y += amplitude

flag_plot = True

mean_x_xt = np.zeros(n_turns)
mean_y_xt = np.zeros(n_turns)

plt.ion()

fig1 = plt.figure(figsize=(6.4*1.7, 4.8))
ax_x = fig1.add_subplot(121)
line1_x, = ax_x.plot(mean_x_xt, 'r-', label='average x-position')
line2_x, = ax_x.plot(mean_x_xt, 'm-', label='exponential fit')
ax_x.set_ylim(-3.5, -1)
ax_x.set_xlim(0, n_turns)
ax_y = fig1.add_subplot(122, sharex=ax_x)
line1_y, = ax_y.plot(mean_y_xt, 'b-', label='average y-position')
line2_y, = ax_y.plot(mean_y_xt, 'c-', label='exponential fit')
ax_y.set_ylim(-3.5, -1)
ax_y.set_xlim(0, n_turns)

plt.xlabel('turn')
plt.ylabel('log10(average x-position)')
plt.legend()

turns = np.linspace(0, n_turns - 1, n_turns)

import time

line.track(particles, num_turns=1)

start = time.time()

for i_turn in range(n_turns):
line.track(particles, num_turns=1)

if i_turn % 50 == 0:
print(f'Turn: {i_turn}')

'''
mean_x_xt[i_turn] = np.mean(particles.x)
mean_y_xt[i_turn] = np.mean(particles.y)

if i_turn % 50 == 0 and i_turn > 1:
i_fit_end = np.argmax(mean_x_xt) # i_turn
i_fit_start = int(i_fit_end * 0.9)

# compute x instability growth rate
ampls_x_xt = np.abs(hilbert(mean_x_xt))
fit_x_xt = linregress(turns[i_fit_start: i_fit_end],
np.log(ampls_x_xt[i_fit_start: i_fit_end]))

# compute y instability growth rate

ampls_y_xt = np.abs(hilbert(mean_y_xt))
fit_y_xt = linregress(turns[i_fit_start: i_fit_end],
np.log(ampls_y_xt[i_fit_start: i_fit_end]))

line1_x.set_xdata(turns[:i_turn])
line1_x.set_ydata(np.log10(np.abs(mean_x_xt[:i_turn])))
line2_x.set_xdata(turns[:i_turn])
line2_x.set_ydata(np.log10(np.exp(fit_x_xt.intercept +
fit_x_xt.slope*turns[:i_turn])))

line1_y.set_xdata(turns[:i_turn])
line1_y.set_ydata(np.log10(np.abs(mean_y_xt[:i_turn])))
line2_y.set_xdata(turns[:i_turn])
line2_y.set_ydata(np.log10(np.exp(fit_y_xt.intercept +
fit_y_xt.slope*turns[:i_turn])))
print(f'xtrack h growth rate: {fit_x_xt.slope}')
print(f'xtrack v growth rate: {fit_y_xt.slope}')

fig1.canvas.draw()
fig1.canvas.flush_events()

out_folder = '.'
np.save(f'{out_folder}/mean_x.npy', mean_x_xt)
np.save(f'{out_folder}/mean_y.npy', mean_y_xt)
'''

end = time.time()
print(f'Time: {end - start}')
65 changes: 43 additions & 22 deletions tests/test_xwakes_kick_vs_pyheadtail.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@
import xwakes as xw
import xtrack as xt
import xobjects as xo
from xobjects.test_helpers import for_all_test_contexts

test_data_folder = pathlib.Path(__file__).parent.joinpath(
'../test_data').absolute()

def test_xwakes_kick_vs_pyheadtail_table_dipolar():
@for_all_test_contexts(excluding="ContextPyopencl")
def test_xwakes_kick_vs_pyheadtail_table_dipolar(test_context):

from xpart.pyheadtail_interface.pyhtxtparticles import PyHtXtParticles

p = xt.Particles(p0c=7e12, zeta=np.linspace(-1, 1, 100000))
p = xt.Particles(p0c=7e12, zeta=np.linspace(-1, 1, 100000),
_context=test_context)
p.x[p.zeta > 0] += 1e-3
p.y[p.zeta > 0] += 1e-3
p_table = p.copy()
Expand All @@ -36,7 +39,8 @@ def test_xwakes_kick_vs_pyheadtail_table_dipolar():
'quadrupolar_xy', 'dipolar_yx', 'quadrupolar_yx',
'constant_x', 'constant_y'])
wake_from_table = xw.WakeFromTable(table, columns=['dipolar_x', 'dipolar_y'])
wake_from_table.configure_for_tracking(zeta_range=(-1, 1), num_slices=100)
wake_from_table.configure_for_tracking(zeta_range=(-1, 1), num_slices=100,
_context=test_context)

# Zotter convention
assert table['dipolar_x'].values[1] > 0
Expand Down Expand Up @@ -77,11 +81,13 @@ def test_xwakes_kick_vs_pyheadtail_table_dipolar():
xo.assert_allclose(p_table.px, p_ref.px, atol=1e-30, rtol=2e-3)
xo.assert_allclose(p_table.py, p_ref.py, atol=1e-30, rtol=2e-3)

def test_xwakes_kick_vs_pyheadtail_table_quadrupolar():
@for_all_test_contexts(excluding="ContextPyopencl")
def test_xwakes_kick_vs_pyheadtail_table_quadrupolar(test_context):

from xpart.pyheadtail_interface.pyhtxtparticles import PyHtXtParticles

p = xt.Particles(p0c=7e12, zeta=np.linspace(-1, 1, 100000))
p = xt.Particles(p0c=7e12, zeta=np.linspace(-1, 1, 100000),
_context=test_context)
p.x[p.zeta > 0] += 1e-3
p.y[p.zeta > 0] += 1e-3
p_table = p.copy()
Expand All @@ -98,7 +104,8 @@ def test_xwakes_kick_vs_pyheadtail_table_quadrupolar():
'constant_x', 'constant_y'])

wake_from_table = xw.WakeFromTable(table, columns=['quadrupolar_x', 'quadrupolar_y'])
wake_from_table.configure_for_tracking(zeta_range=(-1, 1), num_slices=100)
wake_from_table.configure_for_tracking(zeta_range=(-1, 1), num_slices=100,
_context=test_context)

# This is specific of this table
assert table['quadrupolar_x'].values[1] < 0
Expand Down Expand Up @@ -141,13 +148,16 @@ def test_xwakes_kick_vs_pyheadtail_table_quadrupolar():
xo.assert_allclose(p_table.py, p_ref.py, atol=1e-30, rtol=2e-3)


def test_xwakes_kick_vs_pyheadtail_table_longitudinal():
@for_all_test_contexts(excluding="ContextPyopencl")
def test_xwakes_kick_vs_pyheadtail_table_longitudinal(test_context):

from xpart.pyheadtail_interface.pyhtxtparticles import PyHtXtParticles

p = xt.Particles.merge([
xt.Particles(p0c=7e12, zeta=np.linspace(-1e-3, 1e-3, 100000)),
xt.Particles(p0c=7e12, zeta=1e-6+np.zeros(100000))
xt.Particles(p0c=7e12, zeta=np.linspace(-1e-3, 1e-3, 100000),
_context=test_context),
xt.Particles(p0c=7e12, zeta=1e-6+np.zeros(100000),
_context=test_context),
])

p_table = p.copy()
Expand All @@ -163,7 +173,9 @@ def test_xwakes_kick_vs_pyheadtail_table_longitudinal():
'quadrupolar_xy', 'dipolar_yx', 'quadrupolar_yx',
'constant_x', 'constant_y'])
wake_from_table = xw.WakeFromTable(table, columns=['time', 'longitudinal'])
wake_from_table.configure_for_tracking(zeta_range=(-2e-3, 2e-3), num_slices=1000)
wake_from_table.configure_for_tracking(zeta_range=(-2e-3, 2e-3),
num_slices=1000,
_context=test_context)

assert len(wake_from_table.components) == 1
assert wake_from_table.components[0].plane == 'z'
Expand Down Expand Up @@ -196,12 +208,13 @@ def test_xwakes_kick_vs_pyheadtail_table_longitudinal():
assert np.max(p_ref.delta) > 1e-12
xo.assert_allclose(p_table.delta, p_ref.delta, atol=1e-14, rtol=0)

def test_xwakes_kick_vs_pyheadtail_resonator_dipolar():
@for_all_test_contexts(excluding="ContextPyopencl")
def test_xwakes_kick_vs_pyheadtail_resonator_dipolar(test_context):

from xpart.pyheadtail_interface.pyhtxtparticles import PyHtXtParticles

p = xt.Particles(p0c=7e12, zeta=np.linspace(-1, 1, 100000),
weight=1e14)
weight=1e14, _context=test_context)
p.x[p.zeta > 0] += 1e-3
p.y[p.zeta > 0] += 1e-3
p_table = p.copy()
Expand All @@ -212,7 +225,8 @@ def test_xwakes_kick_vs_pyheadtail_resonator_dipolar():
r=1e8, q=1e7, f_r=1e9,
kind=xw.Yokoya('circular'), # equivalent to: kind=['dipolar_x', 'dipolar_y'],
)
wake.configure_for_tracking(zeta_range=(-1, 1), num_slices=50)
wake.configure_for_tracking(zeta_range=(-1, 1), num_slices=50,
_context=test_context)

assert len(wake.components) == 2
assert wake.components[0].plane == 'x'
Expand Down Expand Up @@ -242,7 +256,8 @@ def test_xwakes_kick_vs_pyheadtail_resonator_dipolar():
table = pd.DataFrame({'time': t_samples, 'dipolar_x': w_dipole_x_samples,
'dipolar_y': w_dipole_y_samples})
wake_from_table = xw.WakeFromTable(table)
wake_from_table.configure_for_tracking(zeta_range=(-1, 1), num_slices=50)
wake_from_table.configure_for_tracking(zeta_range=(-1, 1), num_slices=50,
_context=test_context)

assert len(wake_from_table.components) == 2
assert wake_from_table.components[0].plane == 'x'
Expand Down Expand Up @@ -282,12 +297,13 @@ def test_xwakes_kick_vs_pyheadtail_resonator_dipolar():
xo.assert_allclose(p_table.px, p_ref.px, rtol=0, atol=2e-3*np.max(np.abs(p_ref.px)))
xo.assert_allclose(p_table.py, p_ref.py, rtol=0, atol=2e-3*np.max(np.abs(p_ref.py)))

def test_xwakes_kick_vs_pyheadtail_resonator_quadrupolar():
@for_all_test_contexts(excluding="ContextPyopencl")
def test_xwakes_kick_vs_pyheadtail_resonator_quadrupolar(test_context):

from xpart.pyheadtail_interface.pyhtxtparticles import PyHtXtParticles

p = xt.Particles(p0c=7e12, zeta=np.linspace(-1, 1, 100000),
weight=1e14)
weight=1e14, _context=test_context)
p.x[p.zeta > 0] += 1e-3
p.y[p.zeta > 0] += 1e-3
p_table = p.copy()
Expand All @@ -298,7 +314,8 @@ def test_xwakes_kick_vs_pyheadtail_resonator_quadrupolar():
r=1e8, q=1e7, f_r=1e9,
kind=['quadrupolar_x', 'quadrupolar_y'],
)
wake.configure_for_tracking(zeta_range=(-1, 1), num_slices=50)
wake.configure_for_tracking(zeta_range=(-1, 1), num_slices=50,
_context=test_context)

assert len(wake.components) == 2
assert wake.components[0].plane == 'x'
Expand Down Expand Up @@ -327,7 +344,8 @@ def test_xwakes_kick_vs_pyheadtail_resonator_quadrupolar():
table = pd.DataFrame({'time': t_samples, 'quadrupolar_x': w_quadrupole_x_samples,
'quadrupolar_y': w_quadrupole_y_samples})
wake_from_table = xw.WakeFromTable(table)
wake_from_table.configure_for_tracking(zeta_range=(-1, 1), num_slices=50)
wake_from_table.configure_for_tracking(zeta_range=(-1, 1), num_slices=50,
_context=test_context)

assert len(wake_from_table.components) == 2
assert wake_from_table.components[0].plane == 'x'
Expand Down Expand Up @@ -370,12 +388,13 @@ def test_xwakes_kick_vs_pyheadtail_resonator_quadrupolar():
xo.assert_allclose(p_table.px, p_ref.px, rtol=0, atol=2e-3*np.max(np.abs(p_ref.px)))
xo.assert_allclose(p_table.py, p_ref.py, rtol=0, atol=2e-3*np.max(np.abs(p_ref.py)))

def test_xwakes_kick_vs_pyheadtail_resonator_longitudinal():
@for_all_test_contexts(excluding="ContextPyopencl")
def test_xwakes_kick_vs_pyheadtail_resonator_longitudinal(test_context):

from xpart.pyheadtail_interface.pyhtxtparticles import PyHtXtParticles

p = xt.Particles(p0c=7e12, zeta=np.linspace(-1, 1, 100000),
weight=1e14)
weight=1e14, _context=test_context)
p.x[p.zeta > 0] += 1e-3
p.y[p.zeta > 0] += 1e-3
p_table = p.copy()
Expand All @@ -386,7 +405,8 @@ def test_xwakes_kick_vs_pyheadtail_resonator_longitudinal():
r=1e8, q=1e7, f_r=1e9,
kind='longitudinal'
)
wake.configure_for_tracking(zeta_range=(-1.01, 1.01), num_slices=50)
wake.configure_for_tracking(zeta_range=(-1.01, 1.01), num_slices=50,
_context=test_context)

assert len(wake.components) == 1
assert wake.components[0].plane == 'z'
Expand All @@ -407,7 +427,8 @@ def test_xwakes_kick_vs_pyheadtail_resonator_longitudinal():
w_longitudinal_x_samples[0] *= 2 # Undo sampling weight
table = pd.DataFrame({'time': t_samples, 'longitudinal': w_longitudinal_x_samples})
wake_from_table = xw.WakeFromTable(table)
wake_from_table.configure_for_tracking(zeta_range=(-1.01, 1.01), num_slices=50)
wake_from_table.configure_for_tracking(zeta_range=(-1.01, 1.01), num_slices=50,
_context=test_context)

assert len(wake_from_table.components) == 1
assert wake_from_table.components[0].plane == 'z'
Expand Down
Loading
Loading