-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
phaseboost in, need to work on tests
- Loading branch information
Showing
7 changed files
with
329 additions
and
127 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
import structlog | ||
|
||
from meteor.iterative import iterative_tv_phase_retrieval | ||
from meteor.tv import tv_denoise_difference_map | ||
|
||
from .common import DiffmapArgParser, kweight_diffmap_according_to_mode | ||
|
||
log = structlog.get_logger() | ||
|
||
|
||
# TODO: test this | ||
TV_WEIGHTS_TO_SCAN_DEFAULT = [0.01] | ||
|
||
|
||
class IterativeTvArgParser(DiffmapArgParser): | ||
def __init__(self, *args: Any, **kwargs: Any) -> None: | ||
super().__init__(*args, **kwargs) | ||
self.add_argument( | ||
"-x", | ||
"--tv-weights-to-scan", | ||
nargs="+", | ||
type=float, | ||
default=TV_WEIGHTS_TO_SCAN_DEFAULT, | ||
help=( | ||
"Choose what TV weights to evaluate at every iteration. Can be a single float." | ||
f"Default: {TV_WEIGHTS_TO_SCAN_DEFAULT}." | ||
), | ||
) | ||
|
||
|
||
def main(command_line_arguments: list[str] | None = None) -> None: | ||
parser = IterativeTvArgParser( | ||
description=( | ||
"bla bla" # TODO | ||
Check failure on line 38 in meteor/scripts/compute_iterative_tv_map.py GitHub Actions / build (3.11)Ruff (TD004)
Check failure on line 38 in meteor/scripts/compute_iterative_tv_map.py GitHub Actions / build (3.11)Ruff (TD005)
|
||
) | ||
) | ||
args = parser.parse_args(command_line_arguments) | ||
parser.check_output_filepaths(args) | ||
mapset = parser.load_difference_maps(args) | ||
|
||
# First, find improved derivative phases | ||
new_derivative_map, it_tv_metadata = iterative_tv_phase_retrieval( | ||
mapset.derivative, | ||
mapset.native, | ||
tv_weights_to_scan=args.tv_weights_to_scan, | ||
) | ||
mapset.derivative = new_derivative_map | ||
|
||
diffmap, kparameter_used = kweight_diffmap_according_to_mode( | ||
kweight_mode=args.kweight_mode, kweight_parameter=args.kweight_parameter, mapset=mapset | ||
) | ||
|
||
# TODO: used fixed weight or golden method? | ||
final_map, final_tv_metadata = tv_denoise_difference_map( | ||
diffmap, full_output=True, weights_to_scan=args.tv_weights_to_scan | ||
) | ||
|
||
log.info("Writing output.", file=str(args.mtzout)) | ||
final_map.write_mtz(args.mtzout) | ||
|
||
# TODO: append it_tv_metadata | ||
# log.info("Writing metadata.", file=str(args.metadataout)) | ||
# metadata.k_parameter_used = kparameter_used | ||
# | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import pytest | ||
import reciprocalspaceship as rs | ||
|
||
from meteor.rsmap import Map | ||
from meteor.scripts import compute_iterative_tv_map | ||
from meteor.scripts.common import WeightMode | ||
from meteor.tv import TvDenoiseResult | ||
from meteor.utils import filter_common_indices | ||
|
||
|
||
@pytest.mark.parametrize("kweight_mode", list(WeightMode)) | ||
@pytest.mark.parametrize("tv_weight_mode", list(WeightMode)) | ||
def test_script_produces_consistent_results( | ||
kweight_mode: WeightMode, | ||
tv_weight_mode: WeightMode, | ||
testing_pdb_file: Path, | ||
testing_mtz_file: Path, | ||
tmp_path: Path, | ||
) -> None: | ||
# for when WeightMode.fixed; these maximize negentropy in manual testing | ||
kweight_parameter = 0.05 | ||
|
||
output_mtz = tmp_path / "test-output.mtz" | ||
output_metadata = tmp_path / "test-output-metadata.csv" | ||
|
||
cli_args = [ | ||
str(testing_mtz_file), # derivative | ||
"--derivative-amplitude-column", | ||
"F_on", | ||
"--derivative-uncertainty-column", | ||
"SIGF_on", | ||
str(testing_mtz_file), # native | ||
"--native-amplitude-column", | ||
"F_off", | ||
"--native-uncertainty-column", | ||
"SIGF_off", | ||
"--structure", | ||
str(testing_pdb_file), | ||
"-o", | ||
str(output_mtz), | ||
"-m", | ||
str(output_metadata), | ||
"--kweight-mode", | ||
kweight_mode, | ||
"--kweight-parameter", | ||
str(kweight_parameter), | ||
"-x", | ||
"0.01", | ||
] | ||
|
||
compute_iterative_tv_map.main(cli_args) | ||
|
||
# TODO this simple load metadata won't work, load JSON then make object instead | ||
Check failure on line 56 in test/functional/test_compute_iterative_tv.py GitHub Actions / build (3.11)Ruff (TD004)
|
||
result_metadata = TvDenoiseResult.from_json_file(output_metadata) | ||
result_map = Map.read_mtz_file(output_mtz) | ||
|
||
# 1. make sure negentropy increased | ||
if kweight_mode == WeightMode.none and tv_weight_mode == WeightMode.none: | ||
np.testing.assert_allclose( | ||
result_metadata.optimal_negentropy, result_metadata.initial_negentropy | ||
) | ||
else: | ||
assert result_metadata.optimal_negentropy >= result_metadata.initial_negentropy | ||
|
||
# 2. make sure computed DF are close to those stored on disk | ||
reference_dataset = rs.read_mtz(str(testing_mtz_file)) | ||
reference_amplitudes = reference_dataset["F_itTV"] | ||
|
||
result_amplitudes, reference_amplitudes = filter_common_indices( | ||
result_map.amplitudes, reference_amplitudes | ||
) | ||
rho = np.corrcoef(result_amplitudes.to_numpy(), reference_amplitudes.to_numpy())[0, 1] | ||
|
||
# comparing a correlation coefficienct allows for a global scale factor change, but nothing else | ||
if (kweight_mode == WeightMode.none) or (tv_weight_mode == WeightMode.none): # noqa: PLR1714 | ||
assert rho > 0.50 | ||
else: | ||
assert rho > 0.98 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.