Skip to content

Commit

Permalink
* fix common unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
Joshuaalbert committed Sep 3, 2024
1 parent 10dbcd5 commit 37aec6a
Show file tree
Hide file tree
Showing 13 changed files with 645 additions and 73 deletions.
16 changes: 4 additions & 12 deletions dsa2000_cal/dsa2000_cal/common/bbs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
from typing import Tuple, Dict, Optional, List

import numpy as np
from africanus.model.coherency import convert
from astropy import coordinates as ac
from astropy import units as au
from h5parm.utils import parse_coordinates_bbs
from pydantic import Field

from dsa2000_cal.common.serialise_utils import SerialisableBaseModel
from dsa2000_cal.common.wsclean_util import parse_coordinates_bbs


class SourceModel(SerialisableBaseModel):
Expand All @@ -18,9 +17,8 @@ class SourceModel(SerialisableBaseModel):
lm: np.ndarray = Field(
description="Source direction cosines of shape [source, 2]",
)
corrs: List[List[str]] = Field(
corrs: List[str] = Field(
description="Correlations in the source model",
default=[['XX', 'XY'], ['YX', 'YY']],
)
freqs: np.ndarray = Field(
description="Frequencies of shape [chan]",
Expand Down Expand Up @@ -188,20 +186,14 @@ def _get_stokes_param(param: str):
axis=1
) # [source, corr]

output_corrs = [['XX', 'XY'], ['YX', 'YY']]
image_corr = convert(
stokes_image,
['I', 'Q', 'U', 'V'],
output_corrs
) # [source, 2, 2]
image_corr = np.tile(image_corr[:, None, :, :], [1, len(self.channels), 1, 1]) # [source, chan, 2, 2]
image_corr = np.tile(stokes_image[:, None, :], [1, len(self.channels), 1]) # [source, chan, 4]
if 'ReferenceFrequency' in data_dict and 'SpectralIndex' in data_dict:
## TODO: Add spectral model if necessary
pass

return SourceModel(
image=image_corr,
lm=direction_cosines,
corrs=output_corrs,
corrs=['I', 'Q', 'U', 'V'],
freqs=self.channels
)
3 changes: 2 additions & 1 deletion dsa2000_cal/dsa2000_cal/common/interp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,9 @@ class InterpolatedArray:

def __post_init__(self):

print(self.x)
if len(np.shape(self.x)) != 1:
raise ValueError(f"Times must be 1D, got {np.shape(self.x)}.")
raise ValueError(f"x must be 1D, got {np.shape(self.x)}.")

def _assert_shape(x):
if np.shape(x)[self.axis] != np.size(self.x):
Expand Down
2 changes: 1 addition & 1 deletion dsa2000_cal/dsa2000_cal/common/serialise_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def parse_obj(cls: Type[C], obj: Dict[str, Any]) -> C:

# Deserialise InterpolatedArray
elif field.type_ is InterpolatedArray and isinstance(obj.get(name), dict) and obj[name].get(
"type") == 'dsa2000_cal.uvw.uvw_utils.InterpolatedArray':
"type") == 'dsa2000_cal.common.interp_utils.InterpolatedArray':
obj[name] = deserialise_interpolated_array(obj[name])
continue

Expand Down
38 changes: 8 additions & 30 deletions dsa2000_cal/dsa2000_cal/common/tests/test_bbs_utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
import os

import astropy.units as au
from astropy.coordinates import SkyCoord
from h5parm.utils import parse_coordinates_bbs

from dsa2000_cal.common.bbs_utils import BBSSkyModel


def test_bbs_sky_model_I_only():
def test_bbs_sky_model_I_only(tmp_path):
sky_model_bbs = "# (Name, Type, Ra, Dec, I) = format\n" \
"A, POINT, 00:00:00.123456, +37.07.47.12345, 1.0\n" \
"B, POINT, 00:00:00.123456, +37.37.47.12345, 1.0\n" \
"C, POINT, 00:00:00.123456, +38.07.47.12345, 1.0\n"
sky_model_file = 'test_sky_model.txt'
sky_model_file = tmp_path / 'test_sky_model.txt'
with open(sky_model_file, 'w') as f:
f.write(sky_model_bbs)
pointing_centre = parse_coordinates_bbs("00:00:00.0", "+37.07.47.0")
Expand All @@ -22,19 +19,18 @@ def test_bbs_sky_model_I_only():
num_channels=5
)
source_model = bbs_sky_model.get_source()
assert source_model.corrs == [['XX', 'XY'], ['YX', 'YY']]
assert source_model.image.shape == (3, 5, 2, 2)
assert source_model.corrs == ['I', 'Q', 'U', 'V']
assert source_model.image.shape == (3, 5, 4)
assert source_model.lm.shape == (3, 2)
assert source_model.freqs.shape == (5,)
os.remove(sky_model_file)


