diff --git a/src/pygama/hit/build_hit.py b/src/pygama/hit/build_hit.py index e531fa872..5cc34202f 100644 --- a/src/pygama/hit/build_hit.py +++ b/src/pygama/hit/build_hit.py @@ -9,7 +9,7 @@ from collections import OrderedDict import numpy as np -from lgdo import LH5Iterator, LH5Store, ls +from lgdo import Array, LH5Iterator, LH5Store, ls log = logging.getLogger(__name__) @@ -131,6 +131,28 @@ def build_hit( outtbl_obj = tbl_obj.eval(cfg["operations"]) + # make high level flags + if "aggregations" in cfg: + for high_lvl_flag, flags in cfg["aggregations"].items(): + flags_list = list(flags.values()) + n_flags = len(flags_list) + if n_flags <= 8: + flag_dtype = np.uint8 + elif n_flags <= 16: + flag_dtype = np.uint16 + elif n_flags <= 32: + flag_dtype = np.uint32 + else: + flag_dtype = np.uint64 + + df_flags = outtbl_obj.get_dataframe(flags_list) + flag_values = df_flags.values.astype(flag_dtype) + + multiplier = 2 ** np.arange(n_flags, dtype=flag_values.dtype) + flag_out = np.dot(flag_values, multiplier) + + outtbl_obj.add_field(high_lvl_flag, Array(flag_out)) + # remove or add columns according to "outputs" in the configuration # dictionary if "outputs" in cfg: diff --git a/tests/hit/configs/aggregations-hit-config.json b/tests/hit/configs/aggregations-hit-config.json new file mode 100644 index 000000000..57237ce80 --- /dev/null +++ b/tests/hit/configs/aggregations-hit-config.json @@ -0,0 +1,28 @@ +{ + "outputs": ["is_valid_rt", "is_valid_t0", "is_valid_tmax", "aggr1", "aggr2"], + "operations": { + "is_valid_rt": { + "expression": "((tp_90-tp_10)>96) & ((tp_50-tp_10)>=16)", + "parameters": {} + }, + "is_valid_t0": { + "expression": "(tp_0_est>47000) & (tp_0_est<55000)", + "parameters": {} + }, + "is_valid_tmax": { + "expression": "(tp_max>47000) & (tp_max<120000)", + "parameters": {} + } + }, + "aggregations": { + "aggr1": { + "bit0": "is_valid_rt", + "bit1": "is_valid_t0", + "bit2": "is_valid_tmax" + }, + "aggr2": { + "bit0": "is_valid_t0", + "bit1": "is_valid_tmax" + } + } +} diff --git a/tests/hit/test_build_hit.py b/tests/hit/test_build_hit.py index e497a1742..924928da6 100644 --- a/tests/hit/test_build_hit.py +++ b/tests/hit/test_build_hit.py @@ -98,6 +98,53 @@ def test_outputs_specification(dsp_test_file, tmptestdir): assert list(obj.keys()) == ["calE", "AoE", "A_max"] +def test_aggregation_outputs(dsp_test_file, tmptestdir): + outfile = f"{tmptestdir}/LDQTA_r117_20200110T105115Z_cal_geds_hit.lh5" + + build_hit( + dsp_test_file, + outfile=outfile, + hit_config=f"{config_dir}/aggregations-hit-config.json", + wo_mode="overwrite", + ) + + sto = LH5Store() + obj, _ = sto.read_object("/geds/hit", outfile) + assert list(obj.keys()) == [ + "is_valid_rt", + "is_valid_t0", + "is_valid_tmax", + "aggr1", + "aggr2", + ] + + df = store.load_dfs( + outfile, + ["is_valid_rt", "is_valid_t0", "is_valid_tmax", "aggr1", "aggr2"], + "geds/hit/", + ) + + # aggr1 consists of 3 bits --> max number can be 7, aggr2 consists of 2 bits so max number can be 3 + assert not (df["aggr1"] > 7).any() + assert not (df["aggr2"] > 3).any() + + def get_bit(x, n): + """bit numbering from right to left, starting with bit 0""" + return x & (1 << n) != 0 + + df["bit0_check"] = df.apply(lambda row: get_bit(row["aggr1"], 0), axis=1) + are_identical = df["bit0_check"].equals(df.is_valid_rt) + assert are_identical + + df["bit1_check"] = df.apply(lambda row: get_bit(row["aggr1"], 1), axis=1) + are_identical = df["bit1_check"].equals(df.is_valid_t0) + assert are_identical + + df["bit2_check"] = df.apply(lambda row: get_bit(row["aggr1"], 2), axis=1) + are_identical = df["bit2_check"].equals(df.is_valid_tmax) + assert are_identical + + def test_build_hit_spms_basic(dsp_test_file_spm, tmptestdir): out_file = f"{tmptestdir}/L200-comm-20211130-phy-spms_hit.lh5" build_hit(