Skip to content

Commit

Permalink
feat: add direct from ground truth (#32)
Browse files Browse the repository at this point in the history
* feat: add direct from ground truth

* guard import

* add example

* update

* add explanation
  • Loading branch information
tlambert03 authored May 19, 2024
1 parent 36dabb9 commit ed70c3a
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 7 deletions.
172 changes: 172 additions & 0 deletions examples/direct_truth.ipynb

Large diffs are not rendered by default.

24 changes: 24 additions & 0 deletions src/microsim/schema/sample/direct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal

from microsim.schema._base_model import SimBaseModel
from microsim.schema.backend import NumpyAPI

if TYPE_CHECKING:
from microsim._data_array import DataArray


class FixedArrayTruth(SimBaseModel):
type: Literal["fixed-array"] = "fixed-array"
array: Any

def render(self, space: DataArray, xp: NumpyAPI | None = None) -> DataArray:
if space.shape != self.array.shape:
raise ValueError(
"This GroundTruth may only be used with simulation space of shape: "
f"{self.array.shape}. Got: {space.shape}"
)

xp = xp or NumpyAPI()
return space + xp.asarray(self.array).astype(space.dtype)
11 changes: 8 additions & 3 deletions src/microsim/schema/sample/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

from pydantic import Field, model_validator

from microsim._data_array import DataArray
from microsim._data_array import ArrayProtocol, DataArray
from microsim.schema._base_model import SimBaseModel
from microsim.schema.backend import NumpyAPI

from .direct import FixedArrayTruth
from .fluorophore import Fluorophore
from .matslines import MatsLines

Distribution = MatsLines
Distribution = MatsLines | FixedArrayTruth


class FluorophoreDistribution(SimBaseModel):
Expand All @@ -21,13 +22,17 @@ def render(self, space: DataArray, xp: NumpyAPI | None = None) -> DataArray:

@model_validator(mode="before")
def _vmodel(cls, value: Any) -> Any:
if isinstance(value, Distribution):
if isinstance(value, (MatsLines | FixedArrayTruth)):
return {"distribution": value}
if isinstance(value, dict):
if "distribution" not in value and "type" in value:
return {"distribution": value}
return value

@classmethod
def from_array(cls, array: ArrayProtocol) -> "FluorophoreDistribution":
return cls(distribution=FixedArrayTruth(array=array))


class Sample(SimBaseModel):
labels: list[FluorophoreDistribution]
39 changes: 35 additions & 4 deletions src/microsim/schema/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,34 @@
from typing import TYPE_CHECKING, Annotated

if TYPE_CHECKING:
from typing import Self
from typing import Self, Unpack

from .backend import NumpyAPI

from pydantic import AfterValidator, Field, model_validator

from microsim._data_array import DataArray
from microsim._data_array import ArrayProtocol, DataArray

from ._base_model import SimBaseModel
from .detectors import Detector
from .lens import ObjectiveLens
from .modality import Modality, Widefield
from .optical_config import FITC, OpticalConfig
from .sample import Sample
from .sample import FluorophoreDistribution, Sample
from .settings import Settings
from .space import Space, _RelativeSpace
from .space import ShapeScaleSpace, Space, _RelativeSpace

if TYPE_CHECKING:
from typing import TypedDict

class SimluationKwargs(TypedDict, total=False):
output_space: Space | dict | None
objective_lens: ObjectiveLens
channels: list[OpticalConfig]
detector: Detector | None
modality: Modality
settings: Settings
output_path: "OutPath" | None


def _check_extensions(path: Path) -> Path:
Expand All @@ -42,6 +54,25 @@ class Simulation(SimBaseModel):
settings: Settings = Field(default_factory=Settings)
output_path: OutPath | None = None

@classmethod
def from_ground_truth(
self,
ground_truth: ArrayProtocol,
scale: tuple[float, ...],
**kwargs: "Unpack[SimluationKwargs]",
) -> "Self":
"""Shortcut to create a simulation directly from a ground truth array.
In this case, we bypass derive the `truth_space` and `sample` objects directly
from a pre-calculated ground truth array. `scale` must also be provided as a
tuple of floats, one for each dimension of the ground truth array.
"""
return self(
truth_space=ShapeScaleSpace(shape=ground_truth.shape, scale=scale),
sample=Sample(labels=[FluorophoreDistribution.from_array(ground_truth)]),
**kwargs,
)

@model_validator(mode="after")
def _resolve_spaces(self) -> "Self":
if isinstance(self.truth_space, _RelativeSpace):
Expand Down
8 changes: 8 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,11 @@ def test_sim_from_json() -> None:
"""

ms.Simulation.model_validate_json(json_string)


def test_simulation_from_ground_truth() -> None:
ground_truth = np.random.rand(64, 128, 128)
scale = (0.04, 0.02, 0.02)
sim = ms.Simulation.from_ground_truth(ground_truth=ground_truth, scale=scale)
assert sim.truth_space.scale == scale
np.testing.assert_array_almost_equal(sim.ground_truth(), ground_truth)

0 comments on commit ed70c3a

Please sign in to comment.