Skip to content

Commit

Permalink
tv result improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
tjlane committed Oct 26, 2024
1 parent c60d2e5 commit 1840b77
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 28 deletions.
50 changes: 23 additions & 27 deletions meteor/tv.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,48 +20,47 @@

@dataclass
class TvDenoiseResult:
# constants for JSON format
_scan_name = "scan"
_weight_name = "weight"
_negentropy_name = "negentropy"

initial_negentropy: float
optimal_weight: float
optimal_negentropy: float
map_sampling_used_for_tv: float
weights_scanned: list[float]
negentropy_at_weights: list[float]

def json(self) -> dict:

data = []
for idx in range(len(self.weights_scanned)):
data.append(
{
"weight": self.weights_scanned[idx],
"negentropy": self.negentropy_at_weights[idx],
}
)

json_payload = {
"initial_negentropy": self.initial_negentropy,
"optimal_weight": self.optimal_weight,
"optimal_negentropy": self.optimal_negentropy,
"map_sampling_used_for_tv": self.map_sampling_used_for_tv,
"data": data
}

def json(self) -> dict:
json_payload = asdict(self)
json_payload.pop("weights_scanned")
json_payload.pop("negentropy_at_weights")
json_payload[self._scan_name] = [
{
self._weight_name: self.weights_scanned[idx],
self._negentropy_name: self.negentropy_at_weights[idx],
}
for idx in range(len(self.weights_scanned))
]
return json_payload

def to_json_file(self, filename: Path) -> None:
with filename.open("w") as f:
json.dump(self.json(), f, indent=4)

@classmethod
def from_json(cls, json_payload: dict) -> TvDenoiseResult:
try:
data = json_payload.pop("data")
json_payload["weights_scanned"] = [float(point["weight"]) for point in data]
json_payload["negentropy_at_weights"] = [float(point["negentropy"]) for point in data]
data = json_payload.pop(cls._scan_name)
json_payload["weights_scanned"] = [float(point[cls._weight_name]) for point in data]
json_payload["negentropy_at_weights"] = [
float(point[cls._negentropy_name]) for point in data
]
return cls(**json_payload)

except Exception as exptn:
msg = f"could not load json payload; mis-formatted"
msg = "could not load json payload; mis-formatted"
raise ValueError(msg) from exptn

@classmethod
Expand All @@ -70,9 +69,6 @@ def from_json_file(cls, filename: Path) -> TvDenoiseResult:
json_payload = json.load(f)
return cls.from_json(json_payload)





def _tv_denoise_array(*, map_as_array: np.ndarray, weight: float) -> np.ndarray:
"""Closure convienence function to generate more readable code."""
Expand Down
2 changes: 1 addition & 1 deletion test/unit/test_tv.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

from dataclasses import asdict
from pathlib import Path
from typing import Sequence
from dataclasses import asdict

import numpy as np
import pandas as pd
Expand Down

0 comments on commit 1840b77

Please sign in to comment.