diff --git a/colabdesign/af/alphafold/model/config.py b/colabdesign/af/alphafold/model/config.py index f405d198..32f001e7 100644 --- a/colabdesign/af/alphafold/model/config.py +++ b/colabdesign/af/alphafold/model/config.py @@ -308,7 +308,7 @@ def model_config(name: str) -> ml_collections.ConfigDict: 'subbatch_size': 4, 'use_remat': False, 'zero_init': True, - 'use_dgram': False + 'use_dgram_pred': False, }, 'heads': { 'distogram': { @@ -537,7 +537,7 @@ def model_config(name: str) -> ml_collections.ConfigDict: 'subbatch_size': 4, 'use_remat': False, 'zero_init': True, - 'use_dgram': False + 'use_dgram_pred': False, }, 'heads': { 'distogram': { diff --git a/colabdesign/af/alphafold/model/modules.py b/colabdesign/af/alphafold/model/modules.py index 9b4c30c3..d0d7c631 100644 --- a/colabdesign/af/alphafold/model/modules.py +++ b/colabdesign/af/alphafold/model/modules.py @@ -150,20 +150,28 @@ def __init__(self, config, name='alphafold'): def __call__(self, batch, **kwargs): """Run the AlphaFold model.""" impl = AlphaFoldIteration(self.config, self.global_config) - - def get_prev(ret): + def get_prev(ret, use_dgram=False): new_prev = { 'prev_msa_first_row': ret['representations']['msa_first_row'], 'prev_pair': ret['representations']['pair'], - 'prev_pos': ret['structure_module']['final_atom_positions'] } - if self.global_config.use_dgram: - new_prev['prev_dgram'] = ret["distogram"]["logits"] + if use_dgram: + if self.global_config.use_dgram_pred: + dgram = jax.nn.softmax(ret["distogram"]["logits"]) + dgram_map = jax.nn.one_hot(jnp.repeat(jnp.append(0,jnp.arange(15)),4),15).at[:,0].set(0) + new_prev['prev_dgram'] = dgram @ dgram_map + else: + pos = ret['structure_module']['final_atom_positions'] + prev_pseudo_beta = pseudo_beta_fn(batch['aatype'], pos, None) + new_prev['prev_dgram'] = dgram_from_positions(prev_pseudo_beta, min_bin=3.25, max_bin=20.75, num_bins=15) + else: + new_prev['prev_pos'] = ret['structure_module']['final_atom_positions'] + return new_prev prev = batch.pop("prev") ret = impl(batch={**batch, **prev}) - ret["prev"] = get_prev(ret) + ret["prev"] = get_prev(ret, use_dgram="prev_dgram" in prev) return ret class TemplatePairStack(hk.Module): @@ -1383,6 +1391,7 @@ def __call__(self, batch, safe_key=None): msa_feat = batch['msa_feat'].astype(dtype) target_feat = jnp.pad(batch["target_feat"].astype(dtype),[[0,0],[1,1]]) preprocess_1d = common_modules.Linear(c.msa_channel, name='preprocess_1d')(target_feat) + preprocess_1d = jnp.where(target_feat.sum(-1,keepdims=True) == 0, 0, preprocess_1d) preprocess_msa = common_modules.Linear(c.msa_channel, name='preprocess_msa')(msa_feat) msa_activations = preprocess_1d[None] + preprocess_msa @@ -1397,19 +1406,11 @@ def __call__(self, batch, safe_key=None): # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 6 # Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder" - if gc.use_dgram: - # use predicted distogram input (from Sergey) - dgram = jax.nn.softmax(batch["prev_dgram"]) - dgram_map = jax.nn.one_hot(jnp.repeat(jnp.append(0,jnp.arange(15)),4),15).at[:,0].set(0) - dgram = dgram @ dgram_map - + if "prev_dgram" in batch: + dgram = batch["prev_dgram"] else: - # use predicted position input prev_pseudo_beta = pseudo_beta_fn(batch['aatype'], batch['prev_pos'], None) - if c.backprop_dgram: - dgram = dgram_from_positions_soft(prev_pseudo_beta, temp=c.backprop_dgram_temp, **c.prev_pos) - else: - dgram = dgram_from_positions(prev_pseudo_beta, **c.prev_pos) + dgram = dgram_from_positions(prev_pseudo_beta, **c.prev_pos) dgram = dgram.astype(dtype) pair_activations += common_modules.Linear(c.pair_channel, name='prev_pos_linear')(dgram) diff --git a/colabdesign/af/alphafold/model/modules_multimer.py b/colabdesign/af/alphafold/model/modules_multimer.py index 8822c6d8..88cd436f 100644 --- a/colabdesign/af/alphafold/model/modules_multimer.py +++ b/colabdesign/af/alphafold/model/modules_multimer.py @@ -178,12 +178,23 @@ def __call__( assert isinstance(batch, dict) num_res = batch['aatype'].shape[0] - def get_prev(ret): + def get_prev(ret, use_dgram=False): new_prev = { - 'prev_pos': ret['structure_module']['final_atom_positions'], 'prev_msa_first_row': ret['representations']['msa_first_row'], 'prev_pair': ret['representations']['pair'], } + if use_dgram: + if self.global_config.use_dgram_pred: + dgram = jax.nn.softmax(ret["distogram"]["logits"]) + dgram_map = jax.nn.one_hot(jnp.repeat(jnp.append(0,jnp.arange(15)),4),15).at[:,0].set(0) + new_prev['prev_dgram'] = dgram @ dgram_map + else: + pos = ret['structure_module']['final_atom_positions'] + prev_pseudo_beta = modules.pseudo_beta_fn(batch['aatype'], pos, None) + new_prev['prev_dgram'] = modules.dgram_from_positions(prev_pseudo_beta, min_bin=3.25, max_bin=20.75, num_bins=15) + else: + new_prev['prev_pos'] = ret['structure_module']['final_atom_positions'] + return new_prev def apply_network(prev, safe_key): @@ -191,9 +202,10 @@ def apply_network(prev, safe_key): return impl( batch=recycled_batch, safe_key=safe_key) - - ret = apply_network(prev=batch.pop("prev"), safe_key=safe_key) - ret["prev"] = get_prev(ret) + + prev = batch.pop("prev") + ret = apply_network(prev=prev, safe_key=safe_key) + ret["prev"] = get_prev(ret, use_dgram="prev_dgram" in prev) if not return_representations: del ret['representations'] @@ -315,8 +327,12 @@ def __call__(self, batch, safe_key=None): mask_2d = mask_2d.astype(dtype) if c.recycle_pos: - prev_pseudo_beta = modules.pseudo_beta_fn(batch['aatype'], batch['prev_pos'], None) - dgram = modules.dgram_from_positions(prev_pseudo_beta, **self.config.prev_pos) + if "prev_dgram" in batch: + dgram = batch["prev_dgram"] + else: + prev_pseudo_beta = modules.pseudo_beta_fn(batch['aatype'], batch['prev_pos'], None) + dgram = modules.dgram_from_positions(prev_pseudo_beta, **self.config.prev_pos) + dgram = dgram.astype(dtype) pair_activations += common_modules.Linear(c.pair_channel, name='prev_pos_linear')(dgram) diff --git a/colabdesign/af/design.py b/colabdesign/af/design.py index 00331ba3..381a192b 100644 --- a/colabdesign/af/design.py +++ b/colabdesign/af/design.py @@ -3,6 +3,7 @@ import jax.numpy as jnp import numpy as np from colabdesign.af.alphafold.common import residue_constants +from colabdesign.af.utils import dgram_from_positions from colabdesign.shared.utils import copy_dict, update_dict, Key, dict_to_str, to_float, softmax, categorical, to_list, copy_missing #################################################### @@ -158,24 +159,42 @@ def _recycle(self, model_params, num_recycles=None, backprop=True): else: L = self._inputs["residue_index"].shape[0] - # intialize previous + # intialize previous inputs if "prev" not in self._inputs or a["clear_prev"]: prev = {'prev_msa_first_row': np.zeros([L,256]), - 'prev_pair': np.zeros([L,L,128])} - - if a["use_initial_guess"] and "batch" in self._inputs: - prev["prev_pos"] = self._inputs["batch"]["all_atom_positions"] + 'prev_pair': np.zeros([L,L,128])} + + # initialize coordinates + # TODO: add support for the 'partial' protocol + if "batch" in self._inputs: + ini_seq = self._inputs["batch"]["aatype"] + ini_pos = self._inputs["batch"]["all_atom_positions"] + + # via evoformer + if a["use_initial_guess"]: + # via distogram or positions + if a["use_dgram"] or a["use_dgram_pred"]: + prev["prev_dgram"] = dgram_from_positions(positions=ini_pos, + seq=ini_seq, num_bins=15, min_bin=3.25, max_bin=20.75) + else: + prev["prev_pos"] = ini_pos + else: + if a["use_dgram"] or a["use_dgram_pred"]: + prev["prev_dgram"] = np.zeros([L,L,15]) + else: + prev["prev_pos"] = np.zeros([L,37,3]) + + # via structure module + if a["use_initial_atom_pos"]: + self._inputs["initial_atom_pos"] = ini_pos + else: - prev["prev_pos"] = np.zeros([L,37,3]) - - if a["use_dgram"]: - # TODO: add support for initial_guess + use_dgram - prev["prev_dgram"] = np.zeros([L,L,64]) - - if a["use_initial_atom_pos"]: - if "batch" in self._inputs: - self._inputs["initial_atom_pos"] = self._inputs["batch"]["all_atom_positions"] + # if batch not defined, initialize with zeros + if a["use_dgram"] or a["use_dgram_pred"]: + prev["prev_dgram"] = np.zeros([L,L,15]) else: + prev["prev_pos"] = np.zeros([L,37,3]) + if a["use_initial_atom_pos"]: self._inputs["initial_atom_pos"] = np.zeros([L,37,3]) self._inputs["prev"] = prev @@ -196,9 +215,11 @@ def _recycle(self, model_params, num_recycles=None, backprop=True): else: aux = self._single(model_params, backprop) grad.append(jax.tree_map(lambda x:x*m, aux["grad"])) + + # update previous inputs self._inputs["prev"] = aux["prev"] if a["use_initial_atom_pos"]: - self._inputs["initial_atom_pos"] = aux["prev"]["prev_pos"] + self._inputs["initial_atom_pos"] = aux["atom_positions"] aux["grad"] = jax.tree_map(lambda *x: np.stack(x).sum(0), *grad) diff --git a/colabdesign/af/inputs.py b/colabdesign/af/inputs.py index 36fdc5d8..b81dd85f 100644 --- a/colabdesign/af/inputs.py +++ b/colabdesign/af/inputs.py @@ -106,7 +106,7 @@ def _update_template(self, inputs, key): inputs[k] = inputs[k].at[...,5:].set(jnp.where(rm_sc[:,None],0,inputs[k][...,5:])) inputs[k] = jnp.where(rm[:,None],0,inputs[k]) -def update_seq(seq, inputs, seq_1hot=None, seq_pssm=None, mlm=None): +def update_seq(seq, inputs, seq_1hot=None, seq_pssm=None, mlm=None, mask_target=False): '''update the sequence features''' if seq_1hot is None: seq_1hot = seq["pseudo"] if seq_pssm is None: seq_pssm = seq["pseudo"] @@ -122,6 +122,8 @@ def update_seq(seq, inputs, seq_1hot=None, seq_pssm=None, mlm=None): Y = jnp.zeros(msa_feat.shape[-1]).at[...,:23].set(X).at[...,25:48].set(X) msa_feat = jnp.where(mlm[...,None],Y,msa_feat) seq["pseudo"] = jnp.where(mlm[...,None],X[:seq["pseudo"].shape[-1]],seq["pseudo"]) + if mask_target: + target_feat = jnp.where(mlm[0,:,None],0,target_feat) inputs.update({"msa_feat":msa_feat, "target_feat":target_feat}) diff --git a/colabdesign/af/loss.py b/colabdesign/af/loss.py index b0b463be..0100f30f 100644 --- a/colabdesign/af/loss.py +++ b/colabdesign/af/loss.py @@ -541,8 +541,11 @@ def get_seq_ent_loss(inputs): ent = (ent * mask).sum() / (mask.sum() + 1e-8) return {"seq_ent":ent.mean()} -def get_mlm_loss(outputs, mask, truth=None): +def get_mlm_loss(outputs, mask, truth=None, unbias=False): x = outputs["masked_msa"]["logits"][...,:20] + if unbias: + x_mean = (x * mask[...,None]).sum((0,1)) / (mask.sum() + 1e-8) + x = x - x_mean if truth is None: truth = jax.nn.softmax(x) ent = -(truth[...,:20] * jax.nn.log_softmax(x)).sum(-1) ent = (ent * mask).sum(-1) / (mask.sum() + 1e-8) diff --git a/colabdesign/af/model.py b/colabdesign/af/model.py index 834c98a3..8b6be168 100644 --- a/colabdesign/af/model.py +++ b/colabdesign/af/model.py @@ -33,11 +33,13 @@ def __init__(self, self.protocol = protocol self._num = kwargs.pop("num_seq",1) self._args = {"use_templates":use_templates, "use_multimer":use_multimer, "use_bfloat16":True, - "recycle_mode":"last", "use_mlm": False, "realign": True, + "recycle_mode":"last", + "use_mlm": False, "mask_target":False, "unbias_mlm": False, + "realign": True, "debug":debug, "repeat":False, "homooligomer":False, "copies":1, "optimizer":"sgd", "best_metric":"loss", "traj_iter":1, "traj_max":10000, - "clear_prev": True, "use_dgram":False, + "clear_prev": True, "use_dgram":False, "use_dgram_pred":False, "shuffle_first":True, "use_remat":True, "alphabet_size":20, "use_initial_guess":False, "use_initial_atom_pos":False} @@ -101,7 +103,7 @@ def __init__(self, num_recycles = self.opt["num_recycles"] self._cfg.model.num_recycle = num_recycles self._cfg.model.global_config.use_remat = self._args["use_remat"] - self._cfg.model.global_config.use_dgram = self._args["use_dgram"] + self._cfg.model.global_config.use_dgram_pred = self._args["use_dgram_pred"] self._cfg.model.global_config.bfloat16 = self._args["use_bfloat16"] # load model_params @@ -159,7 +161,7 @@ def _model(params, model_params, inputs, key): if a["use_mlm"]: shape = seq["pseudo"].shape[:2] mlm = jax.random.bernoulli(key(),opt["mlm_dropout"],shape) - update_seq(seq, inputs, seq_pssm=pssm, mlm=mlm) + update_seq(seq, inputs, seq_pssm=pssm, mlm=mlm, mask_target=a["mask_target"]) else: update_seq(seq, inputs, seq_pssm=pssm) @@ -221,8 +223,10 @@ def _model(params, model_params, inputs, key): # experimental masked-language-modeling if a["use_mlm"]: aux["mlm"] = outputs["masked_msa"]["logits"] + aux["mlm_mask"] = mlm mask = jnp.where(inputs["seq_mask"],mlm,0) - aux["losses"].update(get_mlm_loss(outputs, mask=mask, truth=seq["pssm"])) + aux["losses"].update(get_mlm_loss(outputs, mask=mask, + truth=seq["pssm"], unbias=a["unbias_mlm"])) # run user defined callbacks for c in ["loss","post"]: diff --git a/colabdesign/af/utils.py b/colabdesign/af/utils.py index ee5079e2..504d3248 100644 --- a/colabdesign/af/utils.py +++ b/colabdesign/af/utils.py @@ -8,7 +8,7 @@ from colabdesign.shared.utils import update_dict, Key from colabdesign.shared.plot import plot_pseudo_3D, make_animation, show_pdb from colabdesign.shared.protein import renum_pdb_str -from colabdesign.af.alphafold.common import protein +from colabdesign.af.alphafold.common import protein, residue_constants #################################################### # AF_UTILS - various utils (save, plot, etc) @@ -186,4 +186,19 @@ def plot_current_pdb(self, show_sidechains=False, show_mainchains=False, - color=["pLDDT","chain","rainbow"] ''' self.plot_pdb(show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color, - color_HP=color_HP, size=size, animate=animate, get_best=False) \ No newline at end of file + color_HP=color_HP, size=size, animate=animate, get_best=False) + +def dgram_from_positions(positions, seq=None, num_bins=39, min_bin=3.25, max_bin=50.75): + if seq is None: + atoms = {k:positions[...,residue_constants.atom_order[k],:] for k in ["N","CA","C"]} + c = _np_get_cb(**atoms, use_jax=False) + else: + ca = positions[...,residue_constants.atom_order["CA"],:] + cb = positions[...,residue_constants.atom_order["CB"],:] + is_gly = seq==residue_constants.restype_order["G"] + c = np.where(is_gly[:,None],ca,cb) + dist = np.sqrt(np.square(c[None,:] - c[:,None]).sum(-1,keepdims=True)) + lower_breaks = np.linspace(min_bin, max_bin, num_bins) + lower_breaks = lower_breaks + upper_breaks = np.append(lower_breaks[1:],1e8) + return ((dist > lower_breaks) * (dist < upper_breaks)).astype(float) diff --git a/setup.py b/setup.py index 5f692ef9..644ec673 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ from setuptools import setup, find_packages setup( name='colabdesign', - version='1.1.2-beta0', + version='1.1.2-beta1', description='Making Protein Design accessible to all via Google Colab!', long_description="Making Protein Design accessible to all via Google Colab!", long_description_content_type='text/markdown',