Skip to content

Commit

Permalink
* fix some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Joshuaalbert committed Sep 28, 2024
1 parent ebdb70f commit ad19ba0
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _get_antennas(self) -> ac.EarthLocation:
all_antennas = array.get_antennas()
array_centre = array.get_array_location()
all_antennas_itrs = all_antennas.get_itrs()
all_antennas_itrs_xyz = all_antennas_itrs.T
all_antennas_itrs_xyz = all_antennas_itrs.cartesian.xyz.T
max_baseline = np.max(
np.linalg.norm(
all_antennas_itrs_xyz[:, None, :] - all_antennas_itrs_xyz[None, :, :],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import jax
import numpy as np
from astropy import units as au

from dsa2000_cal.assets.rfi.lte_rfi.lwa_cell_tower import LWACellTower
from dsa2000_cal.assets.rfi.lte_rfi.mock_cell_tower import MockCellTower


def test_lte_rfi_source_factory():
model = MockCellTower(seed='test')
import pylab as plt
source_params = model.make_source_params(freqs=np.linspace(700, 800, 50) * au.MHz)
plt.plot(source_params.delay_acf.x, source_params.delay_acf.values[:, 0])
delays = np.linspace(-1e7, 1e7, 1000)
print(source_params.delay_acf)
plt.plot(delays, jax.vmap(source_params.delay_acf)(delays)[:, 0, 0])
plt.xlabel('Delay [s]')
plt.ylabel('Auto-correlation function')
plt.show()
assert source_params.delay_acf.regular_grid
plt.show()
6 changes: 3 additions & 3 deletions dsa2000_cal/dsa2000_cal/calibration/multi_step_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class MultiStepLevenbergMarquardt(Generic[X, Y]):
# Improvement threshold
p_any_improvement: FloatArray = 0.06 # p0 > 0
p_less_newton: FloatArray = 0.88 # p2 -- less than sufficient improvement
p_sufficient_improvement: FloatArray = 1. # p1 > p0
p_sufficient_improvement: FloatArray = 0.99 # p1 > p0
p_more_newton: FloatArray = 1. # p3 -- more than sufficient improvement

# Damping alteration factors 0 < c_more_newton < 1 < c_less_newton
Expand All @@ -113,11 +113,11 @@ def __post_init__(self):
self.p_sufficient_improvement,
self.p_more_newton,
self.p_less_newton))) and not (
(0. < self.p_any_improvement)
(0. <= self.p_any_improvement)
and (self.p_any_improvement < self.p_less_newton)
and (self.p_less_newton < self.p_sufficient_improvement)
and (self.p_sufficient_improvement < self.p_more_newton)
and (self.p_more_newton < 1.)
and (self.p_more_newton <= 1.)
):
raise ValueError(
"Improvement thresholds must satisfy 0 < p(any) < p(less) < p(sufficient) < p(more) < 1, "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# export XLA_FLAGS="--xla_gpu_enable_mem_tracing --xla_hlo_profile"

# os.environ['XLA_FLAGS'] = '--xla_gpu_enable_mem_tracing --xla_hlo_profile'
config.update("jax_explain_cache_misses", True)
# config.update("jax_explain_cache_misses", True)
import pytest

from dsa2000_cal.assets.content_registry import fill_registries
Expand Down
24 changes: 12 additions & 12 deletions dsa2000_cal/dsa2000_cal/common/tests/test_wgridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,18 @@ def test_spectral_predict(center_offset: float):
freqs = jnp.asarray([700e6, 700e6])

vis = jax.vmap(
lambda dirty, dl, dm, l0, m0, freqs:
image_to_vis(
uvw=uvw,
freqs=freqs[None],
dirty=dirty,
pixsize_l=dl,
pixsize_m=dm,
center_l=l0,
center_m=m0,
epsilon=1e-6
convert_to_ufunc(
lambda dirty, dl, dm, l0, m0, freqs:
image_to_vis(
uvw=uvw,
freqs=freqs[None],
dirty=dirty,
pixsize_l=dl,
pixsize_m=dm,
center_l=l0,
center_m=m0,
epsilon=1e-6
)
)
)(dirty, dl, dm, l0, m0, freqs)
assert np.shape(vis) == (num_freqs, len(uvw), 1)
Expand Down Expand Up @@ -297,8 +299,6 @@ def _image_to_vis(uvw, freqs, dirty, dl, dm, l0, m0, mask):
vis = _image_to_vis(uvw, freqs, dirty, dl, dm, l0, m0, mask)
assert vis.shape == (a, b, C, r, c)



@partial(
multi_vmap,
in_mapping="[a,r,3],[b,C,c=1],[C,Nl,Nm],[C],[C],[C],[C],[a,r,c=1]",
Expand Down
6 changes: 3 additions & 3 deletions dsa2000_cal/dsa2000_cal/common/wgridder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import itertools
import os
from concurrent.futures import ThreadPoolExecutor
from functools import partial

import jax
import jax.numpy as jnp
Expand All @@ -13,7 +12,6 @@
'vis_to_image'
]

from dsa2000_cal.common.jax_utils import convert_to_ufunc
from dsa2000_cal.common.types import FloatArray, ComplexArray
from dsa2000_cal.common.mixed_precision_utils import mp_policy

Expand Down Expand Up @@ -250,7 +248,8 @@ def compute_vis_for_channel_spectral(indices):
print(e)
raise e

batch_dims = np.shape(uvw)[:-2]
batch_dims = np.broadcast_shapes(np.shape(uvw)[:-2], np.shape(dirty)[:-2], np.shape(freqs)[:-1])
print(np.shape(uvw)[:-2], np.shape(dirty)[:-2], batch_dims)
all_indices = list(itertools.product(*[range(dim) for dim in batch_dims]))
# Put dims at end so memory ordering is nice
output_vis = np.zeros((num_rows, num_freq) + batch_dims, order='F', dtype=output_dtype)
Expand All @@ -261,6 +260,7 @@ def compute_vis_for_channel_spectral(indices):
perm = list(range(len(batch_dims) + 2))
perm.append(perm.pop(0)) # Move num_rows to the end
perm.append(perm.pop(0)) # Move num_freqs to the end
print(perm)

output_vis = np.transpose(output_vis, axes=tuple(perm))

Expand Down
4 changes: 2 additions & 2 deletions dsa2000_cal/dsa2000_cal/imaging/imagor.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,8 @@ def update_weights(weights):
raise ValueError(f"Unknown weighting scheme {self.weighting}")

@partial(multi_vmap,
in_mapping="[r,c,p],[r,c,p],[r,c,p]",
out_mapping="[...,p]",
in_mapping="[r,c,coh],[r,c,coh],[r,c,coh]",
out_mapping="[...,coh]",
verbose=True)
def image_single_coh(vis, weights, mask):
dirty_image = vis_to_image(
Expand Down
19 changes: 15 additions & 4 deletions dsa2000_cal/dsa2000_cal/imaging/tests/test_imagor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,19 @@ def test_evaluate_beam(tmp_path, coherencies, centre_offset: float):
dm = 0.001
center_l = center_m = centre_offset
beam = evaluate_beam(freqs, times, beam_gain_model, geodesic_model, num_l, num_m, dl, dm, center_l, center_m)
assert beam.shape == (num_l, num_m, len(times), len(freqs), 2, 2)
assert np.all(np.isfinite(beam))

assert np.all(np.isfinite(beam))
avg_beam = jnp.mean(beam, axis=(2, 3))

image = np.ones((num_l, num_m, 4))
if len(coherencies) == 1:
assert beam.shape == (num_l, num_m, len(times), len(freqs))
else:
assert beam.shape == (num_l, num_m, len(times), len(freqs), 2, 2)




image = np.ones((num_l, num_m, len(coherencies)))
# image[::10, ::10, 0] = 1.
# image[::10, ::10, 3] = 1.

Expand All @@ -83,8 +90,12 @@ def test_evaluate_beam(tmp_path, coherencies, centre_offset: float):
assert np.all(np.isfinite(pb_cor_image))
import pylab as plt


if len(coherencies) == 4:
avg_beam = avg_beam[..., 0, 0]

plt.imshow(
np.abs(avg_beam[..., 0, 0]).T,
np.abs(avg_beam).T,
origin='lower',
aspect='auto',
extent=(-0.5 * num_l * dl, 0.5 * num_l * dl, -0.5 * num_m * dm, 0.5 * num_m * dm)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def apply_delay(freq, delay, w, delay_acf_val, dist10, dist20, g1, g2):
if full_stokes:
visibilities = kron_product(g1, visibilities, g2.conj().T) # [2, 2]
else:
visibilities = g1 * visibilities * g2.cong().T # []
visibilities = g1 * visibilities * g2.conj().T # []

return mp_policy.cast_to_vis(visibilities) # [[2,2]]

Expand Down

0 comments on commit ad19ba0

Please sign in to comment.