Skip to content

Commit

Permalink
inference for column names
Browse files Browse the repository at this point in the history
  • Loading branch information
tjlane committed Oct 24, 2024
1 parent e774f1c commit ce937c5
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 0 deletions.
78 changes: 78 additions & 0 deletions meteor/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""https://www.ccp4.ac.uk/html/mtzformat.html
https://www.globalphasing.com/buster/wiki/index.cgi?MTZcolumns
"""

from __future__ import annotations

import re
from typing import Final

# TODO: scour for PHENIX style, add

Check failure on line 10 in meteor/io.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (FIX002)

meteor/io.py:10:3: FIX002 Line contains TODO, consider resolving the issue

OBSERVED_INTENSITY_LABELS: Final[list[str]] = [
"I", # generic
"IMEAN", # CCP4
"I-obs", # phenix
]

OBSERVED_AMPLITUDE_LABELS: Final[list[str]] = [
"F", # generic
"FP", # CCP4 & GLPh native
r"FPH\d", # CCP4 derivative
"F-obs", # phenix
]

OBSERVED_UNCERTAINTY_LABELS: Final[list[str]] = [
"SIGF", # generic
"SIGFP", # CCP4 & GLPh native
r"SIGFPH\d", # CCP4
]

COMPUTED_AMPLITUDE_LABELS: Final[list[str]] = ["FC"]

COMPUTED_PHASE_LABELS: Final[list[str]] = ["PHIC"]


class AmbiguousMtzLabelError(ValueError): ...


def _infer_mtz_label(labels_to_search: list[str], labels_to_look_for: list[str]) -> str:
# the next line consumes ["FOO", "BAR", "BAZ"] and produces regex strings like "^(FOO|BAR|BAZ)$"
regex = re.compile(f"^({'|'.join(labels_to_look_for)})$")
matches = [regex.match(label) for label in labels_to_search if regex.match(label) is not None]

if len(matches) == 0:
msg = "cannot infer MTZ column name; "
msg += f"cannot find any of {labels_to_look_for} in {labels_to_search}"
raise AmbiguousMtzLabelError(msg)
if len(matches) > 1:
msg = "cannot infer MTZ column name; "
msg += f">1 instance of {labels_to_look_for} in {labels_to_search}"
raise AmbiguousMtzLabelError(msg)

[match] = matches
if match is None:
msg = "`None` not filtered during regex matching"
raise RuntimeError(msg)

return match.group(0)


def find_observed_intensity_label(mtz_column_labels: list[str]) -> str:
return _infer_mtz_label(mtz_column_labels, OBSERVED_INTENSITY_LABELS)


def find_observed_amplitude_label(mtz_column_labels: list[str]) -> str:
return _infer_mtz_label(mtz_column_labels, OBSERVED_AMPLITUDE_LABELS)


def find_observed_uncertainty_label(mtz_column_labels: list[str]) -> str:
return _infer_mtz_label(mtz_column_labels, OBSERVED_UNCERTAINTY_LABELS)


def find_computed_amplitude_label(mtz_column_labels: list[str]) -> str:
return _infer_mtz_label(mtz_column_labels, COMPUTED_AMPLITUDE_LABELS)


def find_computed_phase_label(mtz_column_labels: list[str]) -> str:
return _infer_mtz_label(mtz_column_labels, COMPUTED_PHASE_LABELS)
102 changes: 102 additions & 0 deletions test/unit/test_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from __future__ import annotations

import pytest
from typing import Callable

from meteor import io

FIND_LABEL_FUNC_TYPE = Callable[[list[str]], str]

Check failure on line 8 in test/unit/test_io.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (I001)

test/unit/test_io.py:1:1: I001 Import block is un-sorted or un-formatted


OBSERVED_INTENSITY_CASES = [
([], "raise"),
(["F"], "raise"),
(["I", "F"], "I"),
(["IMEAN", "F"], "IMEAN"),
(["I-obs", "F"], "I-obs"),
(["I", "IMEAN"], "raise"),
]


OBSERVED_AMPLITUDE_CASES = [
([], "raise"),
(["IMEAN"], "raise"),
(["I", "F"], "F"),
(["IMEAN", "FP"], "FP"),
(["I-obs", "FPH0"], "FPH0"),
(["I-obs", "FPH5"], "FPH5"),
(["I-obs", "FPH51"], "raise"),
(["FP", "FPH1"], "raise"),
]


OBSERVED_UNCERTAINTY_CASES = [
([], "raise"),
(["F"], "raise"),
(["SIGF", "F"], "SIGF"),
(["SIGFP", "F"], "SIGFP"),
(["I-obs", "SIGFPH0"], "SIGFPH0"),
(["I-obs", "SIGFPH1"], "SIGFPH1"),
(["I-obs", "SIGFPH10"], "raise"),
(["SIGFPH1", "SIGFPH2"], "raise"),
]


COMPUTED_AMPLITUDE_CASES = [
([], "raise"),
(["F"], "raise"),
(["I", "F"], "raise"),
(["FC", "F"], "FC"),
(["FC", "FC"], "raise"),
]


COMPUTED_PHASE_CASES = [
([], "raise"),
(["F"], "raise"),
(["PHIC", "F"], "PHIC"),
(["PHIC", "PHIC"], "raise"),
]


def test_infer_mtz_label() -> None:
to_search = ["FOO", "BAR", "BAZ"]
assert io._infer_mtz_label(to_search, ["FOO"]) == "FOO"
assert io._infer_mtz_label(to_search, ["BAR"]) == "BAR"
with pytest.raises(io.AmbiguousMtzLabelError):
_ = io._infer_mtz_label(to_search, [])
with pytest.raises(io.AmbiguousMtzLabelError):
_ = io._infer_mtz_label(to_search, ["FOO", "BAR"])


def validate_find_label_result(function: FIND_LABEL_FUNC_TYPE, labels: list[str], expected_result: str) -> None:
if expected_result == "raise":
with pytest.raises(io.AmbiguousMtzLabelError):
_ = function(labels)
else:
assert function(labels) == expected_result


@pytest.mark.parametrize("labels,expected_result", OBSERVED_INTENSITY_CASES)

Check failure on line 80 in test/unit/test_io.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (PT006)

test/unit/test_io.py:80:26: PT006 Wrong type passed to first argument of `@pytest.mark.parametrize`; expected `tuple`
def test_find_observed_intensity_label(labels: list[str], expected_result: str) -> None:
validate_find_label_result(io.find_observed_intensity_label, labels, expected_result)


@pytest.mark.parametrize("labels,expected_result", OBSERVED_AMPLITUDE_CASES)

Check failure on line 85 in test/unit/test_io.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (PT006)

test/unit/test_io.py:85:26: PT006 Wrong type passed to first argument of `@pytest.mark.parametrize`; expected `tuple`
def test_find_observed_amplitude_label(labels: list[str], expected_result: str) -> None:
validate_find_label_result(io.find_observed_amplitude_label, labels, expected_result)


@pytest.mark.parametrize("labels,expected_result", OBSERVED_UNCERTAINTY_CASES)

Check failure on line 90 in test/unit/test_io.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (PT006)

test/unit/test_io.py:90:26: PT006 Wrong type passed to first argument of `@pytest.mark.parametrize`; expected `tuple`
def test_find_observed_uncertainty_label(labels: list[str], expected_result: str) -> None:
validate_find_label_result(io.find_observed_uncertainty_label, labels, expected_result)


@pytest.mark.parametrize("labels,expected_result", COMPUTED_AMPLITUDE_CASES)

Check failure on line 95 in test/unit/test_io.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (PT006)

test/unit/test_io.py:95:26: PT006 Wrong type passed to first argument of `@pytest.mark.parametrize`; expected `tuple`
def test_find_computed_amplitude_label(labels: list[str], expected_result: str) -> None:
validate_find_label_result(io.find_computed_amplitude_label, labels, expected_result)


@pytest.mark.parametrize("labels,expected_result", COMPUTED_PHASE_CASES)

Check failure on line 100 in test/unit/test_io.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (PT006)

test/unit/test_io.py:100:26: PT006 Wrong type passed to first argument of `@pytest.mark.parametrize`; expected `tuple`
def test_find_computed_phase_label(labels: list[str], expected_result: str) -> None:
validate_find_label_result(io.find_computed_phase_label, labels, expected_result)

0 comments on commit ce937c5

Please sign in to comment.