Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

map object and major refactor #24

Merged
merged 51 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
3fff55b
Map object draft
tjlane Oct 13, 2024
68d3b3d
its alive
tjlane Oct 14, 2024
8e3fa37
test progress
tjlane Oct 14, 2024
02f4ec9
back in business, need to allow inplace
tjlane Oct 14, 2024
da03e58
edit ruff to allow inplace
tjlane Oct 14, 2024
ecce171
progress
tjlane Oct 14, 2024
a81ac7e
fix dataset mismatch
tjlane Oct 14, 2024
3493ca5
loose ends but wires working
tjlane Oct 14, 2024
59ae2e6
much better model
tjlane Oct 14, 2024
91ceb75
implementation progress, still need to lint
tjlane Oct 14, 2024
1b4ac92
lint
tjlane Oct 14, 2024
31de739
todo elaboration
tjlane Oct 14, 2024
7097fed
tag docstring as TODO
tjlane Oct 14, 2024
13de697
getting closer
tjlane Oct 14, 2024
83f2c2f
more todo
tjlane Oct 14, 2024
cac0e9c
lots of type checking and improvements
tjlane Oct 15, 2024
7cc1082
progress
tjlane Oct 15, 2024
c40b9c1
cleaning lint
tjlane Oct 15, 2024
c5e02cb
interface cleaner
tjlane Oct 15, 2024
3281a81
diffmap tests passing
tjlane Oct 15, 2024
a21b7f3
tv tests passing
tjlane Oct 15, 2024
b5aef48
missed file
tjlane Oct 15, 2024
063d966
more tests passing
tjlane Oct 15, 2024
d5431fb
almost green
tjlane Oct 15, 2024
ea732f6
green
tjlane Oct 15, 2024
529fd91
stricter lint
tjlane Oct 16, 2024
0d701ef
test coverage
tjlane Oct 16, 2024
10ba2ca
more test coverage
tjlane Oct 16, 2024
3d6b24e
test coverage
tjlane Oct 16, 2024
fe24233
documenting missing tests and issues
tjlane Oct 16, 2024
f8b431d
docstring for rsmap
tjlane Oct 16, 2024
ff4526c
we need to gemmi...
tjlane Oct 16, 2024
9479e62
updates
tjlane Oct 16, 2024
ea509d9
remove unused code
tjlane Oct 16, 2024
ba39162
structurefactor consistency with rs
tjlane Oct 16, 2024
806f2c1
lint
tjlane Oct 16, 2024
a3e1d7a
test for set_common_crystallographic_metadata
tjlane Oct 16, 2024
37983c3
test_assert_is_map
tjlane Oct 16, 2024
6a4ef6b
type fix test
tjlane Oct 16, 2024
37525b4
remove unneeded test
tjlane Oct 16, 2024
7a9a6fe
to from gemmi and mtz tests
tjlane Oct 16, 2024
29d3a02
two missing type hints
tjlane Oct 16, 2024
6c53649
audit
tjlane Oct 16, 2024
5075e7d
return dHKL
tjlane Oct 16, 2024
bca2d81
mypy and lint
tjlane Oct 16, 2024
505dd8a
allow reset index, fixes issues
tjlane Oct 16, 2024
8128252
last helpful comment
tjlane Oct 16, 2024
0183bf4
test coverage and last fixes
tjlane Oct 17, 2024
569d37e
test coverage
tjlane Oct 17, 2024
e3cf823
test coverage
tjlane Oct 17, 2024
05ad10b
last test and suppress caught warning
tjlane Oct 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
315 changes: 88 additions & 227 deletions meteor/diffmaps.py

Large diffs are not rendered by default.

110 changes: 32 additions & 78 deletions meteor/iterative.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@

import numpy as np
import pandas as pd
import reciprocalspaceship as rs

from .rsmap import Map
from .tv import TvDenoiseResult, tv_denoise_difference_map
from .utils import (
average_phase_diff_in_degrees,
canonicalize_amplitudes,
complex_array_to_rs_dataseries,
rs_dataseries_to_complex_array,
)

DEFAULT_TV_WEIGHTS_TO_SCAN = [0.001, 0.01, 0.1, 1.0]


