diff --git a/lXtractor/collection/constructor.py b/lXtractor/collection/constructor.py index 478ad3f..543d5f5 100644 --- a/lXtractor/collection/constructor.py +++ b/lXtractor/collection/constructor.py @@ -211,7 +211,10 @@ def _setup_ref_kws( elif isinstance(ref_kw, abc.Mapping): return [ref_kw for _ in range(num_refs)] else: - raise TypeError("Invalid type for a reference arguments. Expected a") + raise TypeError( + f"Invalid type {type(ref_kw)} for a reference arguments. " + f"Expected a mapping or a sequence." + ) def _setup_references(self) -> list[PyHMMer]: try: @@ -316,7 +319,7 @@ def _fetch(self, ids: abc.Iterable[str], source: str) -> t.Any: return res def parse_inputs(self, inputs: abc.Iterable[t.Any]) -> abc.Iterator[_IT]: - yield from chain.from_iterable(map(self._parse_id, inputs)) + yield from chain.from_iterable(map(self._parse_inp, inputs)) def run_batch(self, items: _ITL) -> lxc.ChainList[_CT]: logger.info(f"Received batch of {len(items)} items.") @@ -431,7 +434,7 @@ def resume_with( yield from self._run(None, True) @abstractmethod - def _parse_id(self, x: t.Any) -> abc.Iterator[_IT]: + def _parse_inp(self, x: t.Any) -> abc.Iterator[_IT]: pass @abstractmethod @@ -454,7 +457,7 @@ class SeqCollectionConstructor( def item_list_type(self) -> t.Type[SeqItemList]: return SeqItemList - def _parse_id(self, x: t.Any) -> abc.Iterator[SeqItem]: + def _parse_inp(self, x: t.Any) -> abc.Iterator[SeqItem]: if isinstance(x, SeqItem): yield x elif isinstance(x, str): @@ -481,7 +484,7 @@ class StrCollectionConstructor( def item_list_type(self) -> t.Type[StrItemList]: return StrItemList - def _parse_id(self, x: t.Any) -> abc.Iterator[StrItem]: + def _parse_inp(self, x: t.Any) -> abc.Iterator[StrItem]: match x: case StrItem(): yield x @@ -525,7 +528,7 @@ class MapCollectionConstructor( def item_list_type(self) -> t.Type[MapItemList]: return MapItemList - def _parse_id(self, x: t.Any) -> abc.Iterator[MapItem]: + def _parse_inp(self, x: t.Any) -> abc.Iterator[MapItem]: match x: case MapItem(): yield x diff --git a/lXtractor/collection/support.py b/lXtractor/collection/support.py index 9de71d0..5c37ddc 100644 --- a/lXtractor/collection/support.py +++ b/lXtractor/collection/support.py @@ -13,7 +13,7 @@ import lXtractor.chain as lxc from lXtractor.core.config import Config, DefaultConfig -from lXtractor.core.exceptions import MissingData +from lXtractor.core.exceptions import MissingData, FormatError from lXtractor.ext import AlphaFold, PDB, UniProt, SIFTS _RESOURCES = Path(__file__).parent.parent / "resources" @@ -23,19 +23,6 @@ _CT = t.TypeVar("_CT", lxc.ChainSequence, lxc.ChainStructure, lxc.Chain) -def _parse_str_id(x: str) -> str | tuple[str, tuple[str, ...]]: - if ":" in x: - id_, chains = x.split(":", maxsplit=1) - return id_, tuple(chains.split(",")) - return x - - -def _validate_seq_id(x: t.Any) -> str: - if not isinstance(x, str): - raise TypeError(f"Expected sequence ID to be of string type, got {type(x)}.") - return x - - class ConstructorConfig(Config): def __init__( self, @@ -198,6 +185,10 @@ def from_str(cls, inp: str) -> abc.Iterator[t.Self]: @classmethod def from_tuple(cls, inp: tuple[str, abc.Sequence[str]]) -> abc.Iterator[t.Self]: + if isinstance(inp[1], str): + raise FormatError( + f"Strings are disallowed as a second element in input {inp}" + ) for chain_id in inp[1]: yield cls(inp[0], chain_id) @@ -271,10 +262,6 @@ def as_strings(self): def item_type(self) -> t.Type[_IT]: pass - @classmethod - def from_chains(cls): - return cls() - class SeqItemList(ItemList[SeqItem]): @property @@ -323,13 +310,13 @@ def prep_for_init(self, paths: CollectionPaths): for g, gg in self.iter_groups(): str_items = StrItemList(it.str_item for it in gg) seq_path = paths.sequence_files / f"{g}.fasta" - yield seq_path, list(str_items.prep_for_init()) + yield seq_path, list(str_items.prep_for_init(paths)) def as_strings(self): for g, gg in self.iter_groups(): str_items = StrItemList(it.str_item for it in gg) strs = ";".join(str_items.as_strings()) - yield f"{g}:{strs}" + yield f"{g}=>{strs}" _ITL = t.TypeVar("_ITL", SeqItemList, StrItemList, MapItemList) diff --git a/test/test_collection.py b/test/test_collection.py index 2b2d06b..f7d3edb 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -2,6 +2,7 @@ from pathlib import Path import pytest +from more_itertools import consume import lXtractor.chain as lxc from lXtractor.chain import ChainIO @@ -18,9 +19,12 @@ StrItem, SeqItem, MapItem, + SeqItemList, + StrItemList, + MapItemList, ) from lXtractor.core import Alignment -from lXtractor.core.exceptions import MissingData +from lXtractor.core.exceptions import MissingData, ConfigError, FormatError from lXtractor.ext import PyHMMer from lXtractor.variables import SeqEl from test.common import TestError, DATA, SEQUENCES @@ -326,6 +330,54 @@ def test_item_from_str(item_type, inp_str, exp_items): assert list(item_type.from_str(inp_str)) == exp_items +@pytest.mark.parametrize( + "item_type,inp,exp_items", + [ + (StrItem, ("1ABC", ["A", "B"]), [StrItem("1ABC", "A"), StrItem("1ABC", "B")]), + (MapItem, ("X", "1ABC:A"), [MapItem(SeqItem("X"), StrItem("1ABC", "A"))]), + ( + MapItem, + ("X", ["1ABC:A", "2ABC:A"]), + [ + MapItem(SeqItem("X"), StrItem("1ABC", "A")), + MapItem(SeqItem("X"), StrItem("2ABC", "A")), + ], + ), + ( + MapItem, + ("X", [("2ABC", ["A", "B"])]), + [ + MapItem(SeqItem("X"), StrItem("2ABC", "A")), + MapItem(SeqItem("X"), StrItem("2ABC", "B")), + ], + ), + ], +) +def test_item_from_tuple(item_type, inp, exp_items): + assert list(item_type.from_tuple(inp)) == exp_items + + +@pytest.mark.parametrize( + "itl,expected", + [ + (SeqItemList([SeqItem("S1")]), ["S1"]), + (StrItemList([StrItem("S1", "A"), StrItem("S1", "B")]), ["S1:A,B"]), + ( + MapItemList( + [ + MapItem(SeqItem("X"), StrItem("S1", "A")), + MapItem(SeqItem("X"), StrItem("S1", "B")), + MapItem(SeqItem("Y"), StrItem("S1", "A")), + ] + ), + ["X=>S1:A,B", "Y=>S1:A"], + ), + ], +) +def test_item_list_as_strings(itl, expected): + assert list(itl.as_strings()) == expected + + def test_constructor_config(): cfg = ConstructorConfig(source="SIFTS") assert "references" in cfg.list_fields() @@ -361,6 +413,7 @@ def make_config(base: Path, source, refs, ids, local=False): ids=ids, PDB_kwargs=dict(verbose=True), AF2_kwargs=dict(verbose=True), + write_batches=True, ) if source.lower() in ("af", "af2", "alphafold"): kws["str_fmt"] = "cif" @@ -378,15 +431,20 @@ def make_config(base: Path, source, refs, ids, local=False): ("PDB", StrCollectionConstructor), ("AF2", StrCollectionConstructor), ("SIFTS", MapCollectionConstructor), + ("INVALID", SeqCollectionConstructor), ], ) @pytest.mark.parametrize("references", [()]) def test_constructor_setup(inp, references, tmp_path): source, constr_type = inp config, dirs = make_config(tmp_path, source, references, ()) - constructor = constr_type(config) - assert all(x.exists() for x in dirs) - assert isinstance(constructor.collection, Collection) + if source == "INVALID": + with pytest.raises(ConfigError): + constr_type(config) + else: + constructor = constr_type(config) + assert all(x.exists() for x in dirs) + assert isinstance(constructor.collection, Collection) @pytest.mark.parametrize("source,const_type", [("UniProt", SeqCollectionConstructor)]) @@ -394,6 +452,8 @@ def test_constructor_setup(inp, references, tmp_path): "ref", [ DATA / "Pkinase.hmm", + DATA / "void.hmm", + DATA / "void.hmm.gz", SEQUENCES / "fasta" / "simple.fasta", PyHMMer(DATA / "Pkinase.hmm"), Alignment([("REF_SEQ", "KAL"), ("s2", "KKL")]), @@ -411,6 +471,12 @@ def test_setup_references(source, const_type, ref, tmp_path): if isinstance(ref, tuple) and ref[0] == "INVALID": with pytest.raises(TypeError): const_type(config) + elif isinstance(ref, Path) and ref.suffix not in (".hmm", ".fasta"): + with pytest.raises(NameError): + const_type(config) + elif isinstance(ref, Path) and not ref.exists(): + with pytest.raises(FileNotFoundError): + const_type(config) else: constructor = const_type(config) assert all(isinstance(x, PyHMMer) for x in constructor.references) @@ -448,11 +514,69 @@ def rename(x): assert all(x.name == "!" for x in chains.collapse_children()) +@pytest.mark.parametrize( + "ct,inputs,valid,exp_items", + [ + ( + SeqCollectionConstructor, + [SeqItem("S"), "s"], + True, + [SeqItem("S"), SeqItem("s")], + ), + ( + SeqCollectionConstructor, + [StrItem("S", "A")], + False, + [], + ), + ( + StrCollectionConstructor, + [StrItem("S", "A"), "s:A,B", ("ss", ["A"])], + True, + [ + StrItem("S", "A"), + StrItem("s", "A"), + StrItem("s", "B"), + StrItem("ss", "A"), + ], + ), + (StrCollectionConstructor, ["S"], False, []), + (StrCollectionConstructor, [("S", "A")], False, []), + ( + MapCollectionConstructor, + [MapItem(SeqItem("S"), StrItem("s", "A")), "X=>s:A", ("Y", "s:A")], + True, + [ + MapItem(SeqItem("S"), StrItem("s", "A")), + MapItem(SeqItem("X"), StrItem("s", "A")), + MapItem(SeqItem("Y"), StrItem("s", "A")), + ], + ), + (MapCollectionConstructor, [("S", "A:B", "C:D")], False, []), + ], +) +def test_inp_parsing(ct, inputs, valid, exp_items, tmp_path): + if ct is SeqCollectionConstructor: + source = "uniprot" + elif ct is StrCollectionConstructor: + source = "pdb" + else: + source = "sifts" + config, _ = make_config(tmp_path, source, [], ()) + constructor = ct(config) + if valid: + assert list(constructor.parse_inputs(inputs)) == exp_items + else: + with pytest.raises(FormatError): + list(constructor.parse_inputs(inputs)) + + PKP = DATA / "Pkinase.hmm" TEST_BATCHES = [ (SeqCollectionConstructor, "UniProt", ["P12931", "Q16644"], [PKP]), (StrCollectionConstructor, "PDB", ["2SRC:A", "2OIQ:A"], [PKP]), (StrCollectionConstructor, "AF", ["P12931", "Q16644"], [PKP]), + (MapCollectionConstructor, "SIFTS", ["P12931=>2SRC:A;2OIQ:A,B"], [PKP]), ] @@ -461,29 +585,43 @@ def rename(x): def test_run_batch(ct, source, ids, refs, local, tmp_path): config, _ = make_config(tmp_path, source, refs, (), local) + if ct is MapCollectionConstructor: + config["references_annotate_kw"] = dict(str_map_from="map_canonical") + constructor = ct(config) itl = constructor.item_list_type(constructor.parse_inputs(ids)) res = constructor.run_batch(itl) assert isinstance(res, lxc.ChainList) - assert len(res) == len(itl) - assert len(res.collapse_children()) == len(itl) - assert len(constructor.collection.get_ids()) == len(itl) * 2 + if ct is MapCollectionConstructor: + assert len(res) == len(ids) + assert len(res.collapse_children()) == len(ids) + # 1 seq and 1 seq child, 3 str and 3 str children + assert len(constructor.collection.get_ids()) == 8 + else: + assert len(res) == len(itl) + assert len(res.collapse_children()) == len(itl) + assert len(constructor.collection.get_ids()) == len(itl) * 2 + assert len(constructor.history) == 0 -def test_run_empty_batch(tmp_path): +@pytest.mark.parametrize( + "ct", [SeqCollectionConstructor, StrCollectionConstructor, MapCollectionConstructor] +) +def test_run_empty_batch(ct, tmp_path): config, _ = make_config(tmp_path, "", [PKP], (), True) - constructor = StrCollectionConstructor(config) + constructor = ct(config) res = constructor.run_batch(constructor.item_list_type()) assert isinstance(res, lxc.ChainList) assert len(res) == 0 -@pytest.mark.parametrize("ct,source,ids,refs", TEST_BATCHES) +@pytest.mark.parametrize("ct,source,ids,refs", TEST_BATCHES[:-1]) @pytest.mark.parametrize("local", [True]) def test_run(ct, source, ids, refs, local, tmp_path): config, dirs = make_config(tmp_path, source, refs, (), local) config["batch_size"] = 1 + config["keep_chains"] = True constructor = ct(config) for batch in constructor.run(constructor.parse_inputs(ids)): @@ -497,10 +635,12 @@ def test_run(ct, source, ids, refs, local, tmp_path): assert len(hist.items_missed()) == 0 assert len(hist.items_failed()) == 0 + assert isinstance(hist.join_chains(), lxc.ChainList) + assert len(constructor.collection.get_ids()) == len(ids) * 2 -@pytest.mark.parametrize("ct,source,ids,refs", TEST_BATCHES) +@pytest.mark.parametrize("ct,source,ids,refs", TEST_BATCHES[:-1]) @pytest.mark.parametrize("local", [True]) def test_fail_resume(ct, source, ids, refs, local, tmp_path): def bad_fn(_): @@ -529,8 +669,18 @@ def bad_fn(_): assert len(constructor.history) == len(items) assert len(constructor.collection.get_ids()) == len(items) * 2 - # reinitializing and rerunning doesn't cause an exception + # run to completion without stopping on failed batches + config["parent_callback"] = bad_fn constructor = ct(config) - batches = list(constructor.run(items)) - assert len(batches) == len(items) - assert len(constructor.history) == len(items) + consume(constructor.run(items, stop_on_batch_failure=False)) + assert len(constructor.history) == 2 + assert len(list(constructor.history.items_failed())) == 2 + + # However, the collection was previously constructed and remains unchanged + assert len(constructor.collection.get_ids()) == len(items) * 2 + + # # reinitializing and rerunning doesn't cause an exception + # constructor = ct(config) + # batches = list(constructor.run(items)) + # assert len(batches) == len(items) + # assert len(constructor.history) == len(items)