Skip to content

Commit

Permalink
refactoring dgram recycling (#146)
Browse files Browse the repository at this point in the history
* minor edits to allow dgram recycling

* Update design.py

cleanup dgram code

* add option to mask target features with mlm

* add option to unbias mlm

* Update modules.py
  • Loading branch information
sokrypton authored Jun 6, 2023
1 parent ff0b0b5 commit 4f0c3ce
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 51 deletions.
4 changes: 2 additions & 2 deletions colabdesign/af/alphafold/model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def model_config(name: str) -> ml_collections.ConfigDict:
'subbatch_size': 4,
'use_remat': False,
'zero_init': True,
'use_dgram': False
'use_dgram_pred': False,
},
'heads': {
'distogram': {
Expand Down Expand Up @@ -537,7 +537,7 @@ def model_config(name: str) -> ml_collections.ConfigDict:
'subbatch_size': 4,
'use_remat': False,
'zero_init': True,
'use_dgram': False
'use_dgram_pred': False,
},
'heads': {
'distogram': {
Expand Down
35 changes: 18 additions & 17 deletions colabdesign/af/alphafold/model/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,20 +150,28 @@ def __init__(self, config, name='alphafold'):
def __call__(self, batch, **kwargs):
"""Run the AlphaFold model."""
impl = AlphaFoldIteration(self.config, self.global_config)

def get_prev(ret):
def get_prev(ret, use_dgram=False):
new_prev = {
'prev_msa_first_row': ret['representations']['msa_first_row'],
'prev_pair': ret['representations']['pair'],
'prev_pos': ret['structure_module']['final_atom_positions']
}
if self.global_config.use_dgram:
new_prev['prev_dgram'] = ret["distogram"]["logits"]
if use_dgram:
if self.global_config.use_dgram_pred:
dgram = jax.nn.softmax(ret["distogram"]["logits"])
dgram_map = jax.nn.one_hot(jnp.repeat(jnp.append(0,jnp.arange(15)),4),15).at[:,0].set(0)
new_prev['prev_dgram'] = dgram @ dgram_map
else:
pos = ret['structure_module']['final_atom_positions']
prev_pseudo_beta = pseudo_beta_fn(batch['aatype'], pos, None)
new_prev['prev_dgram'] = dgram_from_positions(prev_pseudo_beta, min_bin=3.25, max_bin=20.75, num_bins=15)
else:
new_prev['prev_pos'] = ret['structure_module']['final_atom_positions']

return new_prev

prev = batch.pop("prev")
ret = impl(batch={**batch, **prev})
ret["prev"] = get_prev(ret)
ret["prev"] = get_prev(ret, use_dgram="prev_dgram" in prev)
return ret

class TemplatePairStack(hk.Module):
Expand Down Expand Up @@ -1383,6 +1391,7 @@ def __call__(self, batch, safe_key=None):
msa_feat = batch['msa_feat'].astype(dtype)
target_feat = jnp.pad(batch["target_feat"].astype(dtype),[[0,0],[1,1]])
preprocess_1d = common_modules.Linear(c.msa_channel, name='preprocess_1d')(target_feat)
preprocess_1d = jnp.where(target_feat.sum(-1,keepdims=True) == 0, 0, preprocess_1d)
preprocess_msa = common_modules.Linear(c.msa_channel, name='preprocess_msa')(msa_feat)
msa_activations = preprocess_1d[None] + preprocess_msa

Expand All @@ -1397,19 +1406,11 @@ def __call__(self, batch, safe_key=None):
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 6
# Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder"

if gc.use_dgram:
# use predicted distogram input (from Sergey)
dgram = jax.nn.softmax(batch["prev_dgram"])
dgram_map = jax.nn.one_hot(jnp.repeat(jnp.append(0,jnp.arange(15)),4),15).at[:,0].set(0)
dgram = dgram @ dgram_map

if "prev_dgram" in batch:
dgram = batch["prev_dgram"]
else:
# use predicted position input
prev_pseudo_beta = pseudo_beta_fn(batch['aatype'], batch['prev_pos'], None)
if c.backprop_dgram:
dgram = dgram_from_positions_soft(prev_pseudo_beta, temp=c.backprop_dgram_temp, **c.prev_pos)
else:
dgram = dgram_from_positions(prev_pseudo_beta, **c.prev_pos)
dgram = dgram_from_positions(prev_pseudo_beta, **c.prev_pos)
dgram = dgram.astype(dtype)
pair_activations += common_modules.Linear(c.pair_channel, name='prev_pos_linear')(dgram)

Expand Down
30 changes: 23 additions & 7 deletions colabdesign/af/alphafold/model/modules_multimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,22 +178,34 @@ def __call__(
assert isinstance(batch, dict)
num_res = batch['aatype'].shape[0]

def get_prev(ret):
def get_prev(ret, use_dgram=False):
new_prev = {
'prev_pos': ret['structure_module']['final_atom_positions'],
'prev_msa_first_row': ret['representations']['msa_first_row'],
'prev_pair': ret['representations']['pair'],
}
if use_dgram:
if self.global_config.use_dgram_pred:
dgram = jax.nn.softmax(ret["distogram"]["logits"])
dgram_map = jax.nn.one_hot(jnp.repeat(jnp.append(0,jnp.arange(15)),4),15).at[:,0].set(0)
new_prev['prev_dgram'] = dgram @ dgram_map
else:
pos = ret['structure_module']['final_atom_positions']
prev_pseudo_beta = modules.pseudo_beta_fn(batch['aatype'], pos, None)
new_prev['prev_dgram'] = modules.dgram_from_positions(prev_pseudo_beta, min_bin=3.25, max_bin=20.75, num_bins=15)
else:
new_prev['prev_pos'] = ret['structure_module']['final_atom_positions']

return new_prev

def apply_network(prev, safe_key):
recycled_batch = {**batch, **prev}
return impl(
batch=recycled_batch,
safe_key=safe_key)

ret = apply_network(prev=batch.pop("prev"), safe_key=safe_key)
ret["prev"] = get_prev(ret)

prev = batch.pop("prev")
ret = apply_network(prev=prev, safe_key=safe_key)
ret["prev"] = get_prev(ret, use_dgram="prev_dgram" in prev)

if not return_representations:
del ret['representations']
Expand Down Expand Up @@ -315,8 +327,12 @@ def __call__(self, batch, safe_key=None):
mask_2d = mask_2d.astype(dtype)

if c.recycle_pos:
prev_pseudo_beta = modules.pseudo_beta_fn(batch['aatype'], batch['prev_pos'], None)
dgram = modules.dgram_from_positions(prev_pseudo_beta, **self.config.prev_pos)
if "prev_dgram" in batch:
dgram = batch["prev_dgram"]
else:
prev_pseudo_beta = modules.pseudo_beta_fn(batch['aatype'], batch['prev_pos'], None)
dgram = modules.dgram_from_positions(prev_pseudo_beta, **self.config.prev_pos)

dgram = dgram.astype(dtype)
pair_activations += common_modules.Linear(c.pair_channel, name='prev_pos_linear')(dgram)

Expand Down
51 changes: 36 additions & 15 deletions colabdesign/af/design.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import jax.numpy as jnp
import numpy as np
from colabdesign.af.alphafold.common import residue_constants
from colabdesign.af.utils import dgram_from_positions
from colabdesign.shared.utils import copy_dict, update_dict, Key, dict_to_str, to_float, softmax, categorical, to_list, copy_missing

####################################################
Expand Down Expand Up @@ -158,24 +159,42 @@ def _recycle(self, model_params, num_recycles=None, backprop=True):
else:
L = self._inputs["residue_index"].shape[0]

# intialize previous
# intialize previous inputs
if "prev" not in self._inputs or a["clear_prev"]:
prev = {'prev_msa_first_row': np.zeros([L,256]),
'prev_pair': np.zeros([L,L,128])}

if a["use_initial_guess"] and "batch" in self._inputs:
prev["prev_pos"] = self._inputs["batch"]["all_atom_positions"]
'prev_pair': np.zeros([L,L,128])}

# initialize coordinates
# TODO: add support for the 'partial' protocol
if "batch" in self._inputs:
ini_seq = self._inputs["batch"]["aatype"]
ini_pos = self._inputs["batch"]["all_atom_positions"]

# via evoformer
if a["use_initial_guess"]:
# via distogram or positions
if a["use_dgram"] or a["use_dgram_pred"]:
prev["prev_dgram"] = dgram_from_positions(positions=ini_pos,
seq=ini_seq, num_bins=15, min_bin=3.25, max_bin=20.75)
else:
prev["prev_pos"] = ini_pos
else:
if a["use_dgram"] or a["use_dgram_pred"]:
prev["prev_dgram"] = np.zeros([L,L,15])
else:
prev["prev_pos"] = np.zeros([L,37,3])

# via structure module
if a["use_initial_atom_pos"]:
self._inputs["initial_atom_pos"] = ini_pos

else:
prev["prev_pos"] = np.zeros([L,37,3])

if a["use_dgram"]:
# TODO: add support for initial_guess + use_dgram
prev["prev_dgram"] = np.zeros([L,L,64])

if a["use_initial_atom_pos"]:
if "batch" in self._inputs:
self._inputs["initial_atom_pos"] = self._inputs["batch"]["all_atom_positions"]
# if batch not defined, initialize with zeros
if a["use_dgram"] or a["use_dgram_pred"]:
prev["prev_dgram"] = np.zeros([L,L,15])
else:
prev["prev_pos"] = np.zeros([L,37,3])
if a["use_initial_atom_pos"]:
self._inputs["initial_atom_pos"] = np.zeros([L,37,3])

self._inputs["prev"] = prev
Expand All @@ -196,9 +215,11 @@ def _recycle(self, model_params, num_recycles=None, backprop=True):
else:
aux = self._single(model_params, backprop)
grad.append(jax.tree_map(lambda x:x*m, aux["grad"]))

# update previous inputs
self._inputs["prev"] = aux["prev"]
if a["use_initial_atom_pos"]:
self._inputs["initial_atom_pos"] = aux["prev"]["prev_pos"]
self._inputs["initial_atom_pos"] = aux["atom_positions"]

aux["grad"] = jax.tree_map(lambda *x: np.stack(x).sum(0), *grad)

Expand Down
4 changes: 3 additions & 1 deletion colabdesign/af/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _update_template(self, inputs, key):
inputs[k] = inputs[k].at[...,5:].set(jnp.where(rm_sc[:,None],0,inputs[k][...,5:]))
inputs[k] = jnp.where(rm[:,None],0,inputs[k])

def update_seq(seq, inputs, seq_1hot=None, seq_pssm=None, mlm=None):
def update_seq(seq, inputs, seq_1hot=None, seq_pssm=None, mlm=None, mask_target=False):
'''update the sequence features'''
if seq_1hot is None: seq_1hot = seq["pseudo"]
if seq_pssm is None: seq_pssm = seq["pseudo"]
Expand All @@ -122,6 +122,8 @@ def update_seq(seq, inputs, seq_1hot=None, seq_pssm=None, mlm=None):
Y = jnp.zeros(msa_feat.shape[-1]).at[...,:23].set(X).at[...,25:48].set(X)
msa_feat = jnp.where(mlm[...,None],Y,msa_feat)
seq["pseudo"] = jnp.where(mlm[...,None],X[:seq["pseudo"].shape[-1]],seq["pseudo"])
if mask_target:
target_feat = jnp.where(mlm[0,:,None],0,target_feat)

inputs.update({"msa_feat":msa_feat, "target_feat":target_feat})

Expand Down
5 changes: 4 additions & 1 deletion colabdesign/af/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,8 +541,11 @@ def get_seq_ent_loss(inputs):
ent = (ent * mask).sum() / (mask.sum() + 1e-8)
return {"seq_ent":ent.mean()}

def get_mlm_loss(outputs, mask, truth=None):
def get_mlm_loss(outputs, mask, truth=None, unbias=False):
x = outputs["masked_msa"]["logits"][...,:20]
if unbias:
x_mean = (x * mask[...,None]).sum((0,1)) / (mask.sum() + 1e-8)
x = x - x_mean
if truth is None: truth = jax.nn.softmax(x)
ent = -(truth[...,:20] * jax.nn.log_softmax(x)).sum(-1)
ent = (ent * mask).sum(-1) / (mask.sum() + 1e-8)
Expand Down
14 changes: 9 additions & 5 deletions colabdesign/af/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ def __init__(self,
self.protocol = protocol
self._num = kwargs.pop("num_seq",1)
self._args = {"use_templates":use_templates, "use_multimer":use_multimer, "use_bfloat16":True,
"recycle_mode":"last", "use_mlm": False, "realign": True,
"recycle_mode":"last",
"use_mlm": False, "mask_target":False, "unbias_mlm": False,
"realign": True,
"debug":debug, "repeat":False, "homooligomer":False, "copies":1,
"optimizer":"sgd", "best_metric":"loss",
"traj_iter":1, "traj_max":10000,
"clear_prev": True, "use_dgram":False,
"clear_prev": True, "use_dgram":False, "use_dgram_pred":False,
"shuffle_first":True, "use_remat":True,
"alphabet_size":20,
"use_initial_guess":False, "use_initial_atom_pos":False}
Expand Down Expand Up @@ -101,7 +103,7 @@ def __init__(self,
num_recycles = self.opt["num_recycles"]
self._cfg.model.num_recycle = num_recycles
self._cfg.model.global_config.use_remat = self._args["use_remat"]
self._cfg.model.global_config.use_dgram = self._args["use_dgram"]
self._cfg.model.global_config.use_dgram_pred = self._args["use_dgram_pred"]
self._cfg.model.global_config.bfloat16 = self._args["use_bfloat16"]

# load model_params
Expand Down Expand Up @@ -159,7 +161,7 @@ def _model(params, model_params, inputs, key):
if a["use_mlm"]:
shape = seq["pseudo"].shape[:2]
mlm = jax.random.bernoulli(key(),opt["mlm_dropout"],shape)
update_seq(seq, inputs, seq_pssm=pssm, mlm=mlm)
update_seq(seq, inputs, seq_pssm=pssm, mlm=mlm, mask_target=a["mask_target"])
else:
update_seq(seq, inputs, seq_pssm=pssm)

Expand Down Expand Up @@ -221,8 +223,10 @@ def _model(params, model_params, inputs, key):
# experimental masked-language-modeling
if a["use_mlm"]:
aux["mlm"] = outputs["masked_msa"]["logits"]
aux["mlm_mask"] = mlm
mask = jnp.where(inputs["seq_mask"],mlm,0)
aux["losses"].update(get_mlm_loss(outputs, mask=mask, truth=seq["pssm"]))
aux["losses"].update(get_mlm_loss(outputs, mask=mask,
truth=seq["pssm"], unbias=a["unbias_mlm"]))

# run user defined callbacks
for c in ["loss","post"]:
Expand Down
19 changes: 17 additions & 2 deletions colabdesign/af/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from colabdesign.shared.utils import update_dict, Key
from colabdesign.shared.plot import plot_pseudo_3D, make_animation, show_pdb
from colabdesign.shared.protein import renum_pdb_str
from colabdesign.af.alphafold.common import protein
from colabdesign.af.alphafold.common import protein, residue_constants

####################################################
# AF_UTILS - various utils (save, plot, etc)
Expand Down Expand Up @@ -186,4 +186,19 @@ def plot_current_pdb(self, show_sidechains=False, show_mainchains=False,
- color=["pLDDT","chain","rainbow"]
'''
self.plot_pdb(show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color,
color_HP=color_HP, size=size, animate=animate, get_best=False)
color_HP=color_HP, size=size, animate=animate, get_best=False)

def dgram_from_positions(positions, seq=None, num_bins=39, min_bin=3.25, max_bin=50.75):
if seq is None:
atoms = {k:positions[...,residue_constants.atom_order[k],:] for k in ["N","CA","C"]}
c = _np_get_cb(**atoms, use_jax=False)
else:
ca = positions[...,residue_constants.atom_order["CA"],:]
cb = positions[...,residue_constants.atom_order["CB"],:]
is_gly = seq==residue_constants.restype_order["G"]
c = np.where(is_gly[:,None],ca,cb)
dist = np.sqrt(np.square(c[None,:] - c[:,None]).sum(-1,keepdims=True))
lower_breaks = np.linspace(min_bin, max_bin, num_bins)
lower_breaks = lower_breaks
upper_breaks = np.append(lower_breaks[1:],1e8)
return ((dist > lower_breaks) * (dist < upper_breaks)).astype(float)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from setuptools import setup, find_packages
setup(
name='colabdesign',
version='1.1.2-beta0',
version='1.1.2-beta1',
description='Making Protein Design accessible to all via Google Colab!',
long_description="Making Protein Design accessible to all via Google Colab!",
long_description_content_type='text/markdown',
Expand Down

0 comments on commit 4f0c3ce

Please sign in to comment.