From 13f3e72a4a25c76942a6ff6526cb7ee0b1cd702c Mon Sep 17 00:00:00 2001 From: Sergey O Date: Tue, 25 Oct 2022 11:52:19 -0400 Subject: [PATCH 1/4] v1.1.0 (#94) - minor edits * Update README.md * Update README.md * typo * typo * Update README.md * adding initial code from Shihao * cleanup (#91) * mpnn: moving decoding_order logic from sample.py/score.py to model.py * mpnn: moving notebook logic to model.py * mpnn: Update proteinmpnn_in_jax.ipynb to avoid recompiling if shapes are the same * increasing default traj_max from 500 to 10000 * adding option to specify custom traj for animate() * Update model.py * adding pandas dataframes :D * typo * Update utils.py * adding option to display any design * Update proteinmpnn_in_jax.ipynb * removing broken code * minor edits to make mpnn standalone * minor edits to avoid recompiling mpnn when rm_aa is updated * cleanup; removing overly complicated option(s) * Update README.md * Update README.md * Update README.md * adding check for blank fix_pos * adding logic to skip alternative amino acids during pdb parsing * model._best moved to model._tmp["best"] * Update README.md --- af/examples/peptide_binder_design.ipynb | 2 +- colabdesign/af/design.py | 2 +- colabdesign/af/model.py | 2 +- colabdesign/af/prep.py | 4 +- colabdesign/af/utils.py | 16 +-- colabdesign/mpnn/model.py | 44 ++++--- colabdesign/mpnn/score.py | 2 +- colabdesign/shared/protein.py | 13 +- esm_msa/README.md | 2 +- mpnn/README.md | 15 ++- mpnn/examples/proteinmpnn_in_jax.ipynb | 161 +++++++++++++----------- 11 files changed, 161 insertions(+), 102 deletions(-) diff --git a/af/examples/peptide_binder_design.ipynb b/af/examples/peptide_binder_design.ipynb index 63fa9d13..40ae1c2c 100644 --- a/af/examples/peptide_binder_design.ipynb +++ b/af/examples/peptide_binder_design.ipynb @@ -296,7 +296,7 @@ "cell_type": "code", "source": [ "# log\n", - "model._best[\"aux\"][\"log\"]" + "model._tmp[\"best\"][\"aux\"][\"log\"]" ], "metadata": { "id": "1SGmdJKLNKvb" diff --git a/colabdesign/af/design.py b/colabdesign/af/design.py index 00a60980..07127587 100644 --- a/colabdesign/af/design.py +++ b/colabdesign/af/design.py @@ -245,7 +245,7 @@ def _save_results(self, aux=None, save_best=False, if best_metric is None: best_metric = self._args["best_metric"] metric = float(aux["log"][best_metric]) - if self._args["best_metric"] in ["plddt","ptm","i_ptm","seqid"] or metric_higher_better: + if self._args["best_metric"] in ["plddt","ptm","i_ptm","seqid","composite"] or metric_higher_better: metric = -metric if "metric" not in self._tmp["best"] or metric < self._tmp["best"]["metric"]: self._tmp["best"]["aux"] = aux diff --git a/colabdesign/af/model.py b/colabdesign/af/model.py index 9769cb07..96a42ec6 100644 --- a/colabdesign/af/model.py +++ b/colabdesign/af/model.py @@ -30,7 +30,7 @@ def __init__(self, protocol="fixbb", num_seq=1, use_multimer=False, use_mlm=False, pre_callback=None, post_callback=None, pre_design_callback=None, post_design_callback=None, - loss_callback=None, traj_iter=1, traj_max=500, debug=False, data_dir="."): + loss_callback=None, traj_iter=1, traj_max=10000, debug=False, data_dir="."): assert protocol in ["fixbb","hallucination","binder","partial"] assert recycle_mode in ["average","first","last","sample","add_prev","backprop"] diff --git a/colabdesign/af/prep.py b/colabdesign/af/prep.py index 1129e77e..0049fb0e 100644 --- a/colabdesign/af/prep.py +++ b/colabdesign/af/prep.py @@ -69,7 +69,7 @@ def _prep_fixbb(self, pdb_filename, chain=None, res_idx = self._pdb["residue_index"] # get [pos]itions of interests - if fix_pos is not None: + if fix_pos is not None and fix_pos != "": self._pos_info = prep_pos(fix_pos, **self._pdb["idx"]) self.opt["fix_pos"] = self._pos_info["pos"] @@ -336,7 +336,7 @@ def _prep_partial(self, pdb_filename, chain=None, length=None, self.opt["fix_pos"] = np.arange(self.opt["pos"].shape[0]) self._wt_aatype_sub = self._wt_aatype - elif fix_pos is not None: + elif fix_pos is not None and fix_pos != "": sub_fix_pos = [] sub_i = [] pos = self.opt["pos"].tolist() diff --git a/colabdesign/af/utils.py b/colabdesign/af/utils.py index dc333c10..c73a2c14 100644 --- a/colabdesign/af/utils.py +++ b/colabdesign/af/utils.py @@ -95,7 +95,7 @@ def to_pdb_str(x, n=None): #------------------------------------- # plotting functions #------------------------------------- - def animate(self, s=0, e=None, dpi=100, get_best=True, aux=None, color_by="plddt"): + def animate(self, s=0, e=None, dpi=100, get_best=True, traj=None, aux=None, color_by="plddt"): ''' animate the trajectory - use [s]tart and [e]nd to define range to be animated @@ -104,27 +104,29 @@ def animate(self, s=0, e=None, dpi=100, get_best=True, aux=None, color_by="plddt ''' if aux is None: aux = self._tmp["best"]["aux"] if (get_best and "aux" in self._tmp["best"]) else self.aux - aux = aux["all"] - + aux = aux["all"] if self.protocol in ["fixbb","binder"]: pos_ref = self._inputs["batch"]["all_atom_positions"][:,1].copy() pos_ref[(pos_ref == 0).any(-1)] = np.nan else: pos_ref = aux["atom_positions"][0,:,1,:] - sub_traj = {k:v[s:e] for k,v in self._tmp["traj"].items()} - + + if traj is None: traj = self._tmp["traj"] + sub_traj = {k:v[s:e] for k,v in traj.items()} + align_xyz = self.protocol == "hallucination" return make_animation(**sub_traj, pos_ref=pos_ref, length=self._lengths, color_by=color_by, align_xyz=align_xyz, dpi=dpi) def plot_pdb(self, show_sidechains=False, show_mainchains=False, color="pLDDT", color_HP=False, size=(800,480), animate=False, - get_best=True, aux=None): + get_best=True, aux=None, pdb_str=None): ''' use py3Dmol to plot pdb coordinates - color=["pLDDT","chain","rainbow"] ''' - pdb_str = self.save_pdb(get_best=get_best, aux=aux) + if pdb_str is None: + pdb_str = self.save_pdb(get_best=get_best, aux=aux) view = show_pdb(pdb_str, show_sidechains=show_sidechains, show_mainchains=show_mainchains, diff --git a/colabdesign/mpnn/model.py b/colabdesign/mpnn/model.py index eeb152c5..820ac9fc 100644 --- a/colabdesign/mpnn/model.py +++ b/colabdesign/mpnn/model.py @@ -12,15 +12,16 @@ from colabdesign.shared.prep import prep_pos from colabdesign.shared.utils import Key, copy_dict -from colabdesign.shared.model import design_model, soft_seq # borrow some stuff from AfDesign -from colabdesign.af.prep import prep_pdb, order_aa +from colabdesign.af.prep import prep_pdb from colabdesign.af.alphafold.common import protein, residue_constants +aa_order = residue_constants.restype_order +order_aa = {b:a for a,b in aa_order.items()} from scipy.special import softmax, log_softmax -class mk_mpnn_model(design_model): +class mk_mpnn_model(): def __init__(self, model_name="v_48_020", backbone_noise=0.0, dropout=0.0, seed=None, verbose=False): @@ -43,7 +44,6 @@ def __init__(self, model_name="v_48_020", self.set_seed(seed) self._num = 1 - self._params = {} self._inputs = {} self._tied_lengths = False @@ -56,21 +56,25 @@ def prep_inputs(self, pdb_filename=None, chain=None, homooligomer=False, atom_idx = tuple(residue_constants.atom_order[k] for k in ["N","CA","C","O"]) chain_idx = np.concatenate([[n]*l for n,l in enumerate(pdb["lengths"])]) self._lengths = pdb["lengths"] - self._len = sum(self._lengths) + L = sum(self._lengths) self._inputs = {"X": pdb["batch"]["all_atom_positions"][:,atom_idx], "mask": pdb["batch"]["all_atom_mask"][:,1], "S": pdb["batch"]["aatype"], "residue_idx": pdb["residue_index"], "chain_idx": chain_idx, - "lengths": np.array(self._lengths)} + "lengths": np.array(self._lengths), + "bias": np.zeros((L,20))} - self.set_seq(self._inputs["S"], rm_aa=rm_aa) + + if rm_aa is not None: + for aa in rm_aa.split(","): + self._inputs["bias"][...,aa_order[aa]] -= 1e6 if fix_pos is not None: p = prep_pos(fix_pos, **pdb["idx"])["pos"] if inverse: - p = np.delete(np.arange(self._len),p) + p = np.delete(np.arange(L),p) self._inputs["fix_pos"] = p self._inputs["bias"][p] = 1e7 * np.eye(21)[self._inputs["S"]][p,:20] @@ -78,6 +82,9 @@ def prep_inputs(self, pdb_filename=None, chain=None, homooligomer=False, assert min(self._lengths) == max(self._lengths) self._tied_lengths = True self._len = self._lengths[0] + else: + self._tied_lengths = False + self._len = sum(self._lengths) self.pdb = pdb @@ -112,6 +119,8 @@ def get_af_inputs(self, af): if af._args["homooligomer"]: assert min(self._lengths) == max(self._lengths) self._tied_lengths = True + else: + self._tied_lengths = False def sample(self, num=1, batch=1, temperature=0.1, rescore=False, **kwargs): '''sample sequence''' @@ -140,9 +149,10 @@ def _get_seq(self, O): ''' one_hot to amino acid sequence ''' def split_seq(seq): if len(self._lengths) > 1: - return "".join(np.insert(list(seq),np.cumsum(self._lengths[:-1]),"/")) - else: - return seq + seq = "".join(np.insert(list(seq),np.cumsum(self._lengths[:-1]),"/")) + if self._tied_lengths: + seq = seq.split("/")[0] + return seq seqs, S = [], O["S"].argmax(-1) if S.ndim == 1: S = [S] for s in S: @@ -177,8 +187,9 @@ def score(self, seq=None, **kwargs): '''score sequence''' I = copy_dict(self._inputs) if seq is not None: - self.set_seq(seq) - I["S"] = self._params["seq"][0] + if self._tied_lengths and len(seq) == self._lengths[0]: + seq = seq * len(self._lengths) + I["S"] = np.array([aa_order.get(aa,-1) for aa in seq]) I.update(kwargs) key = I.pop("key",self.key()) O = jax.tree_map(np.array, self._score(**I, key=key)) @@ -190,9 +201,14 @@ def get_logits(self, **kwargs): return self.score(**kwargs)["logits"] def get_unconditional_logits(self, **kwargs): - kwargs["decoding_order"] = np.full(self._len,-1) + L = self._inputs["X"].shape[0] + kwargs["ar_mask"] = np.zeros((L,L)) return self.score(**kwargs)["logits"] + def set_seed(self, seed=None): + np.random.seed(seed=seed) + self.key = Key(seed=seed).get + def _setup(self): def _score(X, mask, residue_idx, chain_idx, key, **kwargs): I = {'X': X, diff --git a/colabdesign/mpnn/score.py b/colabdesign/mpnn/score.py index e26f24c7..48cd200e 100644 --- a/colabdesign/mpnn/score.py +++ b/colabdesign/mpnn/score.py @@ -25,7 +25,7 @@ def score(self, I): h_EX_encoder = cat_neighbors_nodes(jnp.zeros_like(h_V), h_E, E_idx) h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx) - if I.get("S",None) is None: + if "S" not in I: ########################################## # unconditional_probs ########################################## diff --git a/colabdesign/shared/protein.py b/colabdesign/shared/protein.py index 1bef67c7..7431a2a1 100644 --- a/colabdesign/shared/protein.py +++ b/colabdesign/shared/protein.py @@ -24,6 +24,7 @@ def pdb_to_string(pdb_file): modres = {**MODRES} lines = [] + seen = [] for line in open(pdb_file,"rb"): line = line.decode("utf-8","ignore").rstrip() if line[:6] == "MODRES": @@ -36,7 +37,17 @@ def pdb_to_string(pdb_file): if k in modres: line = "ATOM "+line[6:17]+modres[k]+line[20:] if line[:4] == "ATOM": - lines.append(line) + chain = line[21:22] + atom = line[12:12+4].strip() + resi = line[17:17+3] + resn = line[22:22+5].strip() + if resn[-1].isalpha(): # alternative atom + resn = resn[:-1] + line = line[:26]+" "+line[27:] + key = f"{chain}_{resn}_{resi}_{atom}" + if key not in seen: # skip alternative placements + lines.append(line) + seen.append(key) return "\n".join(lines) def renum_pdb_str(pdb_str, Ls=None, renum=True, offset=1): diff --git a/esm_msa/README.md b/esm_msa/README.md index 133c96c9..d2949122 100644 --- a/esm_msa/README.md +++ b/esm_msa/README.md @@ -7,5 +7,5 @@ ### Contributors: - Shihao Feng [@JeffSHF](https://github.com/JeffSHF) -- Weikun.Wu [@guyujun](https://github.com/guyujun) (from [Levinthal.bio](http://levinthal.bio/en/)) +- Lin Xu - Facebook Research - [original Pytorch code](https://github.com/facebookresearch/esm) diff --git a/mpnn/README.md b/mpnn/README.md index e91db2ba..35e6fe9d 100644 --- a/mpnn/README.md +++ b/mpnn/README.md @@ -46,9 +46,18 @@ mpnn_model.prep_inputs(pdb_filename="tmp.pdb", chain="A,B", fix_pos="A") mpnn_model.prep_inputs(pdb_filename="tmp.pdb", rm_aa="C") ``` #### I want more control! -You can modify the bias matrix directly! The bias matrix is a (length, 20) matrix. Using large negative/positive values in the bias matrix is how we prevent certain amino acids from being sampled (rm_aa) and fix certain positions (fix_pos). For reference, the alphabet used: `ARNDCQEGHILKMFPSTWYV`. +You can modify the bias matrix directly! The bias matrix is a (length, 21) matrix. Using large negative/positive values in the bias matrix is how we prevent certain amino acids from being sampled (rm_aa) and fix certain positions (fix_pos). For reference, the alphabet used: `ARNDCQEGHILKMFPSTWYV`. + +For example, to add alanine bias to the first position, do: +```python +from colabdesign.mpnn.model import aa_order +mpnn_model.prep_inputs(pdb_filename="tmp.pdb") +mpnn_model._inputs["bias"][0,aa_order["A"]] = 1.0 +``` +For example, if you want to add a hydrophilic bias to all positions, you can do: ```python -mpnn_model._inputs["bias"][:,0] = 1e8 +for k in "DEHKNQRSTWY": + mpnn_model._inputs["bias"][:,aa_order[k]] = 1.39 ``` #### How about tied sampling for homo-oligomeric complexes? ```python @@ -70,7 +79,7 @@ for n,S in enumerate(samples["S"]): ``` ### Contributors: -- Sergey Ovchinnikov [@sokrypton](https://github.com/sokrypton) - Shihao Feng [@JeffSHF](https://github.com/JeffSHF) +- Sergey Ovchinnikov [@sokrypton](https://github.com/sokrypton) - Simon Kozlov [@sim0nsays](https://github.com/sim0nsays) - Justas Dauparas [@dauparas](https://github.com/dauparas) - [original pytorch code](https://github.com/dauparas/ProteinMPNN) diff --git a/mpnn/examples/proteinmpnn_in_jax.ipynb b/mpnn/examples/proteinmpnn_in_jax.ipynb index ba5cc820..8254146d 100644 --- a/mpnn/examples/proteinmpnn_in_jax.ipynb +++ b/mpnn/examples/proteinmpnn_in_jax.ipynb @@ -61,13 +61,21 @@ " os.system(\"ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign\")\n", "\n", "from colabdesign.mpnn import mk_mpnn_model, clear_mem\n", + "from colabdesign.shared.protein import pdb_to_string\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from IPython.display import HTML\n", + "import pandas as pd\n", + "import tqdm.notebook\n", + "TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'\n", + "\n", "from google.colab import files\n", + "from google.colab import data_table\n", + "data_table.enable_dataframe_formatter()\n", + "\n", "\n", "def get_pdb(pdb_code=\"\"):\n", " if pdb_code is None or pdb_code == \"\":\n", @@ -127,16 +135,13 @@ "if rm_aa == \"\": rm_aa = None\n", "\n", "pdb_path = get_pdb(pdb)\n", - "\n", - "mpnn_args = [pdb_path, chains, homooligomer, fix_pos, inverse, rm_aa]\n", - "if \"mpnn_args_current\" not in dir() or mpnn_args != mpnn_args_current:\n", + "if \"mpnn_model\" not in dir():\n", " mpnn_model = mk_mpnn_model(model_name)\n", - " mpnn_model.prep_inputs(pdb_filename=pdb_path,\n", - " chain=chains, homooligomer=homooligomer,\n", - " fix_pos=fix_pos, inverse=inverse,\n", - " rm_aa=rm_aa, verbose=True)\n", - " mpnn_args_current = [x for x in mpnn_args]\n", "\n", + "mpnn_model.prep_inputs(pdb_filename=pdb_path,\n", + " chain=chains, homooligomer=homooligomer,\n", + " fix_pos=fix_pos, inverse=inverse,\n", + " rm_aa=rm_aa, verbose=True)\n", "out = mpnn_model.sample(num=num_seqs//32, batch=32,\n", " temperature=sampling_temp,\n", " rescore=homooligomer)\n", @@ -145,7 +150,12 @@ " for n in range(num_seqs):\n", " line = f'>score:{out[\"score\"][n]:.3f}_seqid:{out[\"seqid\"][n]:.3f}\\n{out[\"seq\"][n]}'\n", " fasta.write(line+\"\\n\")\n", - " print(line)" + "\n", + "labels = [\"score\",\"seqid\",\"seq\"]\n", + "data = [[out[k][n] for k in labels] for n in range(num_seqs)]\n", + "\n", + "df = pd.DataFrame(data, columns=labels)\n", + "data_table.DataTable(df.round(3))" ], "metadata": { "cellView": "form", @@ -216,15 +226,8 @@ "num_models = 1 #@param [\"1\",\"2\",\"3\",\"4\",\"5\"] {type:\"raw\"}\n", "num_recycles = 1 #@param [\"0\",\"1\",\"2\",\"3\"] {type:\"raw\"}\n", "use_multimer = False #@param {type:\"boolean\"}\n", - "#@markdown ###AF2Rank Options (WIP)\n", - "use_AF2Rank = False #@param {type:\"boolean\"}\n", - "#@markdown - AF2Rank uses native structure as input template and assess the \n", - "#@markdown agreement between sequence and structure using AlphaFold's confidence metrics.\n", - "#@markdown - The \"composite\" metric is defined as pLDDT * pTMscore. (WIP: TMscore between input/output not yet implemented.)\n", + "use_templates = False #@param {type:\"boolean\"}\n", "rm_template_interchain = False #@param {type:\"boolean\"}\n", - "#@markdown - Remove interface template info. (Recommended for evaluating redesigned interfaces).\n", - "constrain_fix_pos = False #@param {type:\"boolean\"}\n", - "#@markdown - constrain fixed position (aka do not remove template sequence/sidechain on for fixed positions)\n", "if not os.path.isdir(\"params\"):\n", " os.system(\"mkdir params\")\n", " os.system(\"curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params\")\n", @@ -235,46 +238,49 @@ "\n", "from colabdesign.af import mk_af_model\n", "af_args = [pdb_path, chains, homooligomer,\n", - " use_multimer, use_AF2Rank]\n", - "\n", - "if \"af_args_current\" not in dir() or af_args != af_arg_current:\n", + " use_multimer, use_templates]\n", + "if \"af_arg_current\" not in dir() or af_args != af_arg_current:\n", " af_model = mk_af_model(use_multimer=use_multimer,\n", - " use_templates=use_AF2Rank,\n", + " use_templates=use_templates,\n", " best_metric=\"dgram_cce\")\n", " af_model.prep_inputs(pdb_path,chains,homooligomer=homooligomer)\n", " af_arg_current = [x for x in af_args]\n", "\n", "af_model.restart()\n", - "if use_AF2Rank:\n", - " af_model.set_opt(\"template\", rm_ic=rm_template_interchain)\n", - " if constrain_fix_pos and \"fix_pos\" in mpnn_model._inputs:\n", - " p = mpnn_model._inputs[\"fix_pos\"]\n", - " af_model._inputs[\"rm_template_seq\"][p] = False\n", - " af_model._inputs[\"rm_template_sc\"][p] = False\n", - " else:\n", - " af_model._inputs[\"rm_template_seq\"][:] = True\n", - " af_model._inputs[\"rm_template_sc\"][:] = True\n", - "\n", - "for S in out[\"S\"]:\n", - " seq = S[:af_model._len].argmax(-1)\n", - " af_model.predict(seq=seq,\n", - " num_recycles=num_recycles,\n", - " num_models=num_models,\n", - " verbose=False)\n", - " (rmsd, ptm, plddt) = (af_model.aux[\"log\"][k] for k in [\"rmsd\",\"ptm\",\"plddt\"])\n", - " if use_AF2Rank:\n", - " af_model.aux[\"log\"][\"composite\"] = ptm * plddt\n", - " af_model._save_results(save_best=True,\n", - " best_metric=\"composite\",\n", - " metric_higher_better=True)\n", - " else:\n", - " af_model._save_results(save_best=True)\n", + "af_model.set_opt(\"template\", rm_ic=rm_template_interchain)\n", "\n", - " af_model._k += 1\n", - " af_model.save_current_pdb(f\"all_pdb/ptm{ptm:.3f}_plddt{plddt:.3f}_rmsd{rmsd:.3f}_n{af_model._k}.pdb\")\n", + "with tqdm.notebook.tqdm(total=out[\"S\"].shape[0], bar_format=TQDM_BAR_FORMAT) as pbar:\n", + " for n,S in enumerate(out[\"S\"]):\n", + " seq = S[:af_model._len].argmax(-1)\n", + " af_model.predict(seq=seq,\n", + " num_recycles=num_recycles,\n", + " num_models=num_models,\n", + " verbose=False)\n", + " (rmsd, ptm, plddt) = (af_model.aux[\"log\"][k] for k in [\"rmsd\",\"ptm\",\"plddt\"])\n", + " af_model.aux[\"log\"][\"composite\"] = ptm * plddt\n", + " af_model._save_results(save_best=True, verbose=False)\n", + " af_model.save_current_pdb(f\"all_pdb/n{n}.pdb\")\n", + " af_model._k += 1\n", + " pbar.update(1)\n", "\n", "af_model.save_pdb(f\"best.pdb\")\n", - "#@markdown Note: designed pdbs are saved to `all_pdb/`\n" + "\n", + "data = []\n", + "labels = [\"dgram_cce\",\"plddt\",\"ptm\",\"i_ptm\",\"rmsd\",\"composite\",\"mpnn\",\"seqid\",\"seq\"]\n", + "for n,af in enumerate(af_model._tmp[\"log\"]):\n", + " data.append([af[\"dgram_cce\"],\n", + " af[\"plddt\"],\n", + " af[\"ptm\"],\n", + " af[\"i_ptm\"],\n", + " af[\"rmsd\"],\n", + " af[\"composite\"],\n", + " out[\"score\"][n],\n", + " out[\"seqid\"][n],\n", + " out[\"seq\"][n]])\n", + "\n", + "df = pd.DataFrame(data, columns=labels)\n", + "data_table.DataTable(df.sort_values(\"dgram_cce\").round(3))\n", + "#@markdown Note: designed pdbs are saved to `all_pdb/`" ], "metadata": { "cellView": "form", @@ -286,31 +292,26 @@ { "cell_type": "code", "source": [ - "#@title animate\n", - "color_by = \"plddt\" #@param [\"chain\", \"plddt\", \"rainbow\"]\n", - "dpi = 100 #@param {type:\"integer\"}\n", - "HTML(af_model.animate(color_by=color_by, dpi=dpi))" - ], - "metadata": { - "cellView": "form", - "id": "74LAQHZGTZCH" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "#@title display best protein {run: \"auto\"}\n", + "#@title display protein (optional) {run: \"auto\"}\n", + "show_best = True #@param {type:\"boolean\"}\n", + "show_idx = 0 #@param {type:\"integer\"}\n", + "#@markdown - Enter index of protein to show, if `show_best` is disabled.\n", + "#@markdown - Note: these are NOT sorted and correspond to \n", + "#@markdown the index in pandas dataframe above.\n", "color = \"pLDDT\" #@param [\"chain\", \"pLDDT\", \"rainbow\"]\n", "show_sidechains = False #@param {type:\"boolean\"}\n", "show_mainchains = False #@param {type:\"boolean\"}\n", "color_HP = False #@param {type:\"boolean\"}\n", - "animate = False #@param {type:\"boolean\"}\n", + "animate = True #@param {type:\"boolean\"}\n", "#@markdown - if `num_models` > 1, will iterate through the models when `animate` is enabled.\n", + "if not show_best:\n", + " pdb_str = pdb_to_string(f\"all_pdb/n{show_idx}.pdb\")\n", + "else:\n", + " pdb_str = None\n", "af_model.plot_pdb(show_sidechains=show_sidechains,\n", " show_mainchains=show_mainchains,\n", - " color=color, color_HP=color_HP, animate=animate)" + " color=color, color_HP=color_HP,\n", + " animate=animate, pdb_str=pdb_str)" ], "metadata": { "cellView": "form", @@ -322,12 +323,32 @@ { "cell_type": "code", "source": [ - "# get stats about best sequence\n", - "print(af_model.get_seq())\n", - "af_model._tmp[\"best\"][\"aux\"][\"log\"]" + "#@title animate (optional)\n", + "#@markdown Note: animation frames are sorted worst to best design\n", + "def sort_traj(self, metric=\"dgram_cce\"):\n", + " if metric in [\"plddt\",\"ptm\",\"i_ptm\",\"seqid\",\"composite\"]:\n", + " metric_higher_better = True\n", + " else:\n", + " metric_higher_better = False\n", + " num = len(self._tmp[\"traj\"][\"seq\"])\n", + " log = self._tmp[\"log\"][-num:]\n", + " if metric in log[0]:\n", + " n = np.array([x[metric] for x in log]).argsort()\n", + " if metric_higher_better: n = n[::-1]\n", + " sub_traj = {k:[v[m] for m in n] for k,v in self._tmp[\"traj\"].items()}\n", + " return sub_traj\n", + " else:\n", + " return None\n", + "\n", + "sub_traj= sort_traj(af_model)\n", + "\n", + "color_by = \"plddt\" #@param [\"chain\", \"plddt\", \"rainbow\"]\n", + "dpi = 100 #@param {type:\"integer\"}\n", + "HTML(af_model.animate(traj={k:v[::-1] for k,v in sub_traj.items()}, color_by=color_by, dpi=dpi))\n" ], "metadata": { - "id": "ZVCC7jCQW66r" + "cellView": "form", + "id": "74LAQHZGTZCH" }, "execution_count": null, "outputs": [] From 63680fd431c088848fdd88dc403eccde24f8c442 Mon Sep 17 00:00:00 2001 From: Sergey O Date: Wed, 16 Nov 2022 19:30:00 -0500 Subject: [PATCH 2/4] V1.1.0 (#101) * speedup alphafold weight download --- af/design.ipynb | 50 ++++++++++--------- af/examples/afdesign_hotspot_test.ipynb | 45 ++++++++--------- af/examples/disulfide_design.ipynb | 41 +++++++-------- af/examples/hallucination.ipynb | 43 ++++++++-------- af/examples/hallucination_custom_loss.ipynb | 41 +++++++-------- .../partial_hallucination_rewire.ipynb | 47 +++++++++-------- af/examples/peptide_binder_design.ipynb | 46 ++++++++--------- af/examples/use_esm_1b_bias.ipynb | 34 ++++++------- mpnn/README.md | 2 +- mpnn/examples/proteinmpnn_in_jax.ipynb | 48 ++++++++++++++---- 10 files changed, 203 insertions(+), 194 deletions(-) diff --git a/af/design.ipynb b/af/design.ipynb index a61195c0..06f6078d 100644 --- a/af/design.ipynb +++ b/af/design.ipynb @@ -33,29 +33,20 @@ }, "outputs": [], "source": [ - "#@title install\n", - "%%bash\n", - "if [ ! -d params ]; then\n", + "#@title setup\n", + "%%time\n", + "import os\n", + "if not os.path.isdir(\"params\"):\n", " # get code\n", - " pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.0\n", + " os.system(\"pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.0\")\n", " # for debugging\n", - " ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign\n", + " os.system(\"ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign\")\n", " # download params\n", - " mkdir params\n", - " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params\n", - "fi" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "Vt7G_nbNeSQ3" - }, - "outputs": [], - "source": [ - "#@title import libraries\n", + " os.system(\"mkdir params\")\n", + " os.system(\"apt-get install aria2 -qq\")\n", + " os.system(\"aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar\")\n", + " os.system(\"tar -xf alphafold_params_2022-03-02.tar -C params\")\n", + "\n", "import warnings\n", "warnings.simplefilter(action='ignore', category=FutureWarning)\n", "\n", @@ -65,16 +56,20 @@ "from google.colab import files\n", "import numpy as np\n", "\n", - "#########################\n", "def get_pdb(pdb_code=\"\"):\n", " if pdb_code is None or pdb_code == \"\":\n", " upload_dict = files.upload()\n", " pdb_string = upload_dict[list(upload_dict.keys())[0]]\n", " with open(\"tmp.pdb\",\"wb\") as out: out.write(pdb_string)\n", " return \"tmp.pdb\"\n", - " else:\n", + " elif os.path.isfile(pdb_code):\n", + " return pdb_code\n", + " elif len(pdb_code) == 4:\n", " os.system(f\"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb\")\n", - " return f\"{pdb_code}.pdb\"" + " return f\"{pdb_code}.pdb\"\n", + " else:\n", + " os.system(f\"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb\")\n", + " return f\"AF-{pdb_code}-F1-model_v3.pdb\"" ] }, { @@ -423,6 +418,15 @@ }, "execution_count": null, "outputs": [] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "rTGKbhsI0t8k" + }, + "execution_count": null, + "outputs": [] } ], "metadata": { diff --git a/af/examples/afdesign_hotspot_test.ipynb b/af/examples/afdesign_hotspot_test.ipynb index 47522fca..4415995c 100644 --- a/af/examples/afdesign_hotspot_test.ipynb +++ b/af/examples/afdesign_hotspot_test.ipynb @@ -28,25 +28,20 @@ }, "outputs": [], "source": [ - "#@title install\n", - "%%bash\n", - "if [ ! -d params ]; then\n", - " pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.0\n", - " mkdir params\n", - " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params\n", - "fi" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "Vt7G_nbNeSQ3" - }, - "outputs": [], - "source": [ - "#@title import libraries\n", + "#@title setup\n", + "%%time\n", + "import os\n", + "if not os.path.isdir(\"params\"):\n", + " # get code\n", + " os.system(\"pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.0\")\n", + " # for debugging\n", + " os.system(\"ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign\")\n", + " # download params\n", + " os.system(\"mkdir params\")\n", + " os.system(\"apt-get install aria2 -qq\")\n", + " os.system(\"aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar\")\n", + " os.system(\"tar -xf alphafold_params_2022-03-02.tar -C params\")\n", + "\n", "import warnings\n", "warnings.simplefilter(action='ignore', category=FutureWarning)\n", "\n", @@ -56,17 +51,21 @@ "from google.colab import files\n", "import numpy as np\n", "\n", - "#########################\n", "def get_pdb(pdb_code=\"\"):\n", " if pdb_code is None or pdb_code == \"\":\n", " upload_dict = files.upload()\n", " pdb_string = upload_dict[list(upload_dict.keys())[0]]\n", " with open(\"tmp.pdb\",\"wb\") as out: out.write(pdb_string)\n", " return \"tmp.pdb\"\n", - " else:\n", + " elif os.path.isfile(pdb_code):\n", + " return pdb_code\n", + " elif len(pdb_code) == 4:\n", " os.system(f\"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb\")\n", - " return f\"{pdb_code}.pdb\"" - ] + " return f\"{pdb_code}.pdb\"\n", + " else:\n", + " os.system(f\"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb\")\n", + " return f\"AF-{pdb_code}-F1-model_v3.pdb\"" + ] }, { "cell_type": "code", diff --git a/af/examples/disulfide_design.ipynb b/af/examples/disulfide_design.ipynb index 96ce9b52..dd14f3ed 100644 --- a/af/examples/disulfide_design.ipynb +++ b/af/examples/disulfide_design.ipynb @@ -29,29 +29,20 @@ }, "outputs": [], "source": [ - "#@title install\n", - "%%bash\n", - "if [ ! -d params ]; then\n", + "#@title setup\n", + "%%time\n", + "import os\n", + "if not os.path.isdir(\"params\"):\n", " # get code\n", - " pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.0\n", + " os.system(\"pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.0\")\n", " # for debugging\n", - " ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign\n", + " os.system(\"ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign\")\n", " # download params\n", - " mkdir params\n", - " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params\n", - "fi" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "Vt7G_nbNeSQ3" - }, - "outputs": [], - "source": [ - "#@title import libraries\n", + " os.system(\"mkdir params\")\n", + " os.system(\"apt-get install aria2 -qq\")\n", + " os.system(\"aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar\")\n", + " os.system(\"tar -xf alphafold_params_2022-03-02.tar -C params\")\n", + "\n", "import warnings\n", "warnings.simplefilter(action='ignore', category=FutureWarning)\n", "\n", @@ -61,16 +52,20 @@ "from google.colab import files\n", "import numpy as np\n", "\n", - "#########################\n", "def get_pdb(pdb_code=\"\"):\n", " if pdb_code is None or pdb_code == \"\":\n", " upload_dict = files.upload()\n", " pdb_string = upload_dict[list(upload_dict.keys())[0]]\n", " with open(\"tmp.pdb\",\"wb\") as out: out.write(pdb_string)\n", " return \"tmp.pdb\"\n", - " else:\n", + " elif os.path.isfile(pdb_code):\n", + " return pdb_code\n", + " elif len(pdb_code) == 4:\n", " os.system(f\"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb\")\n", - " return f\"{pdb_code}.pdb\"" + " return f\"{pdb_code}.pdb\"\n", + " else:\n", + " os.system(f\"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb\")\n", + " return f\"AF-{pdb_code}-F1-model_v3.pdb\"" ] }, { diff --git a/af/examples/hallucination.ipynb b/af/examples/hallucination.ipynb index b02b635e..fc21fe43 100644 --- a/af/examples/hallucination.ipynb +++ b/af/examples/hallucination.ipynb @@ -33,25 +33,20 @@ }, "outputs": [], "source": [ - "#@title install\n", - "%%bash\n", - "if [ ! -d params ]; then\n", - " pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.0\n", - " mkdir params\n", - " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params\n", - "fi" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "XVWobK8Tsgju" - }, - "outputs": [], - "source": [ - "#@title #import libraries\n", + "#@title setup\n", + "%%time\n", + "import os\n", + "if not os.path.isdir(\"params\"):\n", + " # get code\n", + " os.system(\"pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.0\")\n", + " # for debugging\n", + " os.system(\"ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign\")\n", + " # download params\n", + " os.system(\"mkdir params\")\n", + " os.system(\"apt-get install aria2 -qq\")\n", + " os.system(\"aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar\")\n", + " os.system(\"tar -xf alphafold_params_2022-03-02.tar -C params\")\n", + "\n", "import warnings\n", "warnings.simplefilter(action='ignore', category=FutureWarning)\n", "\n", @@ -61,16 +56,20 @@ "from google.colab import files\n", "import numpy as np\n", "\n", - "#########################\n", "def get_pdb(pdb_code=\"\"):\n", " if pdb_code is None or pdb_code == \"\":\n", " upload_dict = files.upload()\n", " pdb_string = upload_dict[list(upload_dict.keys())[0]]\n", " with open(\"tmp.pdb\",\"wb\") as out: out.write(pdb_string)\n", " return \"tmp.pdb\"\n", - " else:\n", + " elif os.path.isfile(pdb_code):\n", + " return pdb_code\n", + " elif len(pdb_code) == 4:\n", " os.system(f\"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb\")\n", - " return f\"{pdb_code}.pdb\"" + " return f\"{pdb_code}.pdb\"\n", + " else:\n", + " os.system(f\"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb\")\n", + " return f\"AF-{pdb_code}-F1-model_v3.pdb\"" ] }, { diff --git a/af/examples/hallucination_custom_loss.ipynb b/af/examples/hallucination_custom_loss.ipynb index b4ab1882..654413b6 100644 --- a/af/examples/hallucination_custom_loss.ipynb +++ b/af/examples/hallucination_custom_loss.ipynb @@ -29,29 +29,20 @@ }, "outputs": [], "source": [ - "#@title install\n", - "%%bash\n", - "if [ ! -d params ]; then\n", + "#@title setup\n", + "%%time\n", + "import os\n", + "if not os.path.isdir(\"params\"):\n", " # get code\n", - " pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.0\n", + " os.system(\"pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.0\")\n", " # for debugging\n", - " ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign\n", + " os.system(\"ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign\")\n", " # download params\n", - " mkdir params\n", - " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params\n", - "fi" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "Vt7G_nbNeSQ3" - }, - "outputs": [], - "source": [ - "#@title import libraries\n", + " os.system(\"mkdir params\")\n", + " os.system(\"apt-get install aria2 -qq\")\n", + " os.system(\"aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar\")\n", + " os.system(\"tar -xf alphafold_params_2022-03-02.tar -C params\")\n", + "\n", "import warnings\n", "warnings.simplefilter(action='ignore', category=FutureWarning)\n", "\n", @@ -61,16 +52,20 @@ "from google.colab import files\n", "import numpy as np\n", "\n", - "#########################\n", "def get_pdb(pdb_code=\"\"):\n", " if pdb_code is None or pdb_code == \"\":\n", " upload_dict = files.upload()\n", " pdb_string = upload_dict[list(upload_dict.keys())[0]]\n", " with open(\"tmp.pdb\",\"wb\") as out: out.write(pdb_string)\n", " return \"tmp.pdb\"\n", - " else:\n", + " elif os.path.isfile(pdb_code):\n", + " return pdb_code\n", + " elif len(pdb_code) == 4:\n", " os.system(f\"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb\")\n", - " return f\"{pdb_code}.pdb\"" + " return f\"{pdb_code}.pdb\"\n", + " else:\n", + " os.system(f\"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb\")\n", + " return f\"AF-{pdb_code}-F1-model_v3.pdb\"" ] }, { diff --git a/af/examples/partial_hallucination_rewire.ipynb b/af/examples/partial_hallucination_rewire.ipynb index e657b20e..be77158a 100644 --- a/af/examples/partial_hallucination_rewire.ipynb +++ b/af/examples/partial_hallucination_rewire.ipynb @@ -28,44 +28,43 @@ }, "outputs": [], "source": [ - "#@title install\n", - "%%bash\n", - "if [ ! -d params ]; then\n", - " pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.0\n", - " mkdir params\n", - " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params\n", - "fi" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "Vt7G_nbNeSQ3" - }, - "outputs": [], - "source": [ - "#@title import libraries\n", + "#@title setup\n", + "%%time\n", + "import os\n", + "if not os.path.isdir(\"params\"):\n", + " # get code\n", + " os.system(\"pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.0\")\n", + " # for debugging\n", + " os.system(\"ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign\")\n", + " # download params\n", + " os.system(\"mkdir params\")\n", + " os.system(\"apt-get install aria2 -qq\")\n", + " os.system(\"aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar\")\n", + " os.system(\"tar -xf alphafold_params_2022-03-02.tar -C params\")\n", + "\n", "import warnings\n", "warnings.simplefilter(action='ignore', category=FutureWarning)\n", "\n", - "import os, re\n", - "from colabdesign.af import mk_afdesign_model, clear_mem\n", + "import os\n", + "from colabdesign import mk_afdesign_model, clear_mem\n", "from IPython.display import HTML\n", "from google.colab import files\n", "import numpy as np\n", "\n", - "#########################\n", "def get_pdb(pdb_code=\"\"):\n", " if pdb_code is None or pdb_code == \"\":\n", " upload_dict = files.upload()\n", " pdb_string = upload_dict[list(upload_dict.keys())[0]]\n", " with open(\"tmp.pdb\",\"wb\") as out: out.write(pdb_string)\n", " return \"tmp.pdb\"\n", - " else:\n", + " elif os.path.isfile(pdb_code):\n", + " return pdb_code\n", + " elif len(pdb_code) == 4:\n", " os.system(f\"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb\")\n", - " return f\"{pdb_code}.pdb\"" + " return f\"{pdb_code}.pdb\"\n", + " else:\n", + " os.system(f\"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb\")\n", + " return f\"AF-{pdb_code}-F1-model_v3.pdb\"" ] }, { diff --git a/af/examples/peptide_binder_design.ipynb b/af/examples/peptide_binder_design.ipynb index 40ae1c2c..1b2ff67d 100644 --- a/af/examples/peptide_binder_design.ipynb +++ b/af/examples/peptide_binder_design.ipynb @@ -33,26 +33,19 @@ }, "outputs": [], "source": [ - "#@title install\n", - "%%bash\n", - "if [ ! -d params ]; then\n", - " pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.0\n", - " ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign\n", - " mkdir params\n", - " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params\n", - "fi" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "Vt7G_nbNeSQ3" - }, - "outputs": [], - "source": [ - "#@title import libraries\n", + "#@title **setup**\n", + "import os\n", + "if not os.path.isdir(\"params\"):\n", + " # get code\n", + " os.system(\"pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.0\")\n", + " # for debugging\n", + " os.system(\"ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign\")\n", + " # download params\n", + " os.system(\"mkdir params\")\n", + " os.system(\"apt-get install aria2 -qq\")\n", + " os.system(\"aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar\")\n", + " os.system(\"tar -xf alphafold_params_2022-03-02.tar -C params\")\n", + "\n", "import warnings\n", "warnings.simplefilter(action='ignore', category=FutureWarning)\n", "\n", @@ -72,13 +65,14 @@ " pdb_string = upload_dict[list(upload_dict.keys())[0]]\n", " with open(\"tmp.pdb\",\"wb\") as out: out.write(pdb_string)\n", " return \"tmp.pdb\"\n", + " elif os.path.isfile(pdb_code):\n", + " return pdb_code\n", + " elif len(pdb_code) == 4:\n", + " os.system(f\"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb\")\n", + " return f\"{pdb_code}.pdb\"\n", " else:\n", - " if len(pdb_code) == 4:\n", - " os.system(f\"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb\")\n", - " return f\"{pdb_code}.pdb\"\n", - " else:\n", - " os.system(f\"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb\")\n", - " return f\"AF-{pdb_code}-F1-model_v3.pdb\"" + " os.system(f\"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb\")\n", + " return f\"AF-{pdb_code}-F1-model_v3.pdb\"" ] }, { diff --git a/af/examples/use_esm_1b_bias.ipynb b/af/examples/use_esm_1b_bias.ipynb index e5dc5e69..faa12ea3 100644 --- a/af/examples/use_esm_1b_bias.ipynb +++ b/af/examples/use_esm_1b_bias.ipynb @@ -195,25 +195,20 @@ }, "outputs": [], "source": [ - "#@title install\n", - "%%bash\n", - "if [ ! -d params ]; then\n", - " pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.0\n", - " mkdir params\n", - " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params\n", - "fi" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Vt7G_nbNeSQ3", - "cellView": "form" - }, - "outputs": [], - "source": [ - "#@title import libraries\n", + "#@title setup afdesign\n", + "%%time\n", + "import os\n", + "if not os.path.isdir(\"params\"):\n", + " # get code\n", + " os.system(\"pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.0\")\n", + " # for debugging\n", + " os.system(\"ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign\")\n", + " # download params\n", + " os.system(\"mkdir params\")\n", + " os.system(\"apt-get install aria2 -qq\")\n", + " os.system(\"aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar\")\n", + " os.system(\"tar -xf alphafold_params_2022-03-02.tar -C params\")\n", + "\n", "import warnings\n", "warnings.simplefilter(action='ignore', category=FutureWarning)\n", "\n", @@ -224,6 +219,7 @@ "import numpy as np\n", "\n" ] + }, { "cell_type": "code", diff --git a/mpnn/README.md b/mpnn/README.md index 35e6fe9d..64d802a9 100644 --- a/mpnn/README.md +++ b/mpnn/README.md @@ -57,7 +57,7 @@ mpnn_model._inputs["bias"][0,aa_order["A"]] = 1.0 For example, if you want to add a hydrophilic bias to all positions, you can do: ```python for k in "DEHKNQRSTWY": - mpnn_model._inputs["bias"][:,aa_order[k]] = 1.39 + mpnn_model._inputs["bias"][:,aa_order[k]] += 1.39 ``` #### How about tied sampling for homo-oligomeric complexes? ```python diff --git a/mpnn/examples/proteinmpnn_in_jax.ipynb b/mpnn/examples/proteinmpnn_in_jax.ipynb index 8254146d..d1751edd 100644 --- a/mpnn/examples/proteinmpnn_in_jax.ipynb +++ b/mpnn/examples/proteinmpnn_in_jax.ipynb @@ -76,16 +76,20 @@ "from google.colab import data_table\n", "data_table.enable_dataframe_formatter()\n", "\n", - "\n", "def get_pdb(pdb_code=\"\"):\n", " if pdb_code is None or pdb_code == \"\":\n", " upload_dict = files.upload()\n", " pdb_string = upload_dict[list(upload_dict.keys())[0]]\n", " with open(\"tmp.pdb\",\"wb\") as out: out.write(pdb_string)\n", " return \"tmp.pdb\"\n", - " else:\n", + " elif os.path.isfile(pdb_code):\n", + " return pdb_code\n", + " elif len(pdb_code) == 4:\n", " os.system(f\"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb\")\n", - " return f\"{pdb_code}.pdb\"" + " return f\"{pdb_code}.pdb\"\n", + " else:\n", + " os.system(f\"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb\")\n", + " return f\"AF-{pdb_code}-F1-model_v3.pdb\"" ], "metadata": { "cellView": "form", @@ -103,6 +107,8 @@ "import warnings, os, re\n", "warnings.simplefilter(action='ignore', category=FutureWarning)\n", "\n", + "os.system(\"mkdir -p output\")\n", + "\n", "# USER OPTIONS\n", "#@markdown #### ProteinMPNN options\n", "model_name = \"v_48_020\" #@param [\"v_48_002\", \"v_48_010\", \"v_48_020\", \"v_48_030\"]\n", @@ -155,6 +161,7 @@ "data = [[out[k][n] for k in labels] for n in range(num_seqs)]\n", "\n", "df = pd.DataFrame(data, columns=labels)\n", + "df.to_csv('output/mpnn_results.csv')\n", "data_table.DataTable(df.round(3))" ], "metadata": { @@ -199,6 +206,7 @@ " pdb_labels = None\n", "\n", "pssm = softmax(logits,-1)\n", + "np.savetxt(\"output/pssm.txt\",pssm)\n", "\n", "fig = px.imshow(np.array(pssm).T,\n", " labels=dict(x=\"positions\", y=\"amino acids\", color=\"probability\"),\n", @@ -230,11 +238,15 @@ "rm_template_interchain = False #@param {type:\"boolean\"}\n", "if not os.path.isdir(\"params\"):\n", " os.system(\"mkdir params\")\n", - " os.system(\"curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params\")\n", + " os.system(\"apt-get install aria2 -qq\")\n", + " os.system(\"aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar\")\n", + " os.system(\"tar -xf alphafold_params_2022-03-02.tar -C params\")\n", "\n", "# where pdb files will be save:\n", - "if not os.path.isdir(\"all_pdb\"): os.system(\"mkdir all_pdb\")\n", - "else: os.system(\"rm all_pdb/*\")\n", + "if not os.path.isdir(\"output/all_pdb\"):\n", + " os.system(\"mkdir output/all_pdb\")\n", + "else:\n", + " os.system(\"rm output/all_pdb/*\")\n", "\n", "from colabdesign.af import mk_af_model\n", "af_args = [pdb_path, chains, homooligomer,\n", @@ -259,11 +271,11 @@ " (rmsd, ptm, plddt) = (af_model.aux[\"log\"][k] for k in [\"rmsd\",\"ptm\",\"plddt\"])\n", " af_model.aux[\"log\"][\"composite\"] = ptm * plddt\n", " af_model._save_results(save_best=True, verbose=False)\n", - " af_model.save_current_pdb(f\"all_pdb/n{n}.pdb\")\n", + " af_model.save_current_pdb(f\"output/all_pdb/n{n}.pdb\")\n", " af_model._k += 1\n", " pbar.update(1)\n", "\n", - "af_model.save_pdb(f\"best.pdb\")\n", + "af_model.save_pdb(f\"output/best.pdb\")\n", "\n", "data = []\n", "labels = [\"dgram_cce\",\"plddt\",\"ptm\",\"i_ptm\",\"rmsd\",\"composite\",\"mpnn\",\"seqid\",\"seq\"]\n", @@ -279,8 +291,9 @@ " out[\"seq\"][n]])\n", "\n", "df = pd.DataFrame(data, columns=labels)\n", + "df.to_csv('output/alphafold_results.csv')\n", "data_table.DataTable(df.sort_values(\"dgram_cce\").round(3))\n", - "#@markdown Note: designed pdbs are saved to `all_pdb/`" + "#@markdown Note: designed pdbs are saved to `output/all_pdb/`" ], "metadata": { "cellView": "form", @@ -289,6 +302,21 @@ "execution_count": null, "outputs": [] }, + { + "cell_type": "code", + "source": [ + "#@title download predictions (optional)\n", + "from google.colab import files\n", + "os.system(f\"zip -r output.zip output/\")\n", + "files.download(f'output.zip')" + ], + "metadata": { + "cellView": "form", + "id": "ZOtuzwwUAgHj" + }, + "execution_count": null, + "outputs": [] + }, { "cell_type": "code", "source": [ @@ -305,7 +333,7 @@ "animate = True #@param {type:\"boolean\"}\n", "#@markdown - if `num_models` > 1, will iterate through the models when `animate` is enabled.\n", "if not show_best:\n", - " pdb_str = pdb_to_string(f\"all_pdb/n{show_idx}.pdb\")\n", + " pdb_str = pdb_to_string(f\"output/all_pdb/n{show_idx}.pdb\")\n", "else:\n", " pdb_str = None\n", "af_model.plot_pdb(show_sidechains=show_sidechains,\n", From 6ee5f32deaae315fff9707ed2a3307f00a8e0a7f Mon Sep 17 00:00:00 2001 From: Sergey O Date: Thu, 17 Nov 2022 16:38:44 -0500 Subject: [PATCH 3/4] updating plddt and pae loss to be consistent --- colabdesign/af/loss.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/colabdesign/af/loss.py b/colabdesign/af/loss.py index dd43a430..feec6d38 100644 --- a/colabdesign/af/loss.py +++ b/colabdesign/af/loss.py @@ -229,14 +229,12 @@ def get_exp_res_loss(outputs, mask_1d=None): return mask_loss(p, mask_1d) def get_plddt_loss(outputs, mask_1d=None): - p = jax.nn.softmax(outputs["predicted_lddt"]["logits"]) - p = (p * jnp.arange(p.shape[-1])[::-1]).mean(-1) + p = 1 - get_plddt(outputs) return mask_loss(p, mask_1d) def get_pae_loss(outputs, mask_1d=None, mask_1b=None, mask_2d=None): - p = jax.nn.softmax(outputs["predicted_aligned_error"]["logits"]) - p = (p * jnp.arange(p.shape[-1])).mean(-1) - p = (p + p.T)/2 + p = 1 - (get_pae(outputs) / 31.0) + p = (p + p.T) / 2 L = p.shape[0] if mask_1d is None: mask_1d = jnp.ones(L) if mask_1b is None: mask_1b = jnp.ones(L) From 674a4302701080047f18f776df917f3132b0f373 Mon Sep 17 00:00:00 2001 From: Sergey O Date: Thu, 17 Nov 2022 16:46:48 -0500 Subject: [PATCH 4/4] Update loss.py --- colabdesign/af/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colabdesign/af/loss.py b/colabdesign/af/loss.py index feec6d38..75beed39 100644 --- a/colabdesign/af/loss.py +++ b/colabdesign/af/loss.py @@ -233,7 +233,7 @@ def get_plddt_loss(outputs, mask_1d=None): return mask_loss(p, mask_1d) def get_pae_loss(outputs, mask_1d=None, mask_1b=None, mask_2d=None): - p = 1 - (get_pae(outputs) / 31.0) + p = get_pae(outputs) / 31.0 p = (p + p.T) / 2 L = p.shape[0] if mask_1d is None: mask_1d = jnp.ones(L)