Skip to content

Commit

Permalink
fix undefined path + add test of valid yaml config (#78)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaczmarj authored Jan 21, 2023
1 parent ee4e669 commit 1158030
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
26 changes: 26 additions & 0 deletions tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,32 @@ def test_invalid_modeldefs(modeldef, tmp_path: Path):
Weights.from_yaml(path)


def test_valid_modeldefs(tmp_path: Path):
from wsinfer._modellib.models import Weights

weights_file = tmp_path / "weights.pt"
modeldef = dict(
version="1.0",
name="foo",
architecture="resnet34",
file=str(weights_file),
num_classes=2,
transform=dict(resize_size=224, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
patch_size_pixels=350,
spacing_um_px=0.25,
class_names=["foo", "bar"],
)
path = tmp_path / "foobar.yaml"
with open(path, "w") as f:
yaml.safe_dump(modeldef, f)

with pytest.raises(FileNotFoundError):
Weights.from_yaml(path)

weights_file.touch()
assert Weights.from_yaml(path)


def test_model_registration(tmp_path: Path):
from wsinfer._modellib import models

Expand Down
6 changes: 3 additions & 3 deletions wsinfer/_modellib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __post_init__(self):
raise ValueError("length of class_names must be equal to num_classes")

@staticmethod
def _validate_input(d) -> None:
def _validate_input(d: dict, config_path: Path) -> None:
"""Raise error if invalid input."""

if not isinstance(d, dict):
Expand Down Expand Up @@ -165,7 +165,7 @@ def _validate_input(d) -> None:
if len(d["class_names"]) != d["num_classes"]:
raise ValueError("mismatch between length of class_names and num_classes.")
if "file" in d.keys():
file = Path(path).parent / d["file"]
file = Path(config_path).parent / d["file"]
file = file.resolve()
if not file.exists():
raise FileNotFoundError(f"'file' not found: {file}")
Expand All @@ -176,7 +176,7 @@ def from_yaml(cls, path):

with open(path) as f:
d = yaml.safe_load(f)
cls._validate_input(d)
cls._validate_input(d, config_path=Path(path))

transform = PatchClassification(
resize_size=d["transform"]["resize_size"],
Expand Down

0 comments on commit 1158030

Please sign in to comment.