Skip to content

Commit

Permalink
unit test for obspriors basis
Browse files Browse the repository at this point in the history
  • Loading branch information
sblunt committed Feb 29, 2024
1 parent 5e23a70 commit 77c8b5e
Showing 1 changed file with 31 additions and 3 deletions.
34 changes: 31 additions & 3 deletions tests/test_basis_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,33 @@ def test_period_basis():
assert np.allclose(original, sample_copy)


def test_obspriors_basis():
"""Check that converting from the ObsPriors basis and back results in the
same starting condition."""

filename = "{}/GJ504.csv".format(DATADIR)
data_table = read_input.read_file(filename)
my_system = system.System(1, data_table, 1.75, 51.44, fitting_basis="ObsPriors")

num_samples = 1
samples = np.empty([len(my_system.sys_priors), num_samples])
for i in range(len(my_system.sys_priors)):
if hasattr(my_system.sys_priors[i], "draw_samples"):
samples[i, :] = my_system.sys_priors[i].draw_samples(num_samples)
else:
samples[i, :] = my_system.sys_priors[i] * np.ones(num_samples)
sample_copy = samples.copy()

# MCMC Format
test = samples[:, 0].copy()
initial_tp = test[-3]

conversion = my_system.basis.to_standard_basis(test)
original = my_system.basis.to_obspriors_basis(conversion, after_date=initial_tp - 1)

assert np.allclose(original, sample_copy[:, 0])


def test_semi_amp_basis():
"""
For both MCMC and OFTI param formats, make the conversion to the standard
Expand Down Expand Up @@ -240,6 +267,7 @@ def test_xyz_basis():


if __name__ == "__main__":
test_period_basis()
test_semi_amp_basis()
test_xyz_basis()
# test_period_basis()
# test_semi_amp_basis()
# test_xyz_basis()
test_obspriors_basis()

0 comments on commit 77c8b5e

Please sign in to comment.