Skip to content

Commit

Permalink
Add support for bit compounds definitions in build_hit() config file (
Browse files Browse the repository at this point in the history
#531)

see #531
---------

Co-authored-by: Rosanna Deckert <[email protected]>
  • Loading branch information
rosannadeckert and Rosanna Deckert authored Dec 5, 2023
1 parent ad1be19 commit daa915f
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 1 deletion.
24 changes: 23 additions & 1 deletion src/pygama/hit/build_hit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections import OrderedDict

import numpy as np
from lgdo import LH5Iterator, LH5Store, ls
from lgdo import Array, LH5Iterator, LH5Store, ls

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -131,6 +131,28 @@ def build_hit(

outtbl_obj = tbl_obj.eval(cfg["operations"])

# make high level flags
if "aggregations" in cfg:
for high_lvl_flag, flags in cfg["aggregations"].items():
flags_list = list(flags.values())
n_flags = len(flags_list)
if n_flags <= 8:
flag_dtype = np.uint8
elif n_flags <= 16:
flag_dtype = np.uint16
elif n_flags <= 32:
flag_dtype = np.uint32
else:
flag_dtype = np.uint64

df_flags = outtbl_obj.get_dataframe(flags_list)
flag_values = df_flags.values.astype(flag_dtype)

multiplier = 2 ** np.arange(n_flags, dtype=flag_values.dtype)
flag_out = np.dot(flag_values, multiplier)

outtbl_obj.add_field(high_lvl_flag, Array(flag_out))

# remove or add columns according to "outputs" in the configuration
# dictionary
if "outputs" in cfg:
Expand Down
28 changes: 28 additions & 0 deletions tests/hit/configs/aggregations-hit-config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"outputs": ["is_valid_rt", "is_valid_t0", "is_valid_tmax", "aggr1", "aggr2"],
"operations": {
"is_valid_rt": {
"expression": "((tp_90-tp_10)>96) & ((tp_50-tp_10)>=16)",
"parameters": {}
},
"is_valid_t0": {
"expression": "(tp_0_est>47000) & (tp_0_est<55000)",
"parameters": {}
},
"is_valid_tmax": {
"expression": "(tp_max>47000) & (tp_max<120000)",
"parameters": {}
}
},
"aggregations": {
"aggr1": {
"bit0": "is_valid_rt",
"bit1": "is_valid_t0",
"bit2": "is_valid_tmax"
},
"aggr2": {
"bit0": "is_valid_t0",
"bit1": "is_valid_tmax"
}
}
}
47 changes: 47 additions & 0 deletions tests/hit/test_build_hit.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,53 @@ def test_outputs_specification(dsp_test_file, tmptestdir):
assert list(obj.keys()) == ["calE", "AoE", "A_max"]


def test_aggregation_outputs(dsp_test_file, tmptestdir):
outfile = f"{tmptestdir}/LDQTA_r117_20200110T105115Z_cal_geds_hit.lh5"

build_hit(
dsp_test_file,
outfile=outfile,
hit_config=f"{config_dir}/aggregations-hit-config.json",
wo_mode="overwrite",
)

sto = LH5Store()
obj, _ = sto.read_object("/geds/hit", outfile)
assert list(obj.keys()) == [
"is_valid_rt",
"is_valid_t0",
"is_valid_tmax",
"aggr1",
"aggr2",
]

df = store.load_dfs(
outfile,
["is_valid_rt", "is_valid_t0", "is_valid_tmax", "aggr1", "aggr2"],
"geds/hit/",
)

# aggr1 consists of 3 bits --> max number can be 7, aggr2 consists of 2 bits so max number can be 3
assert not (df["aggr1"] > 7).any()
assert not (df["aggr2"] > 3).any()

def get_bit(x, n):
"""bit numbering from right to left, starting with bit 0"""
return x & (1 << n) != 0

df["bit0_check"] = df.apply(lambda row: get_bit(row["aggr1"], 0), axis=1)
are_identical = df["bit0_check"].equals(df.is_valid_rt)
assert are_identical

df["bit1_check"] = df.apply(lambda row: get_bit(row["aggr1"], 1), axis=1)
are_identical = df["bit1_check"].equals(df.is_valid_t0)
assert are_identical

df["bit2_check"] = df.apply(lambda row: get_bit(row["aggr1"], 2), axis=1)
are_identical = df["bit2_check"].equals(df.is_valid_tmax)
assert are_identical


def test_build_hit_spms_basic(dsp_test_file_spm, tmptestdir):
out_file = f"{tmptestdir}/L200-comm-20211130-phy-spms_hit.lh5"
build_hit(
Expand Down

0 comments on commit daa915f

Please sign in to comment.