Skip to content

Commit

Permalink
* add RFI cal to LWA forward model
Browse files Browse the repository at this point in the history
* performance test FITS predict. It's too slow.
  • Loading branch information
Joshuaalbert committed Sep 10, 2024
1 parent 6b6c182 commit 0d6ba3b
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class FullyParameterisedRFIHorizonEmitter(AbstractRFIPriorModel):
height_max: au.Quantity = 120. * au.m
luminosity_min: au.Quantity = 1e-13 * (au.W / au.MHz) # W / MHz
luminosity_max: au.Quantity = 1e-10 * (au.W / au.MHz) # W / MHz
full_stokes: bool = True

def __post_init__(self):
if not self.distance_min.unit.is_equivalent("km"):
Expand Down Expand Up @@ -126,14 +127,24 @@ def prior_model():
regular_grid=True
) # [ E]

luminosity = yield Prior(
tfpd.Uniform(
low=quantity_to_jnp(self.luminosity_min, 'Jy*m^2') * jnp.ones((self.num_emitters, 2, 2)),
high=quantity_to_jnp(self.luminosity_max, 'Jy*m^2') * jnp.ones((self.num_emitters, 2, 2))
),
name='luminosity'
).parametrised()
luminosity = jnp.tile(luminosity[:, None, :, :], (1, len(freqs), 1, 1)) # [num_source, num_chan, 2, 2]
if self.full_stokes:
luminosity = yield Prior(
tfpd.Uniform(
low=quantity_to_jnp(self.luminosity_min, 'Jy*m^2') * jnp.ones((self.num_emitters, 2, 2)),
high=quantity_to_jnp(self.luminosity_max, 'Jy*m^2') * jnp.ones((self.num_emitters, 2, 2))
),
name='luminosity'
).parametrised()
luminosity = jnp.tile(luminosity[:, None, :, :], (1, len(freqs), 1, 1)) # [num_source, num_chan, 2, 2]
else:
luminosity = yield Prior(
tfpd.Uniform(
low=quantity_to_jnp(self.luminosity_min, 'Jy*m^2') * jnp.ones((self.num_emitters)),
high=quantity_to_jnp(self.luminosity_max, 'Jy*m^2') * jnp.ones((self.num_emitters))
),
name='luminosity'
).parametrised()
luminosity = jnp.tile(luminosity[:, None], (1, len(freqs))) # [num_source, num_chan]

