diff --git a/alphafold/common/confidence.py b/alphafold/common/confidence.py index 598940707..e5ea99af0 100644 --- a/alphafold/common/confidence.py +++ b/alphafold/common/confidence.py @@ -17,8 +17,20 @@ import jax.numpy as jnp import jax import numpy as np +from alphafold.common import residue_constants import scipy.special +def compute_tol(prev_pos, current_pos, mask, use_jnp=False): + # Early stopping criteria based on criteria used in + # AF2Complex: https://www.nature.com/articles/s41467-022-29394-2 + _np = jnp if use_jnp else np + dist = lambda x:_np.sqrt(((x[:,None] - x[None,:])**2).sum(-1)) + ca_idx = residue_constants.atom_order['CA'] + sq_diff = _np.square(dist(prev_pos[:,ca_idx])-dist(current_pos[:,ca_idx])) + mask_2d = mask[:,None] * mask[None,:] + return _np.sqrt((sq_diff * mask_2d).sum()/mask_2d.sum() + 1e-8) + + def compute_plddt(logits, use_jnp=False): """Computes per-residue pLDDT from logits. Args: diff --git a/alphafold/model/config.py b/alphafold/model/config.py index f25cef7f4..f2b35a86a 100644 --- a/alphafold/model/config.py +++ b/alphafold/model/config.py @@ -22,7 +22,6 @@ NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES - def model_config(name: str) -> ml_collections.ConfigDict: """Get the ConfigDict of a CASP14 model.""" @@ -378,7 +377,7 @@ def model_config(name: str) -> ml_collections.ConfigDict: }, 'global_config': { 'bfloat16': True, - 'bfloat16_output': True, + 'bfloat16_output': False, 'deterministic': False, 'multimer_mode': False, 'subbatch_size': 4, @@ -616,7 +615,7 @@ def model_config(name: str) -> ml_collections.ConfigDict: }, 'global_config': { 'bfloat16': True, - 'bfloat16_output': True, + 'bfloat16_output': False, 'deterministic': False, 'multimer_mode': True, 'subbatch_size': 4, diff --git a/alphafold/model/model.py b/alphafold/model/model.py index aa309f2e9..88e90f1f4 100644 --- a/alphafold/model/model.py +++ b/alphafold/model/model.py @@ -148,25 +148,30 @@ def predict(self, L = aatype.shape[1] # initialize - def z(shape, dtype=np.float32): return np.zeros(shape, dtype=dtype) - dtype = jnp.bfloat16 if self.config.model.global_config.bfloat16 else np.float32 - prev = {'prev_msa_first_row': z([L,256], dtype), - 'prev_pair': z([L,L,128], dtype), - 'prev_pos': z([L,37,3])} + + zeros = lambda shape: np.zeros(shape, dtype=np.float16) + prev = {'prev_msa_first_row': zeros([L,256]), + 'prev_pair': zeros([L,L,128]), + 'prev_pos': zeros([L,37,3])} def run(key, feat, prev): - outputs = jax.tree_map(lambda x:np.asarray(x), - self.apply(self.params, key, {**feat, "prev":prev})) - prev = outputs.pop("prev") - return outputs, prev + def _jnp_to_np(x): + for k, v in x.items(): + if isinstance(v, dict): + x[k] = _jnp_to_np(v) + else: + x[k] = np.asarray(v,np.float16) + return x + result = _jnp_to_np(self.apply(self.params, key, {**feat, "prev":prev})) + prev = result.pop("prev") + return result, prev + # initialize random key key = jax.random.PRNGKey(random_seed) # iterate through recyckes - stop = False - for r in range(num_iters): - + for r in range(num_iters): # grab subset of features if self.multimer_mode: sub_feat = feat @@ -180,30 +185,18 @@ def run(key, feat, prev): result, prev = run(sub_key, sub_feat, prev) if return_representations: - result["representations"] = {"pair": prev["prev_pair"].astype(np.float32), - "single": prev["prev_msa_first_row"].astype(np.float32)} - # decide when to stop - tol = self.config.model.recycle_early_stop_tolerance - sco = self.config.model.stop_at_score - if result["ranking_confidence"] > sco: - stop = True - if not stop and tol > 0: - ca_idx = residue_constants.atom_order['CA'] - pos = result["structure_module"]["final_atom_positions"][:,ca_idx] - dist = np.sqrt(np.square(pos[:,None]-pos[None,:]).sum(-1)) - if r > 0: - sq_diff = np.square(dist - prev_dist) - seq_mask = feat["seq_mask"] if self.multimer_mode else feat["seq_mask"][0] - mask_2d = seq_mask[:,None] * seq_mask[None,:] - result["diff"] = np.sqrt((sq_diff * mask_2d).sum()/mask_2d.sum()) - if result["diff"] < tol: - stop = True - prev_dist = dist + result["representations"] = {"pair": prev["prev_pair"], + "single": prev["prev_msa_first_row"]} + # callback if callback is not None: callback(result, r) - if stop: break + # decide when to stop + if result["ranking_confidence"] > self.config.model.stop_at_score: + break + if r > 0 and result["tol"] < self.config.model.recycle_early_stop_tolerance: + break logging.info('Output shape was %s', tree.map_structure(lambda x: x.shape, result)) return result, r \ No newline at end of file diff --git a/alphafold/model/modules.py b/alphafold/model/modules.py index 51ac23c66..73f9cdd96 100644 --- a/alphafold/model/modules.py +++ b/alphafold/model/modules.py @@ -182,12 +182,19 @@ def get_prev(ret): prev = batch.pop("prev",None) if batch["aatype"].ndim == 2: batch = jax.tree_map(lambda x:x[0], batch) + + # initialize if prev is None: - L = batch["aatype"].shape[0] - dtype = jnp.bfloat16 if self.global_config.bfloat16 else jnp.float32 - prev = {'prev_msa_first_row': jnp.zeros([L,256], dtype=dtype), - 'prev_pair': jnp.zeros([L,L,128],dtype=dtype), + + L = num_residues + prev = {'prev_msa_first_row': jnp.zeros([L,256]), + 'prev_pair': jnp.zeros([L,L,128]), 'prev_pos': jnp.zeros([L,37,3])} + else: + for k,v in prev.items(): + if v.dtype == jnp.float16: + prev[k] = v.astype(jnp.float32) + ret = impl(batch={**batch, **prev}, is_training=is_training) ret["prev"] = get_prev(ret) @@ -200,7 +207,13 @@ def get_prev(ret): mask=batch["seq_mask"], rank_by=self.config.rank_by, use_jnp=True)) - + + ret["tol"] = confidence.compute_tol( + prev["prev_pos"], + ret["prev"]["prev_pos"], + batch["seq_mask"], + use_jnp=True) + return ret class AlphaFoldIteration(hk.Module): @@ -426,13 +439,18 @@ def slice_recycle_idx(x): ensemble_representations=ensemble_representations) emb_config = self.config.embeddings_and_evoformer - prev = batch.pop("prev", None) + # initialize + prev = batch.pop("prev", None) if prev is None: L = num_residues - dtype = jnp.bfloat16 if self.global_config.bfloat16 else jnp.float32 - prev = {'prev_msa_first_row': jnp.zeros([L,256], dtype=dtype), - 'prev_pair': jnp.zeros([L,L,128],dtype=dtype), + prev = {'prev_msa_first_row': jnp.zeros([L,256]), + 'prev_pair': jnp.zeros([L,L,128]), 'prev_pos': jnp.zeros([L,37,3])} + else: + for k,v in prev.items(): + if v.dtype == jnp.float16: + prev[k] = v.astype(jnp.float32) + ret = do_call(prev=prev, recycle_idx=0) ret["prev"] = get_prev(ret) @@ -446,9 +464,15 @@ def slice_recycle_idx(x): # add confidence metrics ret.update(confidence.get_confidence_metrics( prediction_result=ret, - mask=batch["seq_mask"], + mask=batch["seq_mask"][0], rank_by=self.config.rank_by, - use_jnp=True)) + use_jnp=True)) + + ret["tol"] = confidence.compute_tol( + prev["prev_pos"], + ret["prev"]["prev_pos"], + batch["seq_mask"][0], + use_jnp=True) return ret diff --git a/alphafold/model/modules_multimer.py b/alphafold/model/modules_multimer.py index b0746a801..7cd8a6fd5 100644 --- a/alphafold/model/modules_multimer.py +++ b/alphafold/model/modules_multimer.py @@ -442,8 +442,20 @@ def apply_network(prev, safe_key): batch=recycled_batch, is_training=is_training, safe_key=safe_key) - - ret = apply_network(prev=batch.pop("prev"), safe_key=safe_key) + + # initialize + prev = batch.pop("prev", None) + if prev is None: + L = num_residues + prev = {'prev_msa_first_row': jnp.zeros([L,256]), + 'prev_pair': jnp.zeros([L,L,128]), + 'prev_pos': jnp.zeros([L,37,3])} + else: + for k,v in prev.items(): + if v.dtype == jnp.float16: + prev[k] = v.astype(jnp.float32) + + ret = apply_network(prev=prev, safe_key=safe_key) ret["prev"] = get_prev(ret) if not return_representations: @@ -456,6 +468,12 @@ def apply_network(prev, safe_key): rank_by=self.config.rank_by, use_jnp=True)) + ret["tol"] = confidence.compute_tol( + prev["prev_pos"], + ret["prev"]["prev_pos"], + batch["seq_mask"], + use_jnp=True) + return ret class EmbeddingsAndEvoformer(hk.Module): diff --git a/setup.py b/setup.py index 72f4d3c32..e5314e031 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ setup( name='alphafold-colabfold', - version='2.3.3', + version='2.3.4', long_description_content_type='text/markdown', description='An implementation of the inference pipeline of AlphaFold v2.3.1. ' 'This is a completely new model that was entered as AlphaFold2 in CASP14 '