def test_bbs_sky_model_all_only():
def test_bbs_sky_model_all_only(tmp_path):
sky_model_bbs = "# (Name, Type, Ra, Dec, I=0, U=0, V=0) = format\n" \
"A, POINT, 00:00:00.123456, +37.07.47.12345, 1.0, , , \n" \
"B, POINT, 00:00:00.123456, +37.37.47.12345, 1.0, , , \n" \
"C, POINT, 00:00:10.123456, +37.37.47.12345, 1.0, , , "
sky_model_file = 'test_sky_model.txt'
sky_model_file = tmp_path / 'test_sky_model.txt'
with open(sky_model_file, 'w') as f:
f.write(sky_model_bbs)
pointing_centre = parse_coordinates_bbs("00:00:00.0", "+37.07.47.0")
Expand All @@ -44,25 +40,7 @@ def test_bbs_sky_model_all_only():
num_channels=5
)
source_model = bbs_sky_model.get_source()
assert source_model.corrs == [['XX', 'XY'], ['YX', 'YY']]
assert source_model.image.shape == (3, 5, 2, 2)
assert source_model.corrs == ['I', 'Q', 'U', 'V']
assert source_model.image.shape == (3, 5, 4)
assert source_model.lm.shape == (3, 2)
print(source_model)
os.remove(sky_model_file)


def test_file_creation():
filename = "test_sky_model.txt"
create_sky_model(filename, 5, 1.0, SkyCoord(ra=10 * au.degree, dec=10 * au.degree, frame='icrs'))
assert os.path.exists(filename)
os.remove(filename)


def test_correct_number_of_sources():
filename = "test_sky_model.txt"
create_sky_model(filename, 5, 1.0, SkyCoord(ra=10 * au.degree, dec=10 * au.degree, frame='icrs'))
with open(filename, 'r') as f:
lines = f.readlines()
# Subtracting 1 for the header
assert len(lines) - 1 == 5
os.remove(filename)
4 changes: 2 additions & 2 deletions dsa2000_cal/dsa2000_cal/common/tests/test_corr_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

def test_flatten_coherencies():
coherencies = jnp.asarray([[1, 2], [3, 4]])
assert jnp.alltrue(flatten_coherencies(coherencies) == jnp.asarray([1, 2, 3, 4]))
assert jnp.all(flatten_coherencies(coherencies) == jnp.asarray([1, 2, 3, 4]))


def test_unflatten_coherencies():
coherencies = jnp.asarray([1, 2, 3, 4])
assert jnp.alltrue(unflatten_coherencies(coherencies) == jnp.asarray([[1, 2], [3, 4]]))
assert jnp.all(unflatten_coherencies(coherencies) == jnp.asarray([[1, 2], [3, 4]]))


