Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HDN minimum example #396

Open
wants to merge 45 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
fa5c320
multifile pred, avg_psnr
CatEek Dec 6, 2024
9a35653
Merge branch 'main' into splits_prediction_refac
CatEek Dec 10, 2024
7675ea7
inference mode lvae
CatEek Dec 15, 2024
36def88
inference fix
CatEek Dec 16, 2024
ec7f0b2
lvae pred func upd
CatEek Dec 18, 2024
78fa078
Merge remote-tracking branch 'origin/main' into splits_prediction_refac
CatEek Dec 18, 2024
0eef8c7
reduce data fix
CatEek Dec 19, 2024
cbb29a8
hdn init configs
CatEek Dec 23, 2024
f514d7f
basic config fixture + test
CatEek Dec 24, 2024
1f2b0e9
out channels test
CatEek Dec 24, 2024
de7a939
hdn lightning init tests
CatEek Dec 24, 2024
4f7de87
hdn trainstep wip
CatEek Dec 25, 2024
5b71750
hdn trainloop test
CatEek Dec 25, 2024
12b42e4
hdn logvar test
CatEek Dec 26, 2024
b9598b0
train/val steps test
CatEek Dec 26, 2024
10e1919
tests pass
CatEek Dec 26, 2024
147c280
wip
CatEek Jan 3, 2025
2bc3525
3d check config
CatEek Jan 4, 2025
9be50db
wip
CatEek Jan 6, 2025
201eea5
Merge remote-tracking branch 'origin/main' into hdn_config
CatEek Jan 22, 2025
9f91543
hdn_configs
CatEek Jan 22, 2025
6fe54a0
conf fixes wip
CatEek Jan 23, 2025
ff50695
hdn conf wip
CatEek Jan 23, 2025
39bd601
hdn conf factory
CatEek Jan 24, 2025
5d9d534
hdn conf factory
CatEek Feb 4, 2025
0c5572f
wip
CatEek Feb 5, 2025
131f9c4
batch unpack fix
CatEek Feb 5, 2025
c2559e0
input shape fix
CatEek Feb 10, 2025
0fcfbdc
train vae test pass
CatEek Feb 11, 2025
4f7aa06
ds patch_transform unpack fix
CatEek Feb 11, 2025
77b5b59
wip
CatEek Feb 11, 2025
197cbfa
Merge remote-tracking branch 'origin/main' into hdn_config
CatEek Feb 11, 2025
51f0e5f
style(pre-commit.ci): auto fixes [...]
pre-commit-ci[bot] Feb 11, 2025
bdc43b3
config wip
CatEek Feb 13, 2025
e69a656
Merge remote-tracking branch 'origin/hdn_config' into hdn_config
CatEek Feb 13, 2025
4c8570a
Merge remote-tracking branch 'origin/main' into hdn_config
CatEek Feb 13, 2025
814ba84
post-merge fixes
CatEek Feb 13, 2025
9e1d95f
style(pre-commit.ci): auto fixes [...]
pre-commit-ci[bot] Feb 13, 2025
fb570a6
rnd device fix for cpu tests
CatEek Feb 13, 2025
e0ede24
ds tests fix
CatEek Feb 13, 2025
1b852dd
pred ds tests fix
CatEek Feb 13, 2025
4e10d1c
careamist train array vae test
CatEek Feb 13, 2025
c93a131
hdn conf test
CatEek Feb 13, 2025
491c7d3
Merge remote-tracking branch 'origin/hdn_config' into hdn_config
CatEek Feb 13, 2025
88f2188
style(pre-commit.ci): auto fixes [...]
pre-commit-ci[bot] Feb 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
)
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, WandbLogger

