diff --git a/tests/test_all.py b/tests/test_all.py index e5c66b7..fa9c101 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -1082,3 +1082,27 @@ def test_issue_97(tmp_path: Path, tiff_image: Path): assert result.exit_code == 0 metas = list(results_dir.glob("run_metadata_*.json")) assert len(metas) == 2 + + +def test_issue_125(tmp_path: Path): + from wsinfer.cli.infer import _get_info_for_save + from wsinfer._modellib.models import Weights + from wsinfer._modellib.transforms import PatchClassification + + w = Weights( + name="foo", + architecture="resnet34", + # We are testing whether we can still save if file is a Path instance. + file=Path(__file__), + num_classes=1, + transform=PatchClassification( + resize_size=299, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5) + ), + patch_size_pixels=350, + spacing_um_px=0.25, + class_names=["tumor"], + ) + + info = _get_info_for_save(w) + with open(tmp_path / "foo.json", "w") as f: + json.dump(info, f) diff --git a/wsinfer/_modellib/models.py b/wsinfer/_modellib/models.py index 12bc096..76b4e84 100644 --- a/wsinfer/_modellib/models.py +++ b/wsinfer/_modellib/models.py @@ -70,7 +70,7 @@ class Weights: class_names: List[str] url: Optional[str] = None url_file_name: Optional[str] = None - file: Optional[str] = None + file: Optional[Union[str, Path]] = None metadata: Optional[Dict[str, Any]] = None def __post_init__(self): diff --git a/wsinfer/cli/infer.py b/wsinfer/cli/infer.py index 642c1df..0fea2f8 100644 --- a/wsinfer/cli/infer.py +++ b/wsinfer/cli/infer.py @@ -154,6 +154,9 @@ def get_stdout(args) -> str: weights_file = str( Path(torch.hub.get_dir()) / "checkpoints" / weights.url_file_name ) + else: + # Weights file could have been a pathlib.Path object. + weights_file = str(weights_file) return { "model_weights": {