From 1840b77873083fc9dff45f5c3135753bcd43e4fb Mon Sep 17 00:00:00 2001 From: tjlane Date: Sat, 26 Oct 2024 11:18:47 +0100 Subject: [PATCH] tv result improvements --- meteor/tv.py | 50 ++++++++++++++++++++------------------------ test/unit/test_tv.py | 2 +- 2 files changed, 24 insertions(+), 28 deletions(-) diff --git a/meteor/tv.py b/meteor/tv.py index f5aa925..9614308 100644 --- a/meteor/tv.py +++ b/meteor/tv.py @@ -20,34 +20,31 @@ @dataclass class TvDenoiseResult: + # constants for JSON format + _scan_name = "scan" + _weight_name = "weight" + _negentropy_name = "negentropy" + initial_negentropy: float optimal_weight: float optimal_negentropy: float map_sampling_used_for_tv: float weights_scanned: list[float] negentropy_at_weights: list[float] - - def json(self) -> dict: - - data = [] - for idx in range(len(self.weights_scanned)): - data.append( - { - "weight": self.weights_scanned[idx], - "negentropy": self.negentropy_at_weights[idx], - } - ) - - json_payload = { - "initial_negentropy": self.initial_negentropy, - "optimal_weight": self.optimal_weight, - "optimal_negentropy": self.optimal_negentropy, - "map_sampling_used_for_tv": self.map_sampling_used_for_tv, - "data": data - } + def json(self) -> dict: + json_payload = asdict(self) + json_payload.pop("weights_scanned") + json_payload.pop("negentropy_at_weights") + json_payload[self._scan_name] = [ + { + self._weight_name: self.weights_scanned[idx], + self._negentropy_name: self.negentropy_at_weights[idx], + } + for idx in range(len(self.weights_scanned)) + ] return json_payload - + def to_json_file(self, filename: Path) -> None: with filename.open("w") as f: json.dump(self.json(), f, indent=4) @@ -55,13 +52,15 @@ def to_json_file(self, filename: Path) -> None: @classmethod def from_json(cls, json_payload: dict) -> TvDenoiseResult: try: - data = json_payload.pop("data") - json_payload["weights_scanned"] = [float(point["weight"]) for point in data] - json_payload["negentropy_at_weights"] = [float(point["negentropy"]) for point in data] + data = json_payload.pop(cls._scan_name) + json_payload["weights_scanned"] = [float(point[cls._weight_name]) for point in data] + json_payload["negentropy_at_weights"] = [ + float(point[cls._negentropy_name]) for point in data + ] return cls(**json_payload) except Exception as exptn: - msg = f"could not load json payload; mis-formatted" + msg = "could not load json payload; mis-formatted" raise ValueError(msg) from exptn @classmethod @@ -70,9 +69,6 @@ def from_json_file(cls, filename: Path) -> TvDenoiseResult: json_payload = json.load(f) return cls.from_json(json_payload) - - - def _tv_denoise_array(*, map_as_array: np.ndarray, weight: float) -> np.ndarray: """Closure convienence function to generate more readable code.""" diff --git a/test/unit/test_tv.py b/test/unit/test_tv.py index c11edfa..9511316 100644 --- a/test/unit/test_tv.py +++ b/test/unit/test_tv.py @@ -1,8 +1,8 @@ from __future__ import annotations +from dataclasses import asdict from pathlib import Path from typing import Sequence -from dataclasses import asdict import numpy as np import pandas as pd