from careamics.config import Configuration, UNetBasedAlgorithm, load_configuration
from careamics.config import (
Configuration,
UNetBasedAlgorithm,
VAEBasedAlgorithm,
load_configuration,
)
from careamics.config.support import (
SupportedAlgorithm,
SupportedArchitecture,
Expand All @@ -28,6 +33,7 @@
PredictDataModule,
ProgressBarCallback,
TrainDataModule,
VAEModule,
create_predict_datamodule,
)
from careamics.model_io import export_to_bmz, load_pretrained
Expand Down Expand Up @@ -141,6 +147,10 @@ def __init__(
self.model = FCNModule(
algorithm_config=self.cfg.algorithm_config,
)
elif isinstance(self.cfg.algorithm_config, VAEBasedAlgorithm):
self.model = VAEModule(
algorithm_config=self.cfg.algorithm_config,
)
else:
raise NotImplementedError("Architecture not supported.")

Expand Down
4 changes: 4 additions & 0 deletions src/careamics/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"Configuration",
"DataConfig",
"GaussianMixtureNMConfig",
"HDNAlgorithm",
"InferenceConfig",
"LVAELossConfig",
"MultiChannelNMConfig",
Expand All @@ -22,6 +23,7 @@
"VAEBasedAlgorithm",
"algorithm_factory",
"create_care_configuration",
"create_hdn_configuration",
"create_n2n_configuration",
"create_n2v_configuration",
"load_configuration",
Expand All @@ -30,6 +32,7 @@

from .algorithms import (
CAREAlgorithm,
HDNAlgorithm,
N2NAlgorithm,
N2VAlgorithm,
UNetBasedAlgorithm,
Expand All @@ -40,6 +43,7 @@
from .configuration_factories import (
algorithm_factory,
create_care_configuration,
create_hdn_configuration,
create_n2n_configuration,
create_n2v_configuration,
)
Expand Down
2 changes: 2 additions & 0 deletions src/careamics/config/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

__all__ = [
"CAREAlgorithm",
"HDNAlgorithm",
"N2NAlgorithm",
"N2VAlgorithm",
"UNetBasedAlgorithm",
"VAEBasedAlgorithm",
]

from .care_algorithm_model import CAREAlgorithm
from .hdn_algorithm_model import HDNAlgorithm
from .n2n_algorithm_model import N2NAlgorithm
from .n2v_algorithm_model import N2VAlgorithm
from .unet_algorithm_model import UNetBasedAlgorithm
Expand Down
98 changes: 98 additions & 0 deletions src/careamics/config/algorithms/hdn_algorithm_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""HDN algorithm configuration."""

from typing import Literal

from bioimageio.spec.generic.v0_3 import CiteEntry
from pydantic import ConfigDict

from careamics.config.algorithms.vae_algorithm_model import VAEBasedAlgorithm
from careamics.config.architectures import LVAEModel
from careamics.config.loss_model import LVAELossConfig

HDN = "HDN"

HDN_DESCRIPTION = ""
HDN_REF = CiteEntry(
text='Prakash, M., Delbracio, M., Milanfar, P., Jug, F. 2022. "Interpretable '
'Unsupervised Diversity Denoising and Artefact Removal." The International '
"Conference on Learning Representations (ICLR).",
doi="10.1561/2200000056",
)


class HDNAlgorithm(VAEBasedAlgorithm):
"""HDN algorithm configuration."""

model_config = ConfigDict(validate_assignment=True)

algorithm: Literal["hdn"] = "hdn"

loss: LVAELossConfig

model: LVAEModel # TODO add validators

def get_algorithm_friendly_name(self) -> str:
"""
Get the algorithm friendly name.

Returns
-------
str
Friendly name of the algorithm.
"""
return HDN

def get_algorithm_keywords(self) -> list[str]:
"""
Get algorithm keywords.

Returns
-------
list[str]
List of keywords.
"""
return [
"restoration",
"UNet",
"VAE",
"3D" if self.model.is_3D() else "2D",
"CAREamics",
"pytorch",
]

def get_algorithm_references(self) -> str:
"""
Get the algorithm references.

This is used to generate the README of the BioImage Model Zoo export.

Returns
-------
str
Algorithm references.
"""
return HDN_REF.text + " doi: " + HDN_REF.doi

def get_algorithm_citations(self) -> list[CiteEntry]:
"""
Return a list of citation entries of the current algorithm.

This is used to generate the model description for the BioImage Model Zoo.

Returns
-------
List[CiteEntry]
List of citation entries.
"""
return [HDN_REF]

def get_algorithm_description(self) -> str:
"""
Get the algorithm description.

Returns
-------
str
Algorithm description.
"""
return HDN_DESCRIPTION
25 changes: 24 additions & 1 deletion src/careamics/config/algorithms/vae_algorithm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class VAEBasedAlgorithm(BaseModel):
# defined in SupportedAlgorithm
# TODO: Use supported Enum classes for typing?
# - values can still be passed as strings and they will be cast to Enum
algorithm: Literal["musplit", "denoisplit"]
algorithm: Literal["hdn", "musplit", "denoisplit"]

# NOTE: these are all configs (pydantic models)
loss: LVAELossConfig
Expand All @@ -64,6 +64,14 @@ def algorithm_cross_validation(self: Self) -> Self:
Self
The validated model.
"""
# hdn
if self.algorithm == SupportedAlgorithm.HDN:
if self.loss.loss_type != SupportedLoss.HDN:
raise ValueError(
f"Algorithm {self.algorithm} only supports loss `hdn`."
)
if self.model.multiscale_count > 1:
raise ValueError("Algorithm `hdn` does not support multiscale models.")
# musplit
if self.algorithm == SupportedAlgorithm.MUSPLIT:
if self.loss.loss_type != SupportedLoss.MUSPLIT:
Expand Down Expand Up @@ -108,6 +116,12 @@ def output_channels_validation(self: Self) -> Self:
f"Number of output channels ({self.model.output_channels}) must match "
f"the number of noise models ({len(self.noise_model.noise_models)})."
)

if self.algorithm == SupportedAlgorithm.HDN:
assert self.model.output_channels == 1, (
f"Number of output channels ({self.model.output_channels}) must be 1 "
"for algorithm `hdn`."
)
return self

@model_validator(mode="after")
Expand All @@ -127,6 +141,15 @@ def predict_logvar_validation(self: Self) -> Self:
"Gaussian likelihood model `predict_logvar` "
f"({self.gaussian_likelihood.predict_logvar}).",
)
if self.algorithm == SupportedAlgorithm.HDN:
assert (
self.model.predict_logvar is None
), "Model `predict_logvar` must be `None` for algorithm `hdn`."
if self.gaussian_likelihood is not None:
assert self.gaussian_likelihood.predict_logvar is None, (
"Gaussian likelihood model `predict_logvar` must be `None` "
"for algorithm `hdn`."
)
return self

