TorchScript,它是PyTorch模型(子类nn.Module)的中间表示,可以在高性能环境(例如C ++)中运行。
转换的方法有两种,一种是通过追踪转换另一种是通过注释转换。
1、追踪转换
常用的是追踪转换,但是这种方法有一个缺点,就是输入尺寸固定。
官网给出的追踪的例子:
import torch
import torchvision
# An instance of your model.
model = torchvision.models.resnet18()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
The traced ScriptModule
can now be evaluated identically to a regular PyTorch module:
In[1]: output = traced_script_module(torch.ones(1, 3, 224, 224))
In[2]: output[0, :5]
Out[2]: tensor([-0.2698, -0.0381, 0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)
网上找到最多的也是追踪转换的例子。
2、注释转换
官网例子
Converting to Torch Script via Annotation
Under certain circumstances, such as if your model employs particular forms of control flow, you may want to write your model in Torch Script directly and annotate your model accordingly. For example, say you have the following vanilla Pytorch model:
import torch
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
Because the forward
method of this module uses control flow that is dependent on the input, it is not suitable for tracing. Instead, we can convert it to a ScriptModule
. In order to convert the module to the ScriptModule
, one needs to compile the module with torch.jit.script
as follows:
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
my_module = MyModule(10,20)
sm = torch.jit.script(my_module)
If you need to exclude some methods in your nn.Module
because they use Python features that TorchScript doesn’t support yet, you could annotate those with @torch.jit.ignore
my_module
is an instance of ScriptModule
that is ready for serialization.
Step 2: Serializing Your Script Module to a File
Once you have a ScriptModule
in your hands, either from tracing or annotating a PyTorch model, you are ready to serialize it to a file. Later on, you’ll be able to load the module from this file in C++ and execute it without any dependency on Python. Say we want to serialize the ResNet18
model shown earlier in the tracing example. To perform this serialization, simply call save on the module and pass it a filename:
traced_script_module.save("traced_resnet_model.pt")
This will produce a traced_resnet_model.pt
file in your working directory. If you also would like to serialize my_module
, call my_module.save("my_module_model.pt")
We have now officially left the realm of Python and are ready to cross over to the sphere of C++.