Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
update main function
remove get_output_sdf
  • Loading branch information
lpardey committed Jan 8, 2024
1 parent 859815a commit 09af0e1
Showing 1 changed file with 16 additions and 29 deletions.
45 changes: 16 additions & 29 deletions rdock-utils/rdock_utils/sdrmsd.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import argparse
import logging
import math
import os
import sys

import numpy
Expand Down Expand Up @@ -230,7 +229,6 @@ def get_automorphism_rmsd(target: pybel.Molecule, molecule: pybel.Molecule, fit:
# Update result if the current mapping has a lower RMSD
if mapping_rmsd < result_rmsd:
result_rmsd = mapping_rmsd
fitted_result = False

# Additional fitting if fit=True
if fit:
Expand All @@ -239,7 +237,6 @@ def get_automorphism_rmsd(target: pybel.Molecule, molecule: pybel.Molecule, fit:
# Update result if the fitted RMSD is lower
if fitted_rmsd < result_rmsd:
result_rmsd = fitted_rmsd
fitted_result = fitted_pose

return (result_rmsd, fitted_pose) if fit else result_rmsd

Expand Down Expand Up @@ -276,13 +273,10 @@ def main(argv: list[str] | None = None) -> None:
crystal_pose = get_crystal_pose(reference_sdf)
crystal_atoms = len(crystal_pose.atoms)

# If outfname is defined, prepare an output SDF sink to write molecules
output_sdf = get_output_sdf(out)

# Find the RMSD between the crystal pose and each docked pose
docked_poses = pybel.readfile("sdf", input_sdf)

print_fit_message(fit)
display_fit_message(fit)

skipped = []
molecules_dict = {} # Save all poses with their dockid
Expand All @@ -300,23 +294,16 @@ def main(argv: list[str] | None = None) -> None:
handle_pose_matching(out, i, docked_pose, rmsd_result, threshold, molecules_dict, population, out_dict)

if out:
output_sdf = output_sdf = pybel.Outputfile("sdf", out, overwrite=True) # TODO: Check the second argument
process_and_save_selected_molecules(output_sdf, out_dict, population)

if skipped:
print(f"SKIPPED input molecules due to the number of atom mismatch: {skipped}", file=sys.stderr)


def get_output_sdf(out: bool) -> pybel.Outputfile | None:
if out:
output_sdf = pybel.Outputfile("sdf", out, overwrite=True) # TODO: Check the second argument
return output_sdf


def print_fit_message(fit: bool) -> None:
if fit:
print("POSE\tRMSD_FIT")
else:
print("POSE\tRMSD_NOFIT")
def display_fit_message(fit: bool) -> None:
message = "FIT" if fit else "NOFIT"
print(f"POSE\tRMSD_{message}")


def get_crystal_pose(reference_sdf: argparse.FileType) -> pybel.Molecule:
Expand All @@ -341,10 +328,10 @@ def calculate_rmsd(crystal: pybel.Molecule, docked_pose: pybel.Molecule, fit: bo
Perform RMSD calculations and update coordinates if required.
"""
if fit:
rmsd_result, fitted_result = get_automorphism_rmsd(crystal, docked_pose, fit=True)
rmsd_result, fitted_result = get_automorphism_rmsd(crystal, docked_pose, fit)
update_coordinates(docked_pose, fitted_result)
else:
rmsd_result = get_automorphism_rmsd(crystal, docked_pose, fit=False)
rmsd_result = get_automorphism_rmsd(crystal, docked_pose, fit)

return rmsd_result

Expand All @@ -363,9 +350,9 @@ def handle_pose_matching(
Function to handle pose matching and filtering based on 'threshold' parser argument.
"""
if threshold:
match, best_match_value = get_best_matching_pose(docked_pose, threshold, molecules_dict)
if match is not None:
print_matching_info(pose_index, match, population, best_match_value)
match_pose, best_match_value = get_best_matching_pose(docked_pose, threshold, molecules_dict)
if match_pose is not None:
print_matching_info(pose_index, match_pose, population, best_match_value)
else:
save_or_print_info(out, pose_index, docked_pose, result_rmsd, molecules_dict, population, out_dict)
else:
Expand All @@ -376,22 +363,22 @@ def handle_pose_matching(
def get_best_matching_pose(
docked_pose: pybel.Molecule, threshold: float, molecules_dict: dict[int, pybel.Molecule]
) -> tuple[float | None, float]:
match = None
match_pose = None
best_match_value = 999999.0

for did, prevmol in molecules_dict.items():
tmprmsd = get_automorphism_rmsd(prevmol, docked_pose)

if tmprmsd < threshold and tmprmsd < best_match_value:
best_match_value = tmprmsd
match = did
match_pose = did

return (match, best_match_value)
return (match_pose, best_match_value)


def print_matching_info(index: int, match: float, population: dict[float, int], best_match_value: float) -> None:
print(f"Pose {index} matches pose {match} with {best_match_value:.3f} RMSD", file=sys.stderr)
population[match] += 1
def print_matching_info(index: int, match_pose: float, population: dict[float, int], best_match_value: float) -> None:
print(f"Pose {index} matches pose {match_pose} with {best_match_value:.3f} RMSD", file=sys.stderr)
population[match_pose] += 1


def save_or_print_info(
Expand Down

0 comments on commit 09af0e1

Please sign in to comment.