def __str__(self) -> str:
Expand Down
15 changes: 10 additions & 5 deletions src/careamics/config/architectures/lvae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@ class LVAEModel(ArchitectureModel):
model_config = ConfigDict(validate_assignment=True, validate_default=True)

architecture: Literal["LVAE"]
"""Name of the architecture."""

input_shape: list[int] = Field(default=[64, 64], validate_default=True)
"""Shape of the input patch (C, Z, Y, X) or (C, Y, X) if the data is 2D."""

input_shape: list[int] = Field(default=(64, 64), validate_default=True)
"""Shape of the input patch (Z, Y, X) or (Y, X) if the data is 2D."""
encoder_conv_strides: list = Field(default=[2, 2], validate_default=True)

# TODO make this per hierarchy step ?
Expand Down Expand Up @@ -126,6 +124,13 @@ def validate_input_shape(cls, input_shape: list) -> list:
f"Input shape must be greater than 1 in all dimensions"
f"(got {input_shape})."
)

if any(s < 64 for s in input_shape[-2:]):
raise ValueError(
f"Input shape must be greater or equal to 64 in XY dimensions"
f"(got {input_shape})."
)

return input_shape

@field_validator("encoder_n_filters")
Expand Down Expand Up @@ -255,4 +260,4 @@ def is_3D(self) -> bool:
bool
Whether the model is 3D or not.
"""
return self.conv_dims == 3
return len(self.input_shape) == 3
2 changes: 2 additions & 0 deletions src/careamics/config/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from careamics.config.algorithms import (
CAREAlgorithm,
HDNAlgorithm,
N2NAlgorithm,
N2VAlgorithm,
)
Expand All @@ -22,6 +23,7 @@
CAREAlgorithm,
N2NAlgorithm,
N2VAlgorithm,
HDNAlgorithm,
]


Expand Down
Loading
Loading