def test_linear_to_linear():
Expand Down
8 changes: 4 additions & 4 deletions dsa2000_cal/dsa2000_cal/common/tests/test_fits_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def test_image_model():
x = ImageModel(
phase_tracking=ac.ICRS(0 * au.deg, 0 * au.deg),
obs_time=obs_time,
dl=au.Quantity(-1, au.dimensionless_unscaled),
dl=au.Quantity(1, au.dimensionless_unscaled),
dm=au.Quantity(1, au.dimensionless_unscaled),
freqs=au.Quantity([100, 200, 300], au.MHz),
image=au.Quantity(np.ones((10, 10, 3, 4)), au.Jy),
Expand All @@ -203,12 +203,12 @@ def test_image_model():
bandwidth=au.Quantity(300, au.MHz)
)
_ = ImageModel.parse_raw(x.json())
# dl positive
# dl neg
with pytest.raises(ValueError):
_ = ImageModel(
phase_tracking=ac.ICRS(0 * au.deg, 0 * au.deg),
obs_time=obs_time,
dl=au.Quantity(1, au.dimensionless_unscaled),
dl=au.Quantity(-1, au.dimensionless_unscaled),
dm=au.Quantity(1, au.dimensionless_unscaled),
freqs=au.Quantity([100, 200, 300], au.MHz),
image=au.Quantity(np.ones((10, 10, 3, 4)), au.Jy),
Expand Down Expand Up @@ -259,7 +259,7 @@ def test_save_image_to_fits(tmp_path):
image_model = ImageModel(
phase_tracking=ac.ICRS(0 * au.deg, 0 * au.deg),
obs_time=obs_time,
dl=au.Quantity(-0.01, au.dimensionless_unscaled),
dl=au.Quantity(0.01, au.dimensionless_unscaled),
dm=au.Quantity(0.01, au.dimensionless_unscaled),
freqs=au.Quantity([100, 200, 300], au.MHz),
image=au.Quantity(np.random.normal(size=(10, 10, 3, 4)), au.Jy),
Expand Down
5 changes: 3 additions & 2 deletions dsa2000_cal/dsa2000_cal/common/tests/test_fourier_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from jax import numpy as jnp

from dsa2000_cal.common.fourier_utils import ApertureTransform, find_optimal_fft_size
import numpy as np


@pytest.mark.parametrize('convention', ['physical', 'casa'])
Expand Down Expand Up @@ -30,7 +31,7 @@ def test_fourier_conventions(convention):
plt.show()

# This passes for both conventions
jnp.testing.assert_allclose(f_aperture, rec_f_aperture, atol=1e-6)
np.testing.assert_allclose(f_aperture, rec_f_aperture, atol=1e-4)

# If we run with 'casa' convention, the plots all have mode in centre

Expand All @@ -51,7 +52,7 @@ def test_fourier_conventions(convention):
plt.show()

# This passes for both conventions
jnp.testing.assert_allclose(f_image, rec_f_image, atol=1e-6)
np.testing.assert_allclose(f_image, rec_f_image, atol=1e-5)


def test_find_next_magic_size():
Expand Down
2 changes: 1 addition & 1 deletion dsa2000_cal/dsa2000_cal/common/tests/test_interp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_multilinear_interp_2d():
[3., 3.3333333, 3.6666666, 4.]
]
)
np.testing.assert_allclose(multilinear_interp_2d(x, y, xp, yp, z), expected)
np.testing.assert_allclose(multilinear_interp_2d(x, y, xp, yp, z), expected, atol=1e-6)

# within_bounds_2d
xp = jnp.linspace(0, 10, 11)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,36 +11,36 @@
from dsa2000_cal.common.interp_utils import InterpolatedArray


class TestModelInt(SerialisableBaseModel):
class MockModelInt(SerialisableBaseModel):
value: int


def test_serialise_deserialise_model():
model = TestModelInt(value=10)
model = MockModelInt(value=10)
serialized_data = pickle.dumps(model)
deserialized_model = pickle.loads(serialized_data)

assert isinstance(deserialized_model, TestModelInt)
assert isinstance(deserialized_model, MockModelInt)
assert deserialized_model.value == model.value


def test_config_values():
assert TestModelInt.Config.validate_assignment is True
assert TestModelInt.Config.arbitrary_types_allowed is True
assert TestModelInt.Config.json_loads == ujson.loads
assert MockModelInt.Config.validate_assignment is True
assert MockModelInt.Config.arbitrary_types_allowed is True
assert MockModelInt.Config.json_loads == ujson.loads


class TestModelNp(SerialisableBaseModel):
class MockModelNp(SerialisableBaseModel):
array: np.ndarray


def test_numpy_array_json_serialization():
model = TestModelNp(array=np.array([1, 2, 3]))
model = MockModelNp(array=np.array([1, 2, 3]))
serialized_data = model.json()

# Deserialize from the serialized data
# deserialized_model = TestModelNp.model_validate_json(serialized_data)
deserialized_model = TestModelNp.parse_raw(serialized_data)
deserialized_model = MockModelNp.parse_raw(serialized_data)

# Assert that the reconstructed numpy array is correct
np.testing.assert_array_equal(deserialized_model.array, model.array)
Expand Down
6 changes: 3 additions & 3 deletions dsa2000_cal/dsa2000_cal/common/tests/test_vec_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
def test_vec():
a = jnp.asarray([[1, 2],
[3, 4]])
assert jnp.alltrue(vec(a) == jnp.asarray([1, 3, 2, 4]))
assert jnp.all(vec(a) == jnp.asarray([1, 3, 2, 4]))

assert jnp.alltrue(unvec(vec(a), (2, 2)) == a)
assert jnp.alltrue(unvec(vec(a)) == a)
assert jnp.all(unvec(vec(a), (2, 2)) == a)
assert jnp.all(unvec(vec(a)) == a)


def test_kron_product():
Expand Down
Loading

0 comments on commit 37aec6a

Please sign in to comment.