Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom plots rebase #199

Merged
merged 8 commits into from
Oct 3, 2024
31 changes: 29 additions & 2 deletions refl1d/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@

from dataclasses import dataclass, field
import sys
from copy import deepcopy
import os
from math import pi, log10, floor
import traceback
import json
from typing import Optional, Any, Union, Dict, Callable, Literal, Tuple, List, Literal
from typing import Optional, Union, Callable, Literal, List, Literal, Protocol, TypedDict
from warnings import warn

import numpy as np
from bumps import parameter
from bumps.parameter import Parameter, Constraint, tag_all
from bumps.fitproblem import Fitness
from bumps.fitproblem import Fitness, FitProblem
from bumps.dream.state import MCMCDraw

from . import material, profile
from . import __version__
Expand Down Expand Up @@ -51,12 +53,24 @@ def plot_sample(sample, instrument=None, roughness_limit=0):
experiment.plot()


class WebviewPlotFunction(Protocol):
def __call__(
self, model: "ExperimentBase", problem: FitProblem, state: MCMCDraw, n_samples: Optional[int]
) -> dict: ...


class WebviewPlotInfo(TypedDict):
change_with: Literal["parameter", "uncertainty"]
func: WebviewPlotFunction


class ExperimentBase:
probe = None # type: Optional[Probe]
interpolation = 0
_probe_cache = None
_substrate = None
_surface = None
_webview_plots: dict[str, WebviewPlotInfo]

def parameters(self):
raise NotImplementedError()
Expand Down Expand Up @@ -323,6 +337,17 @@ def save_refl(self, basename):
theory = self.reflectivity(interpolation=self.interpolation)
self.probe.save(filename=basename + "-refl-interp.dat", theory=theory)

def register_webview_plot(
self, plot_title: str, plot_function: WebviewPlotFunction, change_with: Literal["parameter", "uncertainty"]
):
# Plot function syntax: f(model, problem, state)
# change_with = 'parameter' or 'uncertainty'
self._webview_plots[plot_title] = dict(change_with=change_with, func=plot_function)

@property
def webview_plots(self):
return self._webview_plots


@dataclass(init=False)
class Experiment(ExperimentBase):
Expand Down Expand Up @@ -430,6 +455,7 @@ def __init__(
if auto_tag:
tag_all(self.probe.parameters(), "probe")
tag_all(self.sample.parameters(), "sample")
self._webview_plots = {}

@property
def ismagnetic(self):
Expand Down Expand Up @@ -725,6 +751,7 @@ def __init__(self, samples=None, ratio=None, probe=None, name=None, coherent=Fal
self._surface = self.samples[0][-1].material
self._cache = {}
self.name = name if name is not None else probe.name
self._webview_plots = {}

def update(self):
self._cache = {}
Expand Down
Loading