def _project_derivative_on_experimental_set(
*,
Expand Down Expand Up @@ -68,7 +68,7 @@ def _complex_derivative_from_iterative_tv(

Parameters
----------
complex_native: np.ndarray
native: np.ndarray
The complex native structure factors, usually experimental amplitudes and calculated phases

initial_complex_derivative : np.ndarray
Expand All @@ -82,10 +82,10 @@ def _complex_derivative_from_iterative_tv(

convergance_tolerance: float
If the change in the estimated derivative SFs drops below this value (phase, per-component)
then return
then return. Default 1e-4.

max_iterations: int
If this number of iterations is reached, stop early.
If this number of iterations is reached, stop early. Default 1000.

Returns
-------
Expand Down Expand Up @@ -125,7 +125,7 @@ def _complex_derivative_from_iterative_tv(
"tv_weight": tv_metadata.optimal_lambda,
"negentropy_after_tv": tv_metadata.optimal_negentropy,
"average_phase_change": phase_change,
}
},
)

if num_iterations > max_iterations:
Expand All @@ -135,16 +135,13 @@ def _complex_derivative_from_iterative_tv(


def iterative_tv_phase_retrieval(
input_dataset: rs.DataSet,
initial_derivative: Map,
native: Map,
*,
native_amplitude_column: str = "F",
derivative_amplitude_column: str = "Fh",
calculated_phase_column: str = "PHIC",
output_derivative_phase_column: str = "PHICh",
convergence_tolerance: float = 1e-3,
max_iterations: int = 100,
tv_weights_to_scan: list[float] | None = None,
) -> tuple[rs.DataSet, pd.DataFrame]:
convergence_tolerance: float = 1e-4,
max_iterations: int = 1000,
tv_weights_to_scan: list[float] = DEFAULT_TV_WEIGHTS_TO_SCAN,
) -> tuple[Map, pd.DataFrame]:
"""
Here is a brief pseudocode sketch of the alogrithm. Structure factors F below are complex unless
explicitly annotated |*|.
Expand All @@ -168,107 +165,64 @@ def iterative_tv_phase_retrieval(

Parameters
----------
input_dataset : rs.DataSet
The input dataset containing the native and derivative amplitude columns, as well as
the calculated phase column.

native_amplitude_column : str, optional
Column name in `input_dataset` representing the amplitudes of the native (dark) structure
factors, by default "F".
initial_derivative: Map
the derivative amplitudes, and initial guess for the phases

derivative_amplitude_column : str, optional
Column name in `input_dataset` representing the amplitudes of the derivative (light)
structure factors, by default "Fh".

calculated_phase_column : str, optional
Column name in `input_dataset` representing the phases of the native (dark) structure
factors, by default "PHIC".

output_derivative_phase_column : str, optional
Column name where the estimated derivative phases will be stored in the output dataset,
by default "PHICh".
native: Map
the native amplitudes, phases

convergance_tolerance: float
If the change in the estimated derivative SFs drops below this value (phase, per-component)
then return
then return. Default 1e-4.

max_iterations: int
If this number of iterations is reached, stop early.
If this number of iterations is reached, stop early. Default 1000.

tv_weights_to_scan : list[float], optional
A list of TV regularization weights (λ values) to be scanned for optimal results,
by default [0.001, 0.01, 0.1, 1.0].

Returns
-------
output_dataset: rs.DataSet
output_map: Map
The estimated derivative phases, along with the input amplitudes and input computed phases.

metadata: pd.DataFrame
Information about the algorithm run as a function of iteration. For each step, includes:
the tv_weight used, the negentropy (after the TV step), and the average phase change in
degrees.
"""

# clean TV denoising interface that is crystallographically intelligent
# maintains state for the HKL index, spacegroup, and cell information
if tv_weights_to_scan is None:
tv_weights_to_scan = [0.001, 0.01, 0.1, 1.0]

def tv_denoise_closure(difference: np.ndarray) -> tuple[np.ndarray, TvDenoiseResult]:
delta_amp, delta_phase = complex_array_to_rs_dataseries(
difference, index=input_dataset.index
)

# these two names are only used inside this closure
delta_amp.name = "DF_for_tv_closure"
delta_phase.name = "DPHI_for_tv_closure"

diffmap = rs.concat([delta_amp, delta_phase], axis=1)
diffmap.cell = input_dataset.cell
diffmap.spacegroup = input_dataset.spacegroup
diffmap = Map.from_structurefactor(difference, index=native.index)
diffmap.cell = native.cell
diffmap.spacegroup = native.spacegroup

denoised_map_coefficients, tv_metadata = tv_denoise_difference_map(
denoised_map, tv_metadata = tv_denoise_difference_map(
diffmap,
difference_map_amplitude_column=delta_amp.name,
difference_map_phase_column=delta_phase.name,
lambda_values_to_scan=tv_weights_to_scan,
full_output=True,
)

denoised_difference = rs_dataseries_to_complex_array(
denoised_map_coefficients[delta_amp.name], denoised_map_coefficients[delta_phase.name]
)

return denoised_difference, tv_metadata

# convert the native and derivative datasets to complex arrays
native = rs_dataseries_to_complex_array(
input_dataset[native_amplitude_column], input_dataset[calculated_phase_column]
)
initial_derivative = rs_dataseries_to_complex_array(
input_dataset[derivative_amplitude_column], input_dataset[calculated_phase_column]
)
return denoised_map.complex_amplitudes, tv_metadata

# estimate the derivative phases using the iterative TV algorithm
it_tv_complex_derivative, metadata = _complex_derivative_from_iterative_tv(
native=native,
initial_derivative=initial_derivative,
native=native.complex_amplitudes,
initial_derivative=initial_derivative.complex_amplitudes,
tv_denoise_function=tv_denoise_closure,
convergence_tolerance=convergence_tolerance,
max_iterations=max_iterations,
)
_, derivative_phases = complex_array_to_rs_dataseries(
it_tv_complex_derivative, input_dataset.index
it_tv_complex_derivative,
index=initial_derivative.index,
)

# combine the determined derivative phases with the input to generate a complete output
output_dataset = input_dataset.copy()
output_dataset[output_derivative_phase_column] = derivative_phases.astype(rs.PhaseDtype())
canonicalize_amplitudes(
output_dataset,
amplitude_label=derivative_amplitude_column,
phase_label=output_derivative_phase_column,
inplace=True,
)
output_dataset = initial_derivative.copy()
output_dataset.phases = derivative_phases

return output_dataset, metadata
Loading
Loading