From 1dedb1f86e1ae58829bc8ab74fe17810c84e37d3 Mon Sep 17 00:00:00 2001 From: Dima Molodenskiy Date: Thu, 12 Dec 2024 11:57:32 +0100 Subject: [PATCH] Added support for multimers --- .../folding_backend/alphafold3_backend.py | 58 +++++++++++++------ 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/alphapulldown/folding_backend/alphafold3_backend.py b/alphapulldown/folding_backend/alphafold3_backend.py index c5e3ff60..30678a6b 100644 --- a/alphapulldown/folding_backend/alphafold3_backend.py +++ b/alphapulldown/folding_backend/alphafold3_backend.py @@ -188,22 +188,24 @@ def _convert_to_fold_input( object_to_model: Union[MonomericObject, ChoppedObject, MultimericObject], random_seed: int, ) -> folding_input.Input: - """Convert a given object to AlphaFold3 fold input.""" - def msa_array_to_a3m(msa_array): - msa_sequences = [] - for i, msa_seq in enumerate(msa_array): - seq_str = ''.join([residue_constants.ID_TO_HHBLITS_AA.get(int(aa), 'X') for aa in msa_seq]) - msa_sequences.append(f'>sequence_{i}\n{seq_str}') - return '\n'.join(msa_sequences) + """Convert a given object to AlphaFold3 fold input. + + This function now supports MonomericObject, ChoppedObject, and MultimericObject. + For a MultimericObject, each interactor is processed as if it were a monomeric object. + """ def chain_id_generator(): chain_letters = string.ascii_uppercase + # Single-letter IDs for c in chain_letters: yield c + # Two-letter IDs for first_letter in chain_letters: for second_letter in chain_letters: yield first_letter + second_letter + chain_id_gen = chain_id_generator() + def insert_release_date_into_mmcif(mmcif_string: str, revision_date: str = '2100-01-01') -> str: pdb_data = mmcif_string.splitlines() header_lines = [ @@ -215,15 +217,26 @@ def insert_release_date_into_mmcif(mmcif_string: str, revision_date: str = '2100 pdb_data.insert(3, release_date_line) return "\n".join(pdb_data) - chain_id_gen = chain_id_generator() + def msa_array_to_a3m(msa_array): + """Converts MSA numpy array to A3M formatted string.""" + msa_sequences = [] + for i, msa_seq in enumerate(msa_array): + seq_str = ''.join([residue_constants.ID_TO_HHBLITS_AA.get(int(aa), 'X') for aa in msa_seq]) + msa_sequences.append(f'>sequence_{i}\n{seq_str}') + return '\n'.join(msa_sequences) - if isinstance(object_to_model, (MonomericObject, ChoppedObject)): - chain_id = next(chain_id_gen) - sequence = object_to_model.sequence - msa_array = object_to_model.feature_dict.get('msa') + def _monomeric_to_chain( + mono_obj: Union[MonomericObject, ChoppedObject], + chain_id: str + ) -> folding_input.ProteinChain: + """Converts a single MonomericObject or ChoppedObject into a ProteinChain.""" + sequence = mono_obj.sequence + feature_dict = mono_obj.feature_dict + + # Convert MSA arrays to A3M. + msa_array = feature_dict.get('msa') unpaired_msa = msa_array_to_a3m(msa_array) if msa_array is not None else "" - paired_msa = "" - feature_dict = object_to_model.feature_dict + paired_msa = "" # For this simplified logic, no paired MSA is handled here. # Process templates if present templates = [] @@ -270,8 +283,6 @@ def insert_release_date_into_mmcif(mmcif_string: str, revision_date: str = '2100 query_to_template_map=query_to_template_map, ) ) - else: - templates = [] chain = folding_input.ProteinChain( id=chain_id, @@ -281,8 +292,20 @@ def insert_release_date_into_mmcif(mmcif_string: str, revision_date: str = '2100 paired_msa=paired_msa, templates=templates, ) - chains = [chain] + return chain + # Main logic depending on object type + if isinstance(object_to_model, (MonomericObject, ChoppedObject)): + # Single chain + chain_id = next(chain_id_gen) + chains = [_monomeric_to_chain(object_to_model, chain_id)] + elif isinstance(object_to_model, MultimericObject): + # Multiple chains - each interactor is a MonomericObject + chains = [] + for interactor in object_to_model.interactors: + chain_id = next(chain_id_gen) + chain = _monomeric_to_chain(interactor, chain_id) + chains.append(chain) else: logging.error("Unsupported object type for folding input conversion.") raise TypeError("Unsupported object type for folding input conversion.") @@ -293,6 +316,7 @@ def insert_release_date_into_mmcif(mmcif_string: str, revision_date: str = '2100 chains=chains, ) + def write_outputs( all_inference_results: Sequence[ResultsForSeed], output_dir: os.PathLike[str] | str,