From 4d7976037315b7bf4d47f19b1d251c6be4929f03 Mon Sep 17 00:00:00 2001 From: Sergey O Date: Fri, 19 May 2023 17:07:17 -0400 Subject: [PATCH 1/5] minor edits to allow dgram recycling --- colabdesign/af/alphafold/model/config.py | 6 +++-- colabdesign/af/alphafold/model/modules.py | 26 +++++++++---------- .../af/alphafold/model/modules_multimer.py | 20 +++++++++++--- colabdesign/af/design.py | 22 +++++++++------- colabdesign/af/model.py | 4 +-- colabdesign/af/utils.py | 25 +++++++++++++++++- 6 files changed, 73 insertions(+), 30 deletions(-) diff --git a/colabdesign/af/alphafold/model/config.py b/colabdesign/af/alphafold/model/config.py index f405d198..6af23762 100644 --- a/colabdesign/af/alphafold/model/config.py +++ b/colabdesign/af/alphafold/model/config.py @@ -308,7 +308,8 @@ def model_config(name: str) -> ml_collections.ConfigDict: 'subbatch_size': 4, 'use_remat': False, 'zero_init': True, - 'use_dgram': False + 'use_dgram': False, + 'use_prev_dgram': False }, 'heads': { 'distogram': { @@ -537,7 +538,8 @@ def model_config(name: str) -> ml_collections.ConfigDict: 'subbatch_size': 4, 'use_remat': False, 'zero_init': True, - 'use_dgram': False + 'use_dgram': False, + 'use_prev_dgram': False }, 'heads': { 'distogram': { diff --git a/colabdesign/af/alphafold/model/modules.py b/colabdesign/af/alphafold/model/modules.py index 9b4c30c3..876416f2 100644 --- a/colabdesign/af/alphafold/model/modules.py +++ b/colabdesign/af/alphafold/model/modules.py @@ -155,10 +155,18 @@ def get_prev(ret): new_prev = { 'prev_msa_first_row': ret['representations']['msa_first_row'], 'prev_pair': ret['representations']['pair'], - 'prev_pos': ret['structure_module']['final_atom_positions'] } if self.global_config.use_dgram: - new_prev['prev_dgram'] = ret["distogram"]["logits"] + dgram = jax.nn.softmax(ret["distogram"]["logits"]) + dgram_map = jax.nn.one_hot(jnp.repeat(jnp.append(0,jnp.arange(15)),4),15).at[:,0].set(0) + new_prev['prev_dgram'] = dgram @ dgram_map + elif self.global_config.use_prev_dgram: + pos = ret['structure_module']['final_atom_positions'] + prev_pseudo_beta = pseudo_beta_fn(batch['aatype'], pos, None) + new_prev['prev_dgram'] = dgram_from_positions(prev_pseudo_beta, min_bin=3.25, max_bin=20.75, num_bins=15) + else: + new_prev['prev_pos'] = ret['structure_module']['final_atom_positions'] + return new_prev prev = batch.pop("prev") @@ -1397,19 +1405,11 @@ def __call__(self, batch, safe_key=None): # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 6 # Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder" - if gc.use_dgram: - # use predicted distogram input (from Sergey) - dgram = jax.nn.softmax(batch["prev_dgram"]) - dgram_map = jax.nn.one_hot(jnp.repeat(jnp.append(0,jnp.arange(15)),4),15).at[:,0].set(0) - dgram = dgram @ dgram_map - + if "prev_dgram" in batch: + dgram = batch["prev_dgram"] else: - # use predicted position input prev_pseudo_beta = pseudo_beta_fn(batch['aatype'], batch['prev_pos'], None) - if c.backprop_dgram: - dgram = dgram_from_positions_soft(prev_pseudo_beta, temp=c.backprop_dgram_temp, **c.prev_pos) - else: - dgram = dgram_from_positions(prev_pseudo_beta, **c.prev_pos) + dgram = dgram_from_positions(prev_pseudo_beta, **c.prev_pos) dgram = dgram.astype(dtype) pair_activations += common_modules.Linear(c.pair_channel, name='prev_pos_linear')(dgram) diff --git a/colabdesign/af/alphafold/model/modules_multimer.py b/colabdesign/af/alphafold/model/modules_multimer.py index 8822c6d8..6d16ea6b 100644 --- a/colabdesign/af/alphafold/model/modules_multimer.py +++ b/colabdesign/af/alphafold/model/modules_multimer.py @@ -180,10 +180,20 @@ def __call__( def get_prev(ret): new_prev = { - 'prev_pos': ret['structure_module']['final_atom_positions'], 'prev_msa_first_row': ret['representations']['msa_first_row'], 'prev_pair': ret['representations']['pair'], } + if self.global_config.use_dgram: + dgram = jax.nn.softmax(ret["distogram"]["logits"]) + dgram_map = jax.nn.one_hot(jnp.repeat(jnp.append(0,jnp.arange(15)),4),15).at[:,0].set(0) + new_prev['prev_dgram'] = dgram @ dgram_map + elif self.global_config.use_prev_dgram: + pos = ret['structure_module']['final_atom_positions'] + prev_pseudo_beta = modules.pseudo_beta_fn(batch['aatype'], pos, None) + new_prev['prev_dgram'] = modules.dgram_from_positions(prev_pseudo_beta, min_bin=3.25, max_bin=20.75, num_bins=15) + else: + new_prev['prev_pos'] = ret['structure_module']['final_atom_positions'] + return new_prev def apply_network(prev, safe_key): @@ -315,8 +325,12 @@ def __call__(self, batch, safe_key=None): mask_2d = mask_2d.astype(dtype) if c.recycle_pos: - prev_pseudo_beta = modules.pseudo_beta_fn(batch['aatype'], batch['prev_pos'], None) - dgram = modules.dgram_from_positions(prev_pseudo_beta, **self.config.prev_pos) + if "prev_dgram" in batch: + dgram = batch["prev_dgram"] + else: + prev_pseudo_beta = modules.pseudo_beta_fn(batch['aatype'], batch['prev_pos'], None) + dgram = modules.dgram_from_positions(prev_pseudo_beta, **self.config.prev_pos) + dgram = dgram.astype(dtype) pair_activations += common_modules.Linear(c.pair_channel, name='prev_pos_linear')(dgram) diff --git a/colabdesign/af/design.py b/colabdesign/af/design.py index 00331ba3..aaf0582a 100644 --- a/colabdesign/af/design.py +++ b/colabdesign/af/design.py @@ -3,6 +3,7 @@ import jax.numpy as jnp import numpy as np from colabdesign.af.alphafold.common import residue_constants +from colabdesign.af.utils import dgram_from_positions from colabdesign.shared.utils import copy_dict, update_dict, Key, dict_to_str, to_float, softmax, categorical, to_list, copy_missing #################################################### @@ -161,16 +162,19 @@ def _recycle(self, model_params, num_recycles=None, backprop=True): # intialize previous if "prev" not in self._inputs or a["clear_prev"]: prev = {'prev_msa_first_row': np.zeros([L,256]), - 'prev_pair': np.zeros([L,L,128])} - + 'prev_pair': np.zeros([L,L,128])} if a["use_initial_guess"] and "batch" in self._inputs: - prev["prev_pos"] = self._inputs["batch"]["all_atom_positions"] + if a["use_prev_dgram"]: + prev["prev_dgram"] = dgram_from_positions( + self._inputs["batch"]["all_atom_positions"], + min_bin=3.25, max_bin=20.75, num_bins=15) + else: + prev["prev_pos"] = self._inputs["batch"]["all_atom_positions"] else: - prev["prev_pos"] = np.zeros([L,37,3]) - - if a["use_dgram"]: - # TODO: add support for initial_guess + use_dgram - prev["prev_dgram"] = np.zeros([L,L,64]) + if a["use_prev_dgram"]: + prev["prev_dgram"] = np.zeros([L,L,15]) + else: + prev["prev_pos"] = np.zeros([L,37,3]) if a["use_initial_atom_pos"]: if "batch" in self._inputs: @@ -198,7 +202,7 @@ def _recycle(self, model_params, num_recycles=None, backprop=True): grad.append(jax.tree_map(lambda x:x*m, aux["grad"])) self._inputs["prev"] = aux["prev"] if a["use_initial_atom_pos"]: - self._inputs["initial_atom_pos"] = aux["prev"]["prev_pos"] + self._inputs["initial_atom_pos"] = aux["atom_positions"] aux["grad"] = jax.tree_map(lambda *x: np.stack(x).sum(0), *grad) diff --git a/colabdesign/af/model.py b/colabdesign/af/model.py index 834c98a3..f607883e 100644 --- a/colabdesign/af/model.py +++ b/colabdesign/af/model.py @@ -37,7 +37,7 @@ def __init__(self, "debug":debug, "repeat":False, "homooligomer":False, "copies":1, "optimizer":"sgd", "best_metric":"loss", "traj_iter":1, "traj_max":10000, - "clear_prev": True, "use_dgram":False, + "clear_prev": True, "use_prev_dgram":False, "shuffle_first":True, "use_remat":True, "alphabet_size":20, "use_initial_guess":False, "use_initial_atom_pos":False} @@ -101,7 +101,7 @@ def __init__(self, num_recycles = self.opt["num_recycles"] self._cfg.model.num_recycle = num_recycles self._cfg.model.global_config.use_remat = self._args["use_remat"] - self._cfg.model.global_config.use_dgram = self._args["use_dgram"] + self._cfg.model.global_config.use_prev_dgram = self._args["use_prev_dgram"] self._cfg.model.global_config.bfloat16 = self._args["use_bfloat16"] # load model_params diff --git a/colabdesign/af/utils.py b/colabdesign/af/utils.py index ee5079e2..a968c194 100644 --- a/colabdesign/af/utils.py +++ b/colabdesign/af/utils.py @@ -186,4 +186,27 @@ def plot_current_pdb(self, show_sidechains=False, show_mainchains=False, - color=["pLDDT","chain","rainbow"] ''' self.plot_pdb(show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color, - color_HP=color_HP, size=size, animate=animate, get_best=False) \ No newline at end of file + color_HP=color_HP, size=size, animate=animate, get_best=False) + +def dgram_from_positions(positions, num_bins, min_bin, max_bin): + """Compute distogram from amino acid positions. + Arguments: + positions: [N_res, 3] Position coordinates. + num_bins: The number of bins in the distogram. + min_bin: The left edge of the first bin. + max_bin: The left edge of the final bin. The final bin catches + everything larger than `max_bin`. + Returns: + Distogram with the specified number of bins. + """ + def squared_difference(x, y): + return np.square(x - y) + lower_breaks = np.linspace(min_bin, max_bin, num_bins) + lower_breaks = np.square(lower_breaks) + upper_breaks = np.concatenate([lower_breaks[1:],np.array([1e8])], axis=-1) + dist2 = np.sum( + squared_difference( + np.expand_dims(positions, axis=-2), + np.expand_dims(positions, axis=-3)), + axis=-1, keepdims=True) + return (dist2 > lower_breaks) * (dist2 < upper_breaks) From 10bb8b59035f1685d79dc6a3e7510ca8ca4bf195 Mon Sep 17 00:00:00 2001 From: Sergey O Date: Wed, 24 May 2023 07:53:56 -0400 Subject: [PATCH 2/5] Update design.py cleanup dgram code --- colabdesign/af/alphafold/model/config.py | 6 +-- colabdesign/af/alphafold/model/modules.py | 22 +++++----- .../af/alphafold/model/modules_multimer.py | 26 ++++++------ colabdesign/af/design.py | 42 ++++++++++++------- colabdesign/af/model.py | 4 +- colabdesign/af/utils.py | 36 +++++++--------- setup.py | 2 +- 7 files changed, 71 insertions(+), 67 deletions(-) diff --git a/colabdesign/af/alphafold/model/config.py b/colabdesign/af/alphafold/model/config.py index 6af23762..32f001e7 100644 --- a/colabdesign/af/alphafold/model/config.py +++ b/colabdesign/af/alphafold/model/config.py @@ -308,8 +308,7 @@ def model_config(name: str) -> ml_collections.ConfigDict: 'subbatch_size': 4, 'use_remat': False, 'zero_init': True, - 'use_dgram': False, - 'use_prev_dgram': False + 'use_dgram_pred': False, }, 'heads': { 'distogram': { @@ -538,8 +537,7 @@ def model_config(name: str) -> ml_collections.ConfigDict: 'subbatch_size': 4, 'use_remat': False, 'zero_init': True, - 'use_dgram': False, - 'use_prev_dgram': False + 'use_dgram_pred': False, }, 'heads': { 'distogram': { diff --git a/colabdesign/af/alphafold/model/modules.py b/colabdesign/af/alphafold/model/modules.py index 876416f2..f1562f3a 100644 --- a/colabdesign/af/alphafold/model/modules.py +++ b/colabdesign/af/alphafold/model/modules.py @@ -150,20 +150,20 @@ def __init__(self, config, name='alphafold'): def __call__(self, batch, **kwargs): """Run the AlphaFold model.""" impl = AlphaFoldIteration(self.config, self.global_config) - - def get_prev(ret): + def get_prev(ret, use_dgram=False): new_prev = { 'prev_msa_first_row': ret['representations']['msa_first_row'], 'prev_pair': ret['representations']['pair'], } - if self.global_config.use_dgram: - dgram = jax.nn.softmax(ret["distogram"]["logits"]) - dgram_map = jax.nn.one_hot(jnp.repeat(jnp.append(0,jnp.arange(15)),4),15).at[:,0].set(0) - new_prev['prev_dgram'] = dgram @ dgram_map - elif self.global_config.use_prev_dgram: - pos = ret['structure_module']['final_atom_positions'] - prev_pseudo_beta = pseudo_beta_fn(batch['aatype'], pos, None) - new_prev['prev_dgram'] = dgram_from_positions(prev_pseudo_beta, min_bin=3.25, max_bin=20.75, num_bins=15) + if use_dgram: + if self.global_config.use_dgram_pred: + dgram = jax.nn.softmax(ret["distogram"]["logits"]) + dgram_map = jax.nn.one_hot(jnp.repeat(jnp.append(0,jnp.arange(15)),4),15).at[:,0].set(0) + new_prev['prev_dgram'] = dgram @ dgram_map + else: + pos = ret['structure_module']['final_atom_positions'] + prev_pseudo_beta = pseudo_beta_fn(batch['aatype'], pos, None) + new_prev['prev_dgram'] = dgram_from_positions(prev_pseudo_beta, min_bin=3.25, max_bin=20.75, num_bins=15) else: new_prev['prev_pos'] = ret['structure_module']['final_atom_positions'] @@ -171,7 +171,7 @@ def get_prev(ret): prev = batch.pop("prev") ret = impl(batch={**batch, **prev}) - ret["prev"] = get_prev(ret) + ret["prev"] = get_prev(ret, use_dgram="prev_dgram" in prev) return ret class TemplatePairStack(hk.Module): diff --git a/colabdesign/af/alphafold/model/modules_multimer.py b/colabdesign/af/alphafold/model/modules_multimer.py index 6d16ea6b..88cd436f 100644 --- a/colabdesign/af/alphafold/model/modules_multimer.py +++ b/colabdesign/af/alphafold/model/modules_multimer.py @@ -178,19 +178,20 @@ def __call__( assert isinstance(batch, dict) num_res = batch['aatype'].shape[0] - def get_prev(ret): + def get_prev(ret, use_dgram=False): new_prev = { 'prev_msa_first_row': ret['representations']['msa_first_row'], 'prev_pair': ret['representations']['pair'], } - if self.global_config.use_dgram: - dgram = jax.nn.softmax(ret["distogram"]["logits"]) - dgram_map = jax.nn.one_hot(jnp.repeat(jnp.append(0,jnp.arange(15)),4),15).at[:,0].set(0) - new_prev['prev_dgram'] = dgram @ dgram_map - elif self.global_config.use_prev_dgram: - pos = ret['structure_module']['final_atom_positions'] - prev_pseudo_beta = modules.pseudo_beta_fn(batch['aatype'], pos, None) - new_prev['prev_dgram'] = modules.dgram_from_positions(prev_pseudo_beta, min_bin=3.25, max_bin=20.75, num_bins=15) + if use_dgram: + if self.global_config.use_dgram_pred: + dgram = jax.nn.softmax(ret["distogram"]["logits"]) + dgram_map = jax.nn.one_hot(jnp.repeat(jnp.append(0,jnp.arange(15)),4),15).at[:,0].set(0) + new_prev['prev_dgram'] = dgram @ dgram_map + else: + pos = ret['structure_module']['final_atom_positions'] + prev_pseudo_beta = modules.pseudo_beta_fn(batch['aatype'], pos, None) + new_prev['prev_dgram'] = modules.dgram_from_positions(prev_pseudo_beta, min_bin=3.25, max_bin=20.75, num_bins=15) else: new_prev['prev_pos'] = ret['structure_module']['final_atom_positions'] @@ -201,9 +202,10 @@ def apply_network(prev, safe_key): return impl( batch=recycled_batch, safe_key=safe_key) - - ret = apply_network(prev=batch.pop("prev"), safe_key=safe_key) - ret["prev"] = get_prev(ret) + + prev = batch.pop("prev") + ret = apply_network(prev=prev, safe_key=safe_key) + ret["prev"] = get_prev(ret, use_dgram="prev_dgram" in prev) if not return_representations: del ret['representations'] diff --git a/colabdesign/af/design.py b/colabdesign/af/design.py index aaf0582a..88367e8e 100644 --- a/colabdesign/af/design.py +++ b/colabdesign/af/design.py @@ -159,27 +159,37 @@ def _recycle(self, model_params, num_recycles=None, backprop=True): else: L = self._inputs["residue_index"].shape[0] - # intialize previous + # intialize previous inputs if "prev" not in self._inputs or a["clear_prev"]: prev = {'prev_msa_first_row': np.zeros([L,256]), 'prev_pair': np.zeros([L,L,128])} - if a["use_initial_guess"] and "batch" in self._inputs: - if a["use_prev_dgram"]: - prev["prev_dgram"] = dgram_from_positions( - self._inputs["batch"]["all_atom_positions"], - min_bin=3.25, max_bin=20.75, num_bins=15) - else: - prev["prev_pos"] = self._inputs["batch"]["all_atom_positions"] + + # initialize coordinates + # TODO: add support for the 'partial' protocol + if "batch" in self._inputs: + ini_seq = self._inputs["batch"]["aatype"] + ini_pos = self._inputs["batch"]["all_atom_positions"] + + # via evoformer + if a["use_initial_guess"]: + # via distogram or positions + if a["use_dgram"] or a["use_dgram_pred"]: + prev["prev_dgram"] = dgram_from_positions(positions=ini_pos, + seq=ini_seq, num_bins=15, min_bin=3.25, max_bin=20.75) + else: + prev["prev_pos"] = ini_pos + + # via structure module + if a["use_initial_atom_pos"]: + self._inputs["initial_atom_pos"] = ini_pos + else: - if a["use_prev_dgram"]: + # if batch not defined, initialize with zeros + if a["use_dgram"] or a["use_dgram_pred"]: prev["prev_dgram"] = np.zeros([L,L,15]) else: - prev["prev_pos"] = np.zeros([L,37,3]) - - if a["use_initial_atom_pos"]: - if "batch" in self._inputs: - self._inputs["initial_atom_pos"] = self._inputs["batch"]["all_atom_positions"] - else: + prev["prev_pos"] = np.zeros([L,37,3]) + if a["use_initial_atom_pos"]: self._inputs["initial_atom_pos"] = np.zeros([L,37,3]) self._inputs["prev"] = prev @@ -200,6 +210,8 @@ def _recycle(self, model_params, num_recycles=None, backprop=True): else: aux = self._single(model_params, backprop) grad.append(jax.tree_map(lambda x:x*m, aux["grad"])) + + # update previous inputs self._inputs["prev"] = aux["prev"] if a["use_initial_atom_pos"]: self._inputs["initial_atom_pos"] = aux["atom_positions"] diff --git a/colabdesign/af/model.py b/colabdesign/af/model.py index f607883e..993c8fd8 100644 --- a/colabdesign/af/model.py +++ b/colabdesign/af/model.py @@ -37,7 +37,7 @@ def __init__(self, "debug":debug, "repeat":False, "homooligomer":False, "copies":1, "optimizer":"sgd", "best_metric":"loss", "traj_iter":1, "traj_max":10000, - "clear_prev": True, "use_prev_dgram":False, + "clear_prev": True, "use_dgram":False, "use_dgram_pred":False, "shuffle_first":True, "use_remat":True, "alphabet_size":20, "use_initial_guess":False, "use_initial_atom_pos":False} @@ -101,7 +101,7 @@ def __init__(self, num_recycles = self.opt["num_recycles"] self._cfg.model.num_recycle = num_recycles self._cfg.model.global_config.use_remat = self._args["use_remat"] - self._cfg.model.global_config.use_prev_dgram = self._args["use_prev_dgram"] + self._cfg.model.global_config.use_dgram_pred = self._args["use_dgram_pred"] self._cfg.model.global_config.bfloat16 = self._args["use_bfloat16"] # load model_params diff --git a/colabdesign/af/utils.py b/colabdesign/af/utils.py index a968c194..504d3248 100644 --- a/colabdesign/af/utils.py +++ b/colabdesign/af/utils.py @@ -8,7 +8,7 @@ from colabdesign.shared.utils import update_dict, Key from colabdesign.shared.plot import plot_pseudo_3D, make_animation, show_pdb from colabdesign.shared.protein import renum_pdb_str -from colabdesign.af.alphafold.common import protein +from colabdesign.af.alphafold.common import protein, residue_constants #################################################### # AF_UTILS - various utils (save, plot, etc) @@ -188,25 +188,17 @@ def plot_current_pdb(self, show_sidechains=False, show_mainchains=False, self.plot_pdb(show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color, color_HP=color_HP, size=size, animate=animate, get_best=False) -def dgram_from_positions(positions, num_bins, min_bin, max_bin): - """Compute distogram from amino acid positions. - Arguments: - positions: [N_res, 3] Position coordinates. - num_bins: The number of bins in the distogram. - min_bin: The left edge of the first bin. - max_bin: The left edge of the final bin. The final bin catches - everything larger than `max_bin`. - Returns: - Distogram with the specified number of bins. - """ - def squared_difference(x, y): - return np.square(x - y) +def dgram_from_positions(positions, seq=None, num_bins=39, min_bin=3.25, max_bin=50.75): + if seq is None: + atoms = {k:positions[...,residue_constants.atom_order[k],:] for k in ["N","CA","C"]} + c = _np_get_cb(**atoms, use_jax=False) + else: + ca = positions[...,residue_constants.atom_order["CA"],:] + cb = positions[...,residue_constants.atom_order["CB"],:] + is_gly = seq==residue_constants.restype_order["G"] + c = np.where(is_gly[:,None],ca,cb) + dist = np.sqrt(np.square(c[None,:] - c[:,None]).sum(-1,keepdims=True)) lower_breaks = np.linspace(min_bin, max_bin, num_bins) - lower_breaks = np.square(lower_breaks) - upper_breaks = np.concatenate([lower_breaks[1:],np.array([1e8])], axis=-1) - dist2 = np.sum( - squared_difference( - np.expand_dims(positions, axis=-2), - np.expand_dims(positions, axis=-3)), - axis=-1, keepdims=True) - return (dist2 > lower_breaks) * (dist2 < upper_breaks) + lower_breaks = lower_breaks + upper_breaks = np.append(lower_breaks[1:],1e8) + return ((dist > lower_breaks) * (dist < upper_breaks)).astype(float) diff --git a/setup.py b/setup.py index 5f692ef9..644ec673 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ from setuptools import setup, find_packages setup( name='colabdesign', - version='1.1.2-beta0', + version='1.1.2-beta1', description='Making Protein Design accessible to all via Google Colab!', long_description="Making Protein Design accessible to all via Google Colab!", long_description_content_type='text/markdown', From d618ea3ff86890cab3fa5be487d3a4c82b0dee34 Mon Sep 17 00:00:00 2001 From: Sergey O Date: Thu, 1 Jun 2023 15:00:00 -0400 Subject: [PATCH 3/5] add option to mask target features with mlm --- colabdesign/af/design.py | 9 +++++++-- colabdesign/af/inputs.py | 4 +++- colabdesign/af/model.py | 5 +++-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/colabdesign/af/design.py b/colabdesign/af/design.py index 88367e8e..381a192b 100644 --- a/colabdesign/af/design.py +++ b/colabdesign/af/design.py @@ -177,8 +177,13 @@ def _recycle(self, model_params, num_recycles=None, backprop=True): prev["prev_dgram"] = dgram_from_positions(positions=ini_pos, seq=ini_seq, num_bins=15, min_bin=3.25, max_bin=20.75) else: - prev["prev_pos"] = ini_pos - + prev["prev_pos"] = ini_pos + else: + if a["use_dgram"] or a["use_dgram_pred"]: + prev["prev_dgram"] = np.zeros([L,L,15]) + else: + prev["prev_pos"] = np.zeros([L,37,3]) + # via structure module if a["use_initial_atom_pos"]: self._inputs["initial_atom_pos"] = ini_pos diff --git a/colabdesign/af/inputs.py b/colabdesign/af/inputs.py index 36fdc5d8..b81dd85f 100644 --- a/colabdesign/af/inputs.py +++ b/colabdesign/af/inputs.py @@ -106,7 +106,7 @@ def _update_template(self, inputs, key): inputs[k] = inputs[k].at[...,5:].set(jnp.where(rm_sc[:,None],0,inputs[k][...,5:])) inputs[k] = jnp.where(rm[:,None],0,inputs[k]) -def update_seq(seq, inputs, seq_1hot=None, seq_pssm=None, mlm=None): +def update_seq(seq, inputs, seq_1hot=None, seq_pssm=None, mlm=None, mask_target=False): '''update the sequence features''' if seq_1hot is None: seq_1hot = seq["pseudo"] if seq_pssm is None: seq_pssm = seq["pseudo"] @@ -122,6 +122,8 @@ def update_seq(seq, inputs, seq_1hot=None, seq_pssm=None, mlm=None): Y = jnp.zeros(msa_feat.shape[-1]).at[...,:23].set(X).at[...,25:48].set(X) msa_feat = jnp.where(mlm[...,None],Y,msa_feat) seq["pseudo"] = jnp.where(mlm[...,None],X[:seq["pseudo"].shape[-1]],seq["pseudo"]) + if mask_target: + target_feat = jnp.where(mlm[0,:,None],0,target_feat) inputs.update({"msa_feat":msa_feat, "target_feat":target_feat}) diff --git a/colabdesign/af/model.py b/colabdesign/af/model.py index 993c8fd8..10eec10a 100644 --- a/colabdesign/af/model.py +++ b/colabdesign/af/model.py @@ -33,7 +33,7 @@ def __init__(self, self.protocol = protocol self._num = kwargs.pop("num_seq",1) self._args = {"use_templates":use_templates, "use_multimer":use_multimer, "use_bfloat16":True, - "recycle_mode":"last", "use_mlm": False, "realign": True, + "recycle_mode":"last", "use_mlm": False, "mask_target":False, "realign": True, "debug":debug, "repeat":False, "homooligomer":False, "copies":1, "optimizer":"sgd", "best_metric":"loss", "traj_iter":1, "traj_max":10000, @@ -159,7 +159,7 @@ def _model(params, model_params, inputs, key): if a["use_mlm"]: shape = seq["pseudo"].shape[:2] mlm = jax.random.bernoulli(key(),opt["mlm_dropout"],shape) - update_seq(seq, inputs, seq_pssm=pssm, mlm=mlm) + update_seq(seq, inputs, seq_pssm=pssm, mlm=mlm, mask_target=a["mask_target"]) else: update_seq(seq, inputs, seq_pssm=pssm) @@ -221,6 +221,7 @@ def _model(params, model_params, inputs, key): # experimental masked-language-modeling if a["use_mlm"]: aux["mlm"] = outputs["masked_msa"]["logits"] + aux["mlm_mask"] = mlm mask = jnp.where(inputs["seq_mask"],mlm,0) aux["losses"].update(get_mlm_loss(outputs, mask=mask, truth=seq["pssm"])) From 986d08f58983190af00ce4e34cf818260b589391 Mon Sep 17 00:00:00 2001 From: Sergey O Date: Mon, 5 Jun 2023 12:33:04 -0400 Subject: [PATCH 4/5] add option to unbias mlm --- colabdesign/af/loss.py | 5 ++++- colabdesign/af/model.py | 7 +++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/colabdesign/af/loss.py b/colabdesign/af/loss.py index b0b463be..0100f30f 100644 --- a/colabdesign/af/loss.py +++ b/colabdesign/af/loss.py @@ -541,8 +541,11 @@ def get_seq_ent_loss(inputs): ent = (ent * mask).sum() / (mask.sum() + 1e-8) return {"seq_ent":ent.mean()} -def get_mlm_loss(outputs, mask, truth=None): +def get_mlm_loss(outputs, mask, truth=None, unbias=False): x = outputs["masked_msa"]["logits"][...,:20] + if unbias: + x_mean = (x * mask[...,None]).sum((0,1)) / (mask.sum() + 1e-8) + x = x - x_mean if truth is None: truth = jax.nn.softmax(x) ent = -(truth[...,:20] * jax.nn.log_softmax(x)).sum(-1) ent = (ent * mask).sum(-1) / (mask.sum() + 1e-8) diff --git a/colabdesign/af/model.py b/colabdesign/af/model.py index 10eec10a..8b6be168 100644 --- a/colabdesign/af/model.py +++ b/colabdesign/af/model.py @@ -33,7 +33,9 @@ def __init__(self, self.protocol = protocol self._num = kwargs.pop("num_seq",1) self._args = {"use_templates":use_templates, "use_multimer":use_multimer, "use_bfloat16":True, - "recycle_mode":"last", "use_mlm": False, "mask_target":False, "realign": True, + "recycle_mode":"last", + "use_mlm": False, "mask_target":False, "unbias_mlm": False, + "realign": True, "debug":debug, "repeat":False, "homooligomer":False, "copies":1, "optimizer":"sgd", "best_metric":"loss", "traj_iter":1, "traj_max":10000, @@ -223,7 +225,8 @@ def _model(params, model_params, inputs, key): aux["mlm"] = outputs["masked_msa"]["logits"] aux["mlm_mask"] = mlm mask = jnp.where(inputs["seq_mask"],mlm,0) - aux["losses"].update(get_mlm_loss(outputs, mask=mask, truth=seq["pssm"])) + aux["losses"].update(get_mlm_loss(outputs, mask=mask, + truth=seq["pssm"], unbias=a["unbias_mlm"])) # run user defined callbacks for c in ["loss","post"]: From e875a9df90295363a9e5de7f57f47ee88601feec Mon Sep 17 00:00:00 2001 From: Sergey O Date: Tue, 6 Jun 2023 11:21:04 -0400 Subject: [PATCH 5/5] Update modules.py --- colabdesign/af/alphafold/model/modules.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colabdesign/af/alphafold/model/modules.py b/colabdesign/af/alphafold/model/modules.py index f1562f3a..d0d7c631 100644 --- a/colabdesign/af/alphafold/model/modules.py +++ b/colabdesign/af/alphafold/model/modules.py @@ -1391,6 +1391,7 @@ def __call__(self, batch, safe_key=None): msa_feat = batch['msa_feat'].astype(dtype) target_feat = jnp.pad(batch["target_feat"].astype(dtype),[[0,0],[1,1]]) preprocess_1d = common_modules.Linear(c.msa_channel, name='preprocess_1d')(target_feat) + preprocess_1d = jnp.where(target_feat.sum(-1,keepdims=True) == 0, 0, preprocess_1d) preprocess_msa = common_modules.Linear(c.msa_channel, name='preprocess_msa')(msa_feat) msa_activations = preprocess_1d[None] + preprocess_msa