diff --git a/README.md b/README.md index 1d0752c..c88beaa 100644 --- a/README.md +++ b/README.md @@ -1 +1,90 @@ # pyRIF +Using Rotamer Interaction Fields from RIFGen/Dock in python + + +### Installation +`pip install pyRIF` + +### Pose Example +Moves your input pose into the region of the RIF +``` +import glob +import pyrosetta +from pyrosetta.rosetta.core.select.residue_selector import ChainSelector + +from pyrif import RotamerInteractionField + +pyrosetta.init() + + +# dictionary pointing to RIFGen outputs +L_AA_RIF = { + 'HDF5' : '/path/to/py_rif.h5', + 'rots' : '/path/to/rotamer_index_spec.txt', + 'target' : '/path/to/target.pdb.gz', +} + + +# residue selectors to select target and binder residues +target_selector = ChainSelector('A') +binder_selector = ChainSelector('B') + + +# create RIF object outside of loop +RIF = RotamerInteractionField( + L_AA_RIF_kwargs=L_AA_RIF, + residue_selector=binder_selector, + target_selector=target_selector, +) + +for PDB in glob.iglob('/path/to/pdbs/*.pdb'): + pose = pyrosetta.io.pose_from_pdb(PDB) + + # apply the RIF + STATUS, RIF_score, sequence_mapping = RIF.apply(pose) + + if STATUS: + print(f'pass, {RIF_SCORE}\n{sequence_mapping}') + # continue with remainder of protocol + else: + print('fail') + +``` + +### Numpy Example +Assumes your XYZs are already within the region of the RIF +``` +import glob +import numpy as np +import pyrosetta + +from pyrif import RotamerInteractionField +pyrosetta.init() + + +# dictionary pointing to RIFGen outputs +L_AA_RIF = { + 'HDF5' : '/path/to/py_rif.h5', + 'rots' : '/path/to/rotamer_index_spec.txt', + 'target' : '/path/to/target.pdb.gz', +} + + +# create RIF object outside of loop +RIF = RotamerInteractionField( + L_AA_RIF_kwargs=L_AA_RIF, +) + +binder_xyzs = np.random.rand(1000, 100, 3, 3)# [(1000 proteins), (100 residues), (N CA C), (X Y Z)] + +for i in range(binder_xyzs.shape[0]): + + STATUS, RIF_score, sequence_mapping = RIF.search_xyzs(binder_xyzs[i, :, :, :]) + + if STATUS: + print(f'pass, {RIF_SCORE}\n{sequence_mapping}') + # continue with remainder of protocol + else: + print('fail') + +``` diff --git a/example/example.py b/example/example.py index ce38052..819b989 100644 --- a/example/example.py +++ b/example/example.py @@ -85,8 +85,6 @@ if STATUS: print(f'rif score: {RIF_score}\n{sequence_mapping}') - -print('D_RIF ONLY') # RIF just with canonical D-AAs RIF = RotamerInteractionField( D_AA_RIF_kwargs=D_AA_RIF, diff --git a/pyRIF/pyRIF.py b/pyRIF/pyRIF.py index c51310d..722014d 100644 --- a/pyRIF/pyRIF.py +++ b/pyRIF/pyRIF.py @@ -28,9 +28,10 @@ class RotamerInteractionField(object): load_RIF_Table(py_rif.h5, rotamer_index_spec.txt) -> python dictionary of the RIF table RIF_stubs(Ns, CAs, Cs,) -> the (4,4) homogenous transforms of the corresponding cooridnates superimpose(target, moving) -> aligns moving onto target - get_hits(L_offsets, D_offsets) -> given L and D offsets, returns the score_list and irot list of hit. Run prepacking + get_irots(L_offsets, D_offsets) -> given L and D offsets, returns the score_list and irot list of hit. Run prepacking fast_pack() -> given lots of things, does voxel-based clash packing. returns a sequence map and total rif score - apply(pose) -> Returns MS_FAIL/SUCCEED if pose passes the minimum requirements + search_xyzs(xyzs) -> given numpy array of NCAC XYZs, returns True/False, RIF_score, sequence_mapping + apply(pose) -> given pose, gets XYZs, runs search_xyzs() ''' def __init__( self, @@ -135,7 +136,7 @@ def load_RIF_Table( line.strip().split('\t') ) - # initialize the things we need t obuild residues + # initialize the things we need to build residues chemical_manager = ChemicalManager.get_instance() ResidueTypeSet = chemical_manager.residue_type_set( 'fa_standard' ) @@ -170,7 +171,8 @@ def load_RIF_Table( # NeRF to get the DOFs frame_xyz = np.array(frame_xyzs)[np.newaxis, :, :] sc_xyz = np.array(side_chain_XYZs)[np.newaxis, :, :] - sc_dof = nerf.iNeRF(frame_xyz, sc_xyz) + res_xyz = np.concatenate((frame_xyz, sc_xyz), axis=1) + sc_dof = nerf.iNeRF(res_xyz) sc_dof_list.append(sc_dof) assert len(rotamer_lines) == len(sc_dof_list) @@ -274,7 +276,7 @@ def RIF_stubs(self, Ns, CAs, Cs): frames[:,:3,3] = t return frames - def get_hits( + def get_irots( self, L_offsets, D_offsets, @@ -356,8 +358,7 @@ def fast_pack( This is VERY hacky... ''' - - + ###### define 1 body and 2 body packing def _build_RIF_hit( hit_frame, @@ -384,11 +385,12 @@ def _build_RIF_hit( C_N_CA_frame[:,1,:] = hit_frame[:,0,:] C_N_CA_frame[:,2,:] = hit_frame[:,1,:] - SC_xyz = nerf.NeRF(C_N_CA_frame, rotamer_dofs) + SC_xyz = nerf.NeRF(rotamer_dofs, abcs=C_N_CA_frame) # clash check SC_xyzs against macrocycle_XYZ + # we only send [:,3:,:] as first 3 atoms are backbone. SC_bins = bin_coordinates( - SC_xyz, + SC_xyz[:,3:, :], pert_mag=pert_mag, shake_repeats=1000, resl=resl, @@ -425,17 +427,18 @@ def _two_body_clash( C_N_CA_frame_j[:,1,:] = hit_frame_j[:,0,:] C_N_CA_frame_j[:,2,:] = hit_frame_j[:,1,:] - i_xyz = nerf.NeRF(C_N_CA_frame_i, rotamer_dofs_i) - j_xyz = nerf.NeRF(C_N_CA_frame_j, rotamer_dofs_j) + i_xyz = nerf.NeRF(rotamer_dofs_i, abcs=C_N_CA_frame_i) + j_xyz = nerf.NeRF(rotamer_dofs_j, abcs=C_N_CA_frame_j) + # we only send [:,3:,:] as first 3 atoms are backbone. bins_i = bin_coordinates( - i_xyz, + i_xyz[:,3:,:], pert_mag=pert_mag, shake_repeats=1000, resl=resl, ) bins_j = bin_coordinates( - j_xyz, + j_xyz[:,3:,:], pert_mag=pert_mag, shake_repeats=1000, resl=resl, @@ -546,58 +549,29 @@ def _two_body_clash( # get the SCORE of the best interactions return TOTAL_RIF_SCORE, SEQ_MAP - def apply(self, pose): + def search_xyzs( + self, + xyzs, + L_idx=None, + D_idx=None, + binder_idx=None, + ): ''' - given a rosetta pose object, searches RIF - - if pass, returns True, rif_score, sequence_mapping - if fail, returns False, None, None - + Given numpy array of xyzss (Mres X 3 X 3) of [Mres x (N CA C) x (X Y Z)] + returns if there are hits that meet score thresholds, and the sequence mapping ''' - # 1. align pose to the RIF target pose - if self.L_AA_RIF['target'] is not None: - pose = self.superimpose(self.L_AA_RIF['target'] , pose) - elif self.D_AA_RIF['target'] is not None: - pose = self.superimpose(self.L_AA_RIF['target'] , pose) - else: - raise ValueError('RIF target pose not defined') + if L_idx is None: L_idx = np.arange(xyzs.shape[0]) + if D_idx is None: D_idx = np.array([])# assuming most people wont search D rifs + if binder_idx is None: binder_idx = np.arange(xyzs.shape[0]) - # 2. get indecies for L and D residues in the pose, and binder - L_idx = np.squeeze( - np.argwhere( - np.logical_and( - self.residue_selector.apply(pose), - self.L_selector.apply(pose) - ), - ), - axis=1, - ) - D_idx = np.squeeze( - np.argwhere( - np.logical_and( - self.residue_selector.apply(pose), - self.D_selector.apply(pose), - ), - ), - axis=1, - ) - binder_idx = np.squeeze( - np.argwhere( - self.residue_selector.apply(pose), - ), - axis=1, - ) - - # 3. get all of the xyzs of atoms in atom_selector in the pose - xyzs = np.array([[residue.atom(atom).xyz() for atom in self.atom_selector] for residue in pose.residues]) - # 4. get the hashed stubs of the stubs, based on L or D rif + # 1. get the hashed stubs of the stubs, based on L or D rif L_stubs = self.RIF_stubs(xyzs[L_idx,0,:], xyzs[L_idx,1,:], xyzs[L_idx,2,:]) D_stubs = self.RIF_stubs(xyzs[D_idx,0,:], xyzs[D_idx,1,:], xyzs[D_idx,2,:]) L_bins = self.L_AA_RIF['binner'].get_bin_index(L_stubs) D_bins = self.D_AA_RIF['binner'].get_bin_index(D_stubs) - # 5. look up if there are any hits for L and D RIFs + # 2. look up if there are any hits for L and D RIFs L_matching_keys = self.L_AA_RIF['rif_dict'].contains(L_bins) D_matching_keys = self.D_AA_RIF['rif_dict'].contains(D_bins) @@ -607,21 +581,20 @@ def apply(self, pose): D_keys = D_bins[D_matching_keys] D_offsets = self.D_AA_RIF['rif_dict'][D_bins] - # 6. test if we have reached the thresholds for number of hits at this point + # 3. test if we have reached the thresholds for number of hits at this point if np.count_nonzero(L_offsets) + np.count_nonzero(D_offsets) <= self.min_RIF_hits - 1: return False, None, None - # 7. get the specific irots for each hit positon - rotamer_list, score_list = self.get_hits(L_offsets, D_offsets) + # 4. get the specific irots for each hit positon + rotamer_list, score_list = self.get_irots(L_offsets, D_offsets) - # 8. get the L and D frames to rebuild the irots + # 5. get the L and D frames to rebuild the irots LD_frames = np.concatenate( (xyzs[L_idx,:,:], xyzs[D_idx,:,:]), axis=0, ) - # 9. pack + # 6. pack assert len(rotamer_list) == len(score_list) == LD_frames.shape[0] - total_rif_score, sequence_mapping = self.fast_pack( list_of_rotamers=rotamer_list, scores=score_list, @@ -635,17 +608,71 @@ def apply(self, pose): rotamer_clash_tolerance=self.fast_pack_params['rotamer_clash_tolerance'], ) - # 10. filter on total rif score & number of hits post pack + # 7. filter on total rif score & number of hits post pack if total_rif_score > self.min_RIF_score or len(sequence_mapping) < self.min_RIF_hits: return False, None, None - # 11. convert sequence mapping to residue index + # 8. convert sequence mapping to residue index LD_idx = np.concatenate((L_idx, D_idx), axis=0) for i in range(len(sequence_mapping)): sequence_mapping[i][0] = LD_idx[i] + 1# add 1 to get us to Rosetta 1-indexed arrays - # 12. return total_score, sequence_mapping + # 9. return total_score, sequence_mapping return True, total_rif_score, sequence_mapping + def apply(self, pose): + ''' + given a rosetta pose object, gets numpy arrays, searches RIF + + if pass, returns True, rif_score, sequence_mapping + if fail, returns False, None, None + + ''' + # 1. align pose to the RIF target pose + #if self.L_AA_RIF['target'] is not None: + # pose = self.superimpose(self.L_AA_RIF['target'] , pose) + #elif self.D_AA_RIF['target'] is not None: + # pose = self.superimpose(self.L_AA_RIF['target'] , pose) + #else: + # raise ValueError('RIF target pose not defined') + + # 2. get indecies for L and D residues in the pose, and binder + L_idx = np.squeeze( + np.argwhere( + np.logical_and( + self.residue_selector.apply(pose), + self.L_selector.apply(pose) + ), + ), + axis=1, + ) + D_idx = np.squeeze( + np.argwhere( + np.logical_and( + self.residue_selector.apply(pose), + self.D_selector.apply(pose), + ), + ), + axis=1, + ) + binder_idx = np.squeeze( + np.argwhere( + self.residue_selector.apply(pose), + ), + axis=1, + ) + + # 3. get all of the xyzs of atoms in atom_selector in the pose + xyzs = np.array([[residue.atom(atom).xyz() for atom in self.atom_selector] for residue in pose.residues]) + + # 4. search xyzs in RIF + return self.search_xyzs( + xyzs, + L_idx=L_idx, + D_idx=D_idx, + binder_idx=binder_idx, + ) + + def check_bins( bin_set1, bin_set2, @@ -662,7 +689,6 @@ def check_bins( It does not matter which order the sets are given """ - return np.sum(bin_set1.contains(bin_set2.items())) def bin_coordinates( @@ -681,9 +707,13 @@ def bin_coordinates( returns: a getpy Set containing the bins of the coordinates as byte8array3 """ + # 1. Add 0 vector into XYZ -> XYZ0 + expanded_xyzs = np.empty((xyz_array.shape[0], xyz_array.shape[1], 4)) + expanded_xyzs[:,:,:3] = xyz_array[...] + expanded_xyzs[:,:,-1] = 0.0 - # First. make some number of copies of the coords - shaken_coords = np.repeat(xyz_array, shake_repeats, axis=0) + # 2. make some number of copies of the coords + shaken_coords = np.repeat(expanded_xyzs, shake_repeats, axis=0) # Second. make some number of random perturbations as NxMx3 matrix random_perturbation = np.random.uniform( low=-pert_mag, @@ -691,12 +721,13 @@ def bin_coordinates( size=(shake_repeats*shaken_coords.shape[1]*3) ).reshape(shake_repeats, shaken_coords.shape[1], 3) - # Third. apply the random perturbations to each "trajectory" - shaken_coords += random_perturbation - # Fourth. round the coordinates as view as int, flatten on the first axis - rounded_coords = np.around(shaken_coords / resl, decimals=0).astype(np.int64).flatten().view(gp.types['byte8array3']) - # Fifth. set the coordinates in a getpy set - coord_set = gp.Set(gp.types['byte8array3']) + # 3. apply the random perturbations to each "trajectory" + shaken_coords[:,:,:3] += random_perturbation + + # 4. round the coordinates as view as int, flatten on the first axis + rounded_coords = np.around(shaken_coords / resl, decimals=0).astype(np.int32).flatten().view(np.dtype('S16')) + # 5. set the coordinates in a getpy set + coord_set = gp.Set(np.dtype('S16')) coord_set.add(rounded_coords) return coord_set diff --git a/setup.py b/setup.py index 56d1c88..7097ad6 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ description = 'Using Rotamer Interaction Fields from RIFGen/Dock in python', packages = ['pyRIF'], package_dir={'pyRIF' : 'pyRIF'}, - install_requires = ['pyrosetta', 'xbin', 'numpy', 'getpy', 'h5py'], + install_requires = ['pyrosetta', 'xbin', 'numpy', 'getpy', 'pynerf', 'h5py'], zip_safe = False, long_description=readme, long_description_content_type='text/markdown',