From 2cfa16575c2dcd5a8add9802f286b5c6a4e28173 Mon Sep 17 00:00:00 2001 From: ggutierrez <94693768+ggutierrez-sunbright@users.noreply.github.com> Date: Tue, 6 Feb 2024 19:20:22 +0100 Subject: [PATCH] Issue #79/migrate sdmodify (#80) * add selectable output to sdfield * fix sdfilter output * add sdmodify * add sdmodify basic tests --- rdock-utils/pyproject.toml | 1 + rdock-utils/rdock_utils/common/SDFParser.py | 7 +++- rdock-utils/rdock_utils/sdfield.py | 37 +++++++++++++--- rdock-utils/rdock_utils/sdfilter/main.py | 4 +- rdock-utils/rdock_utils/sdmodify.py | 42 +++++++++++++++++++ rdock-utils/tests/sdmodify/__init__.py | 0 rdock-utils/tests/sdmodify/conftest.py | 5 +++ .../tests/sdmodify/test_integration.py | 37 ++++++++++++++++ 8 files changed, 125 insertions(+), 8 deletions(-) create mode 100644 rdock-utils/rdock_utils/sdmodify.py create mode 100644 rdock-utils/tests/sdmodify/__init__.py create mode 100644 rdock-utils/tests/sdmodify/conftest.py create mode 100644 rdock-utils/tests/sdmodify/test_integration.py diff --git a/rdock-utils/pyproject.toml b/rdock-utils/pyproject.toml index 12575343..5a77585b 100644 --- a/rdock-utils/pyproject.toml +++ b/rdock-utils/pyproject.toml @@ -16,6 +16,7 @@ sdrmsd = "rdock_utils.sdrmsd.main:main" sdtether = "rdock_utils.sdtether.main:main" sdtether_old = "rdock_utils.sdtether_original:main" sdfilter = "rdock_utils.sdfilter.main:main" +sdmodify = "rdock_utils.sdmodify:main" [project.urls] Repository = "https://github.com/CBDD/rDock.git" diff --git a/rdock-utils/rdock_utils/common/SDFParser.py b/rdock-utils/rdock_utils/common/SDFParser.py index 2409b74c..5fba46bc 100644 --- a/rdock-utils/rdock_utils/common/SDFParser.py +++ b/rdock-utils/rdock_utils/common/SDFParser.py @@ -74,7 +74,7 @@ def write(self, dest: TextIO) -> None: dest.writelines(self.lines) for field_name, field_value in self.data.items(): dest.write(self.str_field(field_name, field_value)) - dest.write("$$$$") + dest.write("$$$$\n") def get_field(self, field_name: str) -> str | None: if field_name.startswith("_TITLE"): @@ -84,6 +84,11 @@ def get_field(self, field_name: str) -> str | None: return None return self.data.get(field_name, None) + def set_title(self, title: str, line_index: int = 0) -> None: + if line_index > 2: + raise ValueError("line index must be 0, 1, or 2") + self.lines[line_index] = title + ("" if title.endswith("\n") else "\n") + @property def title(self) -> str: return self.lines[0].strip() diff --git a/rdock-utils/rdock_utils/sdfield.py b/rdock-utils/rdock_utils/sdfield.py index 218b413d..29b42252 100644 --- a/rdock-utils/rdock_utils/sdfield.py +++ b/rdock-utils/rdock_utils/sdfield.py @@ -1,6 +1,9 @@ # Standard Library import argparse +import sys +from dataclasses import dataclass from logging import getLogger +from typing import Generator, TextIO # Local imports from .common import inputs_generator, read_molecules_from_all_inputs @@ -12,21 +15,43 @@ def get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Adding fields to SD files") parser.add_argument("fieldname", type=str, help="name of the field to be added") parser.add_argument("value", type=str, help="value of the field to be added") - infile_help = "input file[s] to be processed. if not provided, stdin is used." - parser.add_argument("infile", type=str, nargs="*", help=infile_help) + infiles_help = "input file[s] to be processed. if not provided, stdin is used." + parser.add_argument("infiles", type=str, nargs="*", help=infiles_help) outfile_help = "output file. if not provided, stdout is used." parser.add_argument("-o", "--outfile", default=None, type=str, help=outfile_help) return parser -def main(argv: list[str] | None = None) -> None: +@dataclass +class SDFieldConfig: + fieldname: str + value: str + infiles: list[str] + outfile: str | None + + def get_outfile(self) -> Generator[TextIO, None, None]: + if self.outfile: + with open(self.outfile, "w") as f: + yield f + else: + yield sys.stdout + + +def get_config(argv: list[str] | None = None) -> SDFieldConfig: parser = get_parser() args = parser.parse_args(argv) - inputs = inputs_generator(args.infile) + return SDFieldConfig(fieldname=args.fieldname, value=args.value, infiles=args.infiles, outfile=args.outfile) + + +def main(argv: list[str] | None = None) -> None: + config = get_config(argv) + inputs = inputs_generator(config.infiles) + output_gen = config.get_outfile() + output = next(output_gen) for molecule in read_molecules_from_all_inputs(inputs): - molecule.data[args.fieldname] = args.value - print(repr(molecule)) + molecule.data[config.fieldname] = config.value + molecule.write(output) if __name__ == "__main__": diff --git a/rdock-utils/rdock_utils/sdfilter/main.py b/rdock-utils/rdock_utils/sdfilter/main.py index 50634e01..7318c91c 100644 --- a/rdock-utils/rdock_utils/sdfilter/main.py +++ b/rdock-utils/rdock_utils/sdfilter/main.py @@ -1,3 +1,5 @@ +import sys + from rdock_utils.common import inputs_generator, read_molecules_from_all_inputs from .filter import ExpressionContext, create_filters, molecules_with_context @@ -12,7 +14,7 @@ def main(argv: list[str] | None = None) -> None: molecules = molecules_with_context(read_molecules_from_all_inputs(inputs), context) for molecule in molecules: if any(filter.evaluate(molecule) for filter in filters): - print(repr(molecule)) + molecule.write(sys.stdout) if __name__ == "__main__": diff --git a/rdock-utils/rdock_utils/sdmodify.py b/rdock-utils/rdock_utils/sdmodify.py new file mode 100644 index 00000000..16c06ccf --- /dev/null +++ b/rdock-utils/rdock_utils/sdmodify.py @@ -0,0 +1,42 @@ +import argparse +import logging +import sys +from dataclasses import dataclass + +from rdock_utils.common import inputs_generator, read_molecules_from_all_inputs + +logger = logging.getLogger("sdmodify") + + +def get_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Set the first title line equal to a given data field") + parser.add_argument( + "-f", "--field", type=str, required=True, help="Data field to set the first title line equal to" + ) + infile_help = "input file[s] to be processed. if not provided, stdin is used." + parser.add_argument("infiles", type=str, nargs="*", help=infile_help) + return parser + + +@dataclass +class SDModifyConfig: + field: str + infiles: list[str] + + +def get_config(argv: list[str] | None = None) -> SDModifyConfig: + parser = get_parser() + args = parser.parse_args(argv) + return SDModifyConfig(field=args.field, infiles=args.infiles) + + +def main(argv: list[str] | None = None) -> None: + config = get_config(argv) + inputs = inputs_generator(config.infiles) + for index, mol in enumerate(read_molecules_from_all_inputs(inputs), start=1): + value = mol.get_field(config.field) if config.field != "_REC" else str(index) + if value is None: + logger.warning(f"field {config.field} not found in molecule {mol.title}, skipping...") + else: + mol.set_title(value) + mol.write(sys.stdout) diff --git a/rdock-utils/tests/sdmodify/__init__.py b/rdock-utils/tests/sdmodify/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/rdock-utils/tests/sdmodify/conftest.py b/rdock-utils/tests/sdmodify/conftest.py new file mode 100644 index 00000000..eb38e215 --- /dev/null +++ b/rdock-utils/tests/sdmodify/conftest.py @@ -0,0 +1,5 @@ +from ..conftest import FIXTURES_FOLDER + +# reuse sdfilter fixture +SDFILTER_FIXTURES_FOLDER = FIXTURES_FOLDER / "sdfilter" +INPUT_FILE = str(SDFILTER_FIXTURES_FOLDER / "input.sdf") diff --git a/rdock-utils/tests/sdmodify/test_integration.py b/rdock-utils/tests/sdmodify/test_integration.py new file mode 100644 index 00000000..644eeebf --- /dev/null +++ b/rdock-utils/tests/sdmodify/test_integration.py @@ -0,0 +1,37 @@ +from io import StringIO + +import pytest + +from rdock_utils.common import read_molecules +from rdock_utils.sdmodify import main + +from .conftest import INPUT_FILE + + +def test_do_nothing(): + with pytest.raises(SystemExit): + main() + + +@pytest.mark.parametrize( + "args, expected_titles", + [ + pytest.param( + ["-f", "test_field", INPUT_FILE], + (["0.0", "0.0", "2.0", "3.0", "4.0", "0.0"]), + id="molecule field filter", + ), + pytest.param( + ["-f", "_REC", INPUT_FILE], + list(map(str, range(1, 7))), + id="molecule field filter with output file", + ), + ], +) +def test_basic_run(args: list[str], expected_titles: list[str], capsys: pytest.CaptureFixture): + main(args) + captured = capsys.readouterr() + input = StringIO(captured.out) + molecules = read_molecules(input) + titles = [m.title for m in molecules] + assert titles == expected_titles