From 36949a16ae337ee9c23cdb777218c3ef888e49c7 Mon Sep 17 00:00:00 2001 From: Jeremy Rand Date: Mon, 23 Sep 2024 18:32:09 +0000 Subject: [PATCH] Save Model: support TorchScript Various software requires models in TorchScript format. As one example, this can be used to convert PyTorch models to ncnn models via PNNX, without using ONNX as an intermediary. --- .../packages/chaiNNer_pytorch/pytorch/io/save_model.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/io/save_model.py b/backend/src/packages/chaiNNer_pytorch/pytorch/io/save_model.py index ba4eb3c7f..9e039614d 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/io/save_model.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/io/save_model.py @@ -21,6 +21,7 @@ class WeightFormat(Enum): PTH = "pth" ST = "safetensors" + PT = "pt" @io_group.register( @@ -42,6 +43,7 @@ class WeightFormat(Enum): option_labels={ WeightFormat.PTH: "PyTorch (.pth)", WeightFormat.ST: "SafeTensors (.safetensors)", + WeightFormat.PT: "TorchScript (.pt)", }, ), ], @@ -58,5 +60,13 @@ def save_model_node( torch.save(model.model.state_dict(), full_path) elif weight_format == WeightFormat.ST: save_file(model.model.state_dict(), full_path) + elif weight_format == WeightFormat.PT: + size = 3 + size += model.size_requirements.get_padding(size, size)[0] + dummy_input = torch.rand(1, model.input_channels, size, size) + dummy_input = dummy_input.to(model.device) + + trace = torch.jit.trace(model.model, example_inputs=dummy_input) + trace.save(full_path) else: raise ValueError(f"Unknown weight format: {weight_format}")