Skip to content

Commit

Permalink
v2.3.4 - fix memory leaks (attempt 2) (#6)
Browse files Browse the repository at this point in the history
* fix memory leaks

various edits to fix memory leaks
memory leak fix

* v2.3.4 - fix memory leaks

another attempt to fix memory leaks!

* Update config.py

* bugfix - num-ensemble
  • Loading branch information
sokrypton authored Feb 18, 2023
1 parent 8f50ccb commit 41807ea
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 49 deletions.
12 changes: 12 additions & 0 deletions alphafold/common/confidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,20 @@
import jax.numpy as jnp
import jax
import numpy as np
from alphafold.common import residue_constants
import scipy.special

def compute_tol(prev_pos, current_pos, mask, use_jnp=False):
# Early stopping criteria based on criteria used in
# AF2Complex: https://www.nature.com/articles/s41467-022-29394-2
_np = jnp if use_jnp else np
dist = lambda x:_np.sqrt(((x[:,None] - x[None,:])**2).sum(-1))
ca_idx = residue_constants.atom_order['CA']
sq_diff = _np.square(dist(prev_pos[:,ca_idx])-dist(current_pos[:,ca_idx]))
mask_2d = mask[:,None] * mask[None,:]
return _np.sqrt((sq_diff * mask_2d).sum()/mask_2d.sum() + 1e-8)


def compute_plddt(logits, use_jnp=False):
"""Computes per-residue pLDDT from logits.
Args:
Expand Down
5 changes: 2 additions & 3 deletions alphafold/model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ
NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES


def model_config(name: str) -> ml_collections.ConfigDict:
"""Get the ConfigDict of a CASP14 model."""

Expand Down Expand Up @@ -378,7 +377,7 @@ def model_config(name: str) -> ml_collections.ConfigDict:
},
'global_config': {
'bfloat16': True,
'bfloat16_output': True,
'bfloat16_output': False,
'deterministic': False,
'multimer_mode': False,
'subbatch_size': 4,
Expand Down Expand Up @@ -616,7 +615,7 @@ def model_config(name: str) -> ml_collections.ConfigDict:
},
'global_config': {
'bfloat16': True,
'bfloat16_output': True,
'bfloat16_output': False,
'deterministic': False,
'multimer_mode': True,
'subbatch_size': 4,
Expand Down
57 changes: 25 additions & 32 deletions alphafold/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,25 +148,30 @@ def predict(self,
L = aatype.shape[1]

# initialize
def z(shape, dtype=np.float32): return np.zeros(shape, dtype=dtype)
dtype = jnp.bfloat16 if self.config.model.global_config.bfloat16 else np.float32
prev = {'prev_msa_first_row': z([L,256], dtype),
'prev_pair': z([L,L,128], dtype),
'prev_pos': z([L,37,3])}

zeros = lambda shape: np.zeros(shape, dtype=np.float16)
prev = {'prev_msa_first_row': zeros([L,256]),
'prev_pair': zeros([L,L,128]),
'prev_pos': zeros([L,37,3])}

def run(key, feat, prev):
outputs = jax.tree_map(lambda x:np.asarray(x),
self.apply(self.params, key, {**feat, "prev":prev}))
prev = outputs.pop("prev")
return outputs, prev
def _jnp_to_np(x):
for k, v in x.items():
if isinstance(v, dict):
x[k] = _jnp_to_np(v)
else:
x[k] = np.asarray(v,np.float16)
return x
result = _jnp_to_np(self.apply(self.params, key, {**feat, "prev":prev}))
prev = result.pop("prev")
return result, prev


# initialize random key
key = jax.random.PRNGKey(random_seed)

# iterate through recyckes
stop = False
for r in range(num_iters):

for r in range(num_iters):
# grab subset of features
if self.multimer_mode:
sub_feat = feat
Expand All @@ -180,30 +185,18 @@ def run(key, feat, prev):
result, prev = run(sub_key, sub_feat, prev)

if return_representations:
result["representations"] = {"pair": prev["prev_pair"].astype(np.float32),
"single": prev["prev_msa_first_row"].astype(np.float32)}
# decide when to stop
tol = self.config.model.recycle_early_stop_tolerance
sco = self.config.model.stop_at_score
if result["ranking_confidence"] > sco:
stop = True
if not stop and tol > 0:
ca_idx = residue_constants.atom_order['CA']
pos = result["structure_module"]["final_atom_positions"][:,ca_idx]
dist = np.sqrt(np.square(pos[:,None]-pos[None,:]).sum(-1))
if r > 0:
sq_diff = np.square(dist - prev_dist)
seq_mask = feat["seq_mask"] if self.multimer_mode else feat["seq_mask"][0]
mask_2d = seq_mask[:,None] * seq_mask[None,:]
result["diff"] = np.sqrt((sq_diff * mask_2d).sum()/mask_2d.sum())
if result["diff"] < tol:
stop = True
prev_dist = dist

result["representations"] = {"pair": prev["prev_pair"],
"single": prev["prev_msa_first_row"]}

# callback
if callback is not None: callback(result, r)

if stop: break
# decide when to stop
if result["ranking_confidence"] > self.config.model.stop_at_score:
break
if r > 0 and result["tol"] < self.config.model.recycle_early_stop_tolerance:
break

logging.info('Output shape was %s', tree.map_structure(lambda x: x.shape, result))
return result, r
46 changes: 35 additions & 11 deletions alphafold/model/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,19 @@ def get_prev(ret):
prev = batch.pop("prev",None)
if batch["aatype"].ndim == 2:
batch = jax.tree_map(lambda x:x[0], batch)

# initialize
if prev is None:
L = batch["aatype"].shape[0]
dtype = jnp.bfloat16 if self.global_config.bfloat16 else jnp.float32
prev = {'prev_msa_first_row': jnp.zeros([L,256], dtype=dtype),
'prev_pair': jnp.zeros([L,L,128],dtype=dtype),

L = num_residues
prev = {'prev_msa_first_row': jnp.zeros([L,256]),
'prev_pair': jnp.zeros([L,L,128]),
'prev_pos': jnp.zeros([L,37,3])}
else:
for k,v in prev.items():
if v.dtype == jnp.float16:
prev[k] = v.astype(jnp.float32)


ret = impl(batch={**batch, **prev}, is_training=is_training)
ret["prev"] = get_prev(ret)
Expand All @@ -200,7 +207,13 @@ def get_prev(ret):
mask=batch["seq_mask"],
rank_by=self.config.rank_by,
use_jnp=True))


ret["tol"] = confidence.compute_tol(
prev["prev_pos"],
ret["prev"]["prev_pos"],
batch["seq_mask"],
use_jnp=True)

return ret

class AlphaFoldIteration(hk.Module):
Expand Down Expand Up @@ -426,13 +439,18 @@ def slice_recycle_idx(x):
ensemble_representations=ensemble_representations)

emb_config = self.config.embeddings_and_evoformer
prev = batch.pop("prev", None)
# initialize
prev = batch.pop("prev", None)
if prev is None:
L = num_residues
dtype = jnp.bfloat16 if self.global_config.bfloat16 else jnp.float32
prev = {'prev_msa_first_row': jnp.zeros([L,256], dtype=dtype),
'prev_pair': jnp.zeros([L,L,128],dtype=dtype),
prev = {'prev_msa_first_row': jnp.zeros([L,256]),
'prev_pair': jnp.zeros([L,L,128]),
'prev_pos': jnp.zeros([L,37,3])}
else:
for k,v in prev.items():
if v.dtype == jnp.float16:
prev[k] = v.astype(jnp.float32)


ret = do_call(prev=prev, recycle_idx=0)
ret["prev"] = get_prev(ret)
Expand All @@ -446,9 +464,15 @@ def slice_recycle_idx(x):
# add confidence metrics
ret.update(confidence.get_confidence_metrics(
prediction_result=ret,
mask=batch["seq_mask"],
mask=batch["seq_mask"][0],
rank_by=self.config.rank_by,
use_jnp=True))
use_jnp=True))

ret["tol"] = confidence.compute_tol(
prev["prev_pos"],
ret["prev"]["prev_pos"],
batch["seq_mask"][0],
use_jnp=True)

return ret

Expand Down
22 changes: 20 additions & 2 deletions alphafold/model/modules_multimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,20 @@ def apply_network(prev, safe_key):
batch=recycled_batch,
is_training=is_training,
safe_key=safe_key)

ret = apply_network(prev=batch.pop("prev"), safe_key=safe_key)

# initialize
prev = batch.pop("prev", None)
if prev is None:
L = num_residues
prev = {'prev_msa_first_row': jnp.zeros([L,256]),
'prev_pair': jnp.zeros([L,L,128]),
'prev_pos': jnp.zeros([L,37,3])}
else:
for k,v in prev.items():
if v.dtype == jnp.float16:
prev[k] = v.astype(jnp.float32)

ret = apply_network(prev=prev, safe_key=safe_key)
ret["prev"] = get_prev(ret)

if not return_representations:
Expand All @@ -456,6 +468,12 @@ def apply_network(prev, safe_key):
rank_by=self.config.rank_by,
use_jnp=True))

ret["tol"] = confidence.compute_tol(
prev["prev_pos"],
ret["prev"]["prev_pos"],
batch["seq_mask"],
use_jnp=True)

return ret

class EmbeddingsAndEvoformer(hk.Module):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

setup(
name='alphafold-colabfold',
version='2.3.3',
version='2.3.4',
long_description_content_type='text/markdown',
description='An implementation of the inference pipeline of AlphaFold v2.3.1. '
'This is a completely new model that was entered as AlphaFold2 in CASP14 '
Expand Down

0 comments on commit 41807ea

Please sign in to comment.