diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/io/save_model.py b/backend/src/packages/chaiNNer_pytorch/pytorch/io/save_model.py index ba4eb3c7f..9e039614d 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/io/save_model.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/io/save_model.py @@ -21,6 +21,7 @@ class WeightFormat(Enum): PTH = "pth" ST = "safetensors" + PT = "pt" @io_group.register( @@ -42,6 +43,7 @@ class WeightFormat(Enum): option_labels={ WeightFormat.PTH: "PyTorch (.pth)", WeightFormat.ST: "SafeTensors (.safetensors)", + WeightFormat.PT: "TorchScript (.pt)", }, ), ], @@ -58,5 +60,13 @@ def save_model_node( torch.save(model.model.state_dict(), full_path) elif weight_format == WeightFormat.ST: save_file(model.model.state_dict(), full_path) + elif weight_format == WeightFormat.PT: + size = 3 + size += model.size_requirements.get_padding(size, size)[0] + dummy_input = torch.rand(1, model.input_channels, size, size) + dummy_input = dummy_input.to(model.device) + + trace = torch.jit.trace(model.model, example_inputs=dummy_input) + trace.save(full_path) else: raise ValueError(f"Unknown weight format: {weight_format}")