geodesics = self.geodesic_model.compute_near_field_geodesics(
times=times,
Expand Down
23 changes: 21 additions & 2 deletions dsa2000_cal/dsa2000_cal/forward_models/lwa_forward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@
from dsa2000_cal.calibration.probabilistic_models.gain_prior_models import DiagonalUnconstrainedGain, \
ScalarUnconstrainedGain
from dsa2000_cal.calibration.probabilistic_models.gains_per_facet_model import GainsPerFacet
from dsa2000_cal.calibration.probabilistic_models.horizon_rfi_model import HorizonRFIModel
from dsa2000_cal.calibration.probabilistic_models.probabilistic_model import AbstractProbabilisticModel
from dsa2000_cal.calibration.probabilistic_models.rfi_prior_models import FullyParameterisedRFIHorizonEmitter
from dsa2000_cal.forward_models.forward_model import BaseForwardModel
from dsa2000_cal.forward_models.synthetic_sky_model.synthetic_sky_model_producer import SyntheticSkyModelProducer
from dsa2000_cal.forward_models.systematics.dish_effects_simulation import DishEffectsParams
from dsa2000_cal.gain_models.gain_model import GainModel
from dsa2000_cal.measurement_sets.measurement_set import MeasurementSet
from dsa2000_cal.visibility_model.facet_model import FacetModel
from dsa2000_cal.visibility_model.rime_model import RIMEModel
from dsa2000_cal.visibility_model.source_models.rfi.rfi_emitter_source_model import RFIEmitterPredict


@dataclasses.dataclass(eq=False)
Expand Down Expand Up @@ -186,15 +189,31 @@ def _build_calibration_probabilistic_models(

if ms.is_full_stokes():
gain_prior_model = DiagonalUnconstrainedGain()
rfi_prior_model = FullyParameterisedRFIHorizonEmitter(
beam_gain_model=ms.beam_gain_model,
geodesic_model=ms.geodesic_model,
full_stokes=True
)
else:
gain_prior_model = ScalarUnconstrainedGain()
rfi_prior_model = FullyParameterisedRFIHorizonEmitter(
beam_gain_model=ms.beam_gain_model,
geodesic_model=ms.geodesic_model,
full_stokes=False
)

gains_per_facet = GainsPerFacet(
gain_prior_model=gain_prior_model,
rime_model=rime_model
)
# TODO: Construct RFI parameterisation per RFI source.
horizon_rfi = HorizonRFIModel(
rfi_prior_model=rfi_prior_model,
rfi_predict=RFIEmitterPredict(
delay_engine=ms.near_field_delay_engine,
convention=ms.meta.convention
)
)

probabilistic_models = [gains_per_facet]
probabilistic_models = [gains_per_facet, horizon_rfi]

return probabilistic_models
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import itertools
import time

import astropy.coordinates as ac
import astropy.units as au
import jax
import numpy as np
import pytest
from jax import numpy as jnp

from dsa2000_cal.common.types import complex_type
from dsa2000_cal.assets.content_registry import fill_registries
from dsa2000_cal.assets.registries import source_model_registry
from dsa2000_cal.common.types import complex_type, mp_policy
from dsa2000_cal.delay_models.far_field import VisibilityCoords
from dsa2000_cal.visibility_model.source_models.celestial.fits_source_model import FITSSourceModel, FITSPredict
from dsa2000_cal.visibility_model.source_models.celestial.gaussian_source_model import GaussianPredict, \
GaussianModelData
from dsa2000_cal.visibility_model.source_models.celestial.point_source_model import PointPredict, PointModelData
Expand Down Expand Up @@ -149,3 +155,52 @@ def test_benchmark_performance_point_sources(di_gains):
visibilities = f(model_data=model_data, visibility_coords=visibility_coords).block_until_ready()
t1 = time.time()
print(f"Time taken for {num_ant} antennas, {num_chan} channels, {num_source} sources: {t1 - t0:.6f} s")


def build_mock_visibility_coord(ant: int, time: int) -> VisibilityCoords:
rows = (ant * (ant - 1) // 2) * time
uvw = 20e3 * jax.random.normal(jax.random.PRNGKey(42), (rows, 3))
uvw = uvw.at[:, 2].mul(1e-3)
time_obs = jnp.zeros((rows,))
antenna_1 = jax.random.randint(jax.random.PRNGKey(42), (rows,), 0, ant)
antenna_2 = jax.random.randint(jax.random.PRNGKey(43), (rows,), 0, ant)
time_idx = jax.random.randint(jax.random.PRNGKey(44), (rows,), 0, time)

visibility_coords = VisibilityCoords(
uvw=mp_policy.cast_to_length(uvw),
time_obs=mp_policy.cast_to_time(time_obs),
antenna_1=mp_policy.cast_to_index(antenna_1),
antenna_2=mp_policy.cast_to_index(antenna_2),
time_idx=mp_policy.cast_to_index(time_idx)
)
return visibility_coords


@pytest.mark.parametrize('source', ['cas_a', 'cyg_a', 'tau_a', 'vir_a'])
@pytest.mark.parametrize('chan', [1, 16])
@pytest.mark.parametrize('ant', [256, 2048])
def test_benchmark_fits_predict(source, chan: int, ant: int):
fill_registries()
wsclean_fits_files = source_model_registry.get_instance(
source_model_registry.get_match(source)).get_wsclean_fits_files()
# -04:00:28.608,40.43.33.595
phase_tracking = ac.ICRS(ra=-4 * au.hour, dec=40 * au.deg)

freqs = au.Quantity(np.linspace(55, 70, chan), 'MHz')

fits_sources = FITSSourceModel.from_wsclean_model(wsclean_fits_files=wsclean_fits_files,
phase_tracking=phase_tracking, freqs=freqs, full_stokes=False)

visibility_coords = build_mock_visibility_coord(ant, 1)
model_data = fits_sources.get_model_data()

def run(model_data, visibility_coords):
faint_predict = FITSPredict(num_threads=8)
return faint_predict.predict(model_data=model_data, visibility_coords=visibility_coords)

run = jax.jit(run).lower(model_data, visibility_coords).compile()

t0 = time.time()
jax.block_until_ready(run(model_data, visibility_coords))
t1 = time.time()
print(f"Time taken for {source}: {t1 - t0:.6f} s")

0 comments on commit 0d6ba3b

Please sign in to comment.