Skip to content

Commit

Permalink
Save Model: support TorchScript
Browse files Browse the repository at this point in the history
Various software requires models in TorchScript format. As one example,
this can be used to convert PyTorch models to ncnn models via PNNX,
without using ONNX as an intermediary.
  • Loading branch information
Jeremy Rand committed Sep 23, 2024
1 parent 89efb3f commit 36949a1
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions backend/src/packages/chaiNNer_pytorch/pytorch/io/save_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
class WeightFormat(Enum):
PTH = "pth"
ST = "safetensors"
PT = "pt"


@io_group.register(
Expand All @@ -42,6 +43,7 @@ class WeightFormat(Enum):
option_labels={
WeightFormat.PTH: "PyTorch (.pth)",
WeightFormat.ST: "SafeTensors (.safetensors)",
WeightFormat.PT: "TorchScript (.pt)",
},
),
],
Expand All @@ -58,5 +60,13 @@ def save_model_node(
torch.save(model.model.state_dict(), full_path)
elif weight_format == WeightFormat.ST:
save_file(model.model.state_dict(), full_path)
elif weight_format == WeightFormat.PT:
size = 3
size += model.size_requirements.get_padding(size, size)[0]
dummy_input = torch.rand(1, model.input_channels, size, size)
dummy_input = dummy_input.to(model.device)

trace = torch.jit.trace(model.model, example_inputs=dummy_input)
trace.save(full_path)
else:
raise ValueError(f"Unknown weight format: {weight_format}")

0 comments on commit 36949a1

Please sign in to comment.