Skip to content

Commit

Permalink
more tests and refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
edikedik committed Feb 27, 2024
1 parent e3551b0 commit 3eb3fae
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 41 deletions.
15 changes: 9 additions & 6 deletions lXtractor/collection/constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,10 @@ def _setup_ref_kws(
elif isinstance(ref_kw, abc.Mapping):
return [ref_kw for _ in range(num_refs)]
else:
raise TypeError("Invalid type for a reference arguments. Expected a")
raise TypeError(
f"Invalid type {type(ref_kw)} for a reference arguments. "
f"Expected a mapping or a sequence."
)

def _setup_references(self) -> list[PyHMMer]:
try:
Expand Down Expand Up @@ -316,7 +319,7 @@ def _fetch(self, ids: abc.Iterable[str], source: str) -> t.Any:
return res

def parse_inputs(self, inputs: abc.Iterable[t.Any]) -> abc.Iterator[_IT]:
yield from chain.from_iterable(map(self._parse_id, inputs))
yield from chain.from_iterable(map(self._parse_inp, inputs))

def run_batch(self, items: _ITL) -> lxc.ChainList[_CT]:
logger.info(f"Received batch of {len(items)} items.")
Expand Down Expand Up @@ -431,7 +434,7 @@ def resume_with(
yield from self._run(None, True)

@abstractmethod
def _parse_id(self, x: t.Any) -> abc.Iterator[_IT]:
def _parse_inp(self, x: t.Any) -> abc.Iterator[_IT]:
pass

@abstractmethod
Expand All @@ -454,7 +457,7 @@ class SeqCollectionConstructor(
def item_list_type(self) -> t.Type[SeqItemList]:
return SeqItemList

def _parse_id(self, x: t.Any) -> abc.Iterator[SeqItem]:
def _parse_inp(self, x: t.Any) -> abc.Iterator[SeqItem]:
if isinstance(x, SeqItem):
yield x
elif isinstance(x, str):
Expand All @@ -481,7 +484,7 @@ class StrCollectionConstructor(
def item_list_type(self) -> t.Type[StrItemList]:
return StrItemList

def _parse_id(self, x: t.Any) -> abc.Iterator[StrItem]:
def _parse_inp(self, x: t.Any) -> abc.Iterator[StrItem]:
match x:
case StrItem():
yield x
Expand Down Expand Up @@ -525,7 +528,7 @@ class MapCollectionConstructor(
def item_list_type(self) -> t.Type[MapItemList]:
return MapItemList

def _parse_id(self, x: t.Any) -> abc.Iterator[MapItem]:
def _parse_inp(self, x: t.Any) -> abc.Iterator[MapItem]:
match x:
case MapItem():
yield x
Expand Down
27 changes: 7 additions & 20 deletions lXtractor/collection/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import lXtractor.chain as lxc
from lXtractor.core.config import Config, DefaultConfig
from lXtractor.core.exceptions import MissingData
from lXtractor.core.exceptions import MissingData, FormatError
from lXtractor.ext import AlphaFold, PDB, UniProt, SIFTS

_RESOURCES = Path(__file__).parent.parent / "resources"
Expand All @@ -23,19 +23,6 @@
_CT = t.TypeVar("_CT", lxc.ChainSequence, lxc.ChainStructure, lxc.Chain)


def _parse_str_id(x: str) -> str | tuple[str, tuple[str, ...]]:
if ":" in x:
id_, chains = x.split(":", maxsplit=1)
return id_, tuple(chains.split(","))
return x


def _validate_seq_id(x: t.Any) -> str:
if not isinstance(x, str):
raise TypeError(f"Expected sequence ID to be of string type, got {type(x)}.")
return x


class ConstructorConfig(Config):
def __init__(
self,
Expand Down Expand Up @@ -198,6 +185,10 @@ def from_str(cls, inp: str) -> abc.Iterator[t.Self]:

@classmethod
def from_tuple(cls, inp: tuple[str, abc.Sequence[str]]) -> abc.Iterator[t.Self]:
if isinstance(inp[1], str):
raise FormatError(
f"Strings are disallowed as a second element in input {inp}"
)
for chain_id in inp[1]:
yield cls(inp[0], chain_id)

Expand Down Expand Up @@ -271,10 +262,6 @@ def as_strings(self):
def item_type(self) -> t.Type[_IT]:
pass

@classmethod
def from_chains(cls):
return cls()


class SeqItemList(ItemList[SeqItem]):
@property
Expand Down Expand Up @@ -323,13 +310,13 @@ def prep_for_init(self, paths: CollectionPaths):
for g, gg in self.iter_groups():
str_items = StrItemList(it.str_item for it in gg)
seq_path = paths.sequence_files / f"{g}.fasta"
yield seq_path, list(str_items.prep_for_init())
yield seq_path, list(str_items.prep_for_init(paths))

def as_strings(self):
for g, gg in self.iter_groups():
str_items = StrItemList(it.str_item for it in gg)
strs = ";".join(str_items.as_strings())
yield f"{g}:{strs}"
yield f"{g}=>{strs}"


_ITL = t.TypeVar("_ITL", SeqItemList, StrItemList, MapItemList)
Expand Down
Loading

0 comments on commit 3eb3fae

Please sign in to comment.