Skip to content

Commit

Permalink
Added support for multimers
Browse files Browse the repository at this point in the history
  • Loading branch information
DimaMolod committed Dec 12, 2024
1 parent cbb7983 commit 1dedb1f
Showing 1 changed file with 41 additions and 17 deletions.
58 changes: 41 additions & 17 deletions alphapulldown/folding_backend/alphafold3_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,22 +188,24 @@ def _convert_to_fold_input(
object_to_model: Union[MonomericObject, ChoppedObject, MultimericObject],
random_seed: int,
) -> folding_input.Input:
"""Convert a given object to AlphaFold3 fold input."""
def msa_array_to_a3m(msa_array):
msa_sequences = []
for i, msa_seq in enumerate(msa_array):
seq_str = ''.join([residue_constants.ID_TO_HHBLITS_AA.get(int(aa), 'X') for aa in msa_seq])
msa_sequences.append(f'>sequence_{i}\n{seq_str}')
return '\n'.join(msa_sequences)
"""Convert a given object to AlphaFold3 fold input.
This function now supports MonomericObject, ChoppedObject, and MultimericObject.
For a MultimericObject, each interactor is processed as if it were a monomeric object.
"""

def chain_id_generator():
chain_letters = string.ascii_uppercase
# Single-letter IDs
for c in chain_letters:
yield c
# Two-letter IDs
for first_letter in chain_letters:
for second_letter in chain_letters:
yield first_letter + second_letter

chain_id_gen = chain_id_generator()

def insert_release_date_into_mmcif(mmcif_string: str, revision_date: str = '2100-01-01') -> str:
pdb_data = mmcif_string.splitlines()
header_lines = [
Expand All @@ -215,15 +217,26 @@ def insert_release_date_into_mmcif(mmcif_string: str, revision_date: str = '2100
pdb_data.insert(3, release_date_line)
return "\n".join(pdb_data)

chain_id_gen = chain_id_generator()
def msa_array_to_a3m(msa_array):
"""Converts MSA numpy array to A3M formatted string."""
msa_sequences = []
for i, msa_seq in enumerate(msa_array):
seq_str = ''.join([residue_constants.ID_TO_HHBLITS_AA.get(int(aa), 'X') for aa in msa_seq])
msa_sequences.append(f'>sequence_{i}\n{seq_str}')
return '\n'.join(msa_sequences)

if isinstance(object_to_model, (MonomericObject, ChoppedObject)):
chain_id = next(chain_id_gen)
sequence = object_to_model.sequence
msa_array = object_to_model.feature_dict.get('msa')
def _monomeric_to_chain(
mono_obj: Union[MonomericObject, ChoppedObject],
chain_id: str
) -> folding_input.ProteinChain:
"""Converts a single MonomericObject or ChoppedObject into a ProteinChain."""
sequence = mono_obj.sequence
feature_dict = mono_obj.feature_dict

# Convert MSA arrays to A3M.
msa_array = feature_dict.get('msa')
unpaired_msa = msa_array_to_a3m(msa_array) if msa_array is not None else ""
paired_msa = ""
feature_dict = object_to_model.feature_dict
paired_msa = "" # For this simplified logic, no paired MSA is handled here.

# Process templates if present
templates = []
Expand Down Expand Up @@ -270,8 +283,6 @@ def insert_release_date_into_mmcif(mmcif_string: str, revision_date: str = '2100
query_to_template_map=query_to_template_map,
)
)
else:
templates = []

chain = folding_input.ProteinChain(
id=chain_id,
Expand All @@ -281,8 +292,20 @@ def insert_release_date_into_mmcif(mmcif_string: str, revision_date: str = '2100
paired_msa=paired_msa,
templates=templates,
)
chains = [chain]
return chain

# Main logic depending on object type
if isinstance(object_to_model, (MonomericObject, ChoppedObject)):
# Single chain
chain_id = next(chain_id_gen)
chains = [_monomeric_to_chain(object_to_model, chain_id)]
elif isinstance(object_to_model, MultimericObject):
# Multiple chains - each interactor is a MonomericObject
chains = []
for interactor in object_to_model.interactors:
chain_id = next(chain_id_gen)
chain = _monomeric_to_chain(interactor, chain_id)
chains.append(chain)
else:
logging.error("Unsupported object type for folding input conversion.")
raise TypeError("Unsupported object type for folding input conversion.")
Expand All @@ -293,6 +316,7 @@ def insert_release_date_into_mmcif(mmcif_string: str, revision_date: str = '2100
chains=chains,
)


def write_outputs(
all_inference_results: Sequence[ResultsForSeed],
output_dir: os.PathLike[str] | str,
Expand Down

0 comments on commit 1dedb1f

Please sign in to comment.