上一次正儿八经写博客是今年 2 月,5 月做了个比赛总结,其余的博客竟然都是刷题和算法,实属无聊。艰难的日子已经过去,准备学点模型部署相关的东西以及参与一个实际的开源项目,争取数据、算法和工程全链路打通。众所周知,对于一个不是很常用的东西,学完就忘,如 spark, Go
等学过的但很少用的东西,已经被我抛到九霄云外了。所以,这次学完模型的 trace
之后,尝试部署一些能实际运行的软件。
基本概念
TorchScript
是 PyTorch
的 JIT
实现。JIT
全程是 Just In Time Compilation,也就是即使编译。在深度学习中 JIT
的思想更是随处可见,最明显的例子就是 Keras
框架的 model.compile 创建的静态图。
- 静态图需要先构建再运行,优势是在运行前可以对图结构进行优化,比如常数折叠、算子融合等,可以获得更快的前向运算速度。缺点也很明显,就是只有在计算图运行起来之后,才能看到变量的值,像
TensorFlow1.x
中的session.run
那样。 - 动态图是一边运行一边构建,优势是可以在搭建网络的时候看见变量的值,便于检查。缺点是前向运算不好优化,因为根本不知道下一步运算要算什么。动态图模型通过牺牲一些高级特性来换取易用性。
那么那到底 JIT
有哪些特性,使得 torch
这样的动态图框架也要走 JIT
这条路呢?或者说在什么情况下不得不用到 JIT
呢?下面主要通过介绍 TorchScript
来分析 JIT
到底带来了哪些好处。
JIT
是 Python
和 C++
的桥梁,我们可以使用 Python
训练模型,然后通过 JIT
将模型转为语言无关的模块,从而让 C++
可以非常方便得调用,从此「使用 Python
训练模型,使用 C++
将模型部署到生产环境」对 PyTorch
来说成为了一件很容易的事。而因为使用了 C++
,我们现在几乎可以把 PyTorch
模型部署到任意平台和设备上:树莓派、iOS、Android 等等。不然每次都要通过 python
调用模型,性能会大打折扣。
既然是为部署生产所提供的特性,那免不了在性能上面做了极大的优化,如果推断的场景对性能要求高,则可以考虑将模型(torch.nn.Module
)转换为 TorchScript Module
,再进行推断。有两种方式可以转换:
- 使用
TorchScript Module
的更简单的办法是使用Tracing
,Tracing
可以直接将PyTorch
模型(torch.nn.Module
)转换成TorchScript Module
。「trace
」顾名思义,就是需要提供一个「输入」来让模型forward
一遍,以通过该输入的流转路径,获得图的结构。这种方式对于forward
逻辑简单的模型来说非常实用,但如果forward
里面本身夹杂了很多流程控制语句,就会存在问题,因为同一个输入不可能遍历到所有的逻辑分枝。而没有被经过的分支就不会被trace
。 - 可以直接使用
TorchScript Language
来定义一个PyTorch JIT Module
,然后用torch.jit.script
来将他转换成TorchScript Module
并保存成文件。而TorchScript Language
本身也是Python
代码,所以可以直接写在Python
文件中。对于TensorFlow
我们知道不能直接使用Python
中的if
等语句来做条件控制,而是需要用tf.cond
,但对于TorchScript
我们依然能够直接使用if
和for
等条件控制语句,所以即使是在静态图上,PyTorch
依然秉承了「易用」的特性。
简单例子
trace 方法
首先定义一个简单的模型:
1 | import torch |
我们可以绑定输入对模型进行 trace
:
1 | import torch |
可以看到没有出现 if-else
的分支, trace
做的是:运行代码,记录出现的运算,构建 ScriptModule
,但是控制流就丢失了。然后流程丢失并不是好事,在 trace
只会对一个输入进行处理的情况下,对不同的输入得到的结果是不一样的,因为输入只会满足一个分支,因此 trace
的程序也只包含一个分支。
1 | import torch |
因此,我们认为这样的 trace
没有泛化能力。而这种现象普遍发生在动态控制流中,即:具体执行哪个算子取决于输入的数据。
if x[0] == 4: x += 1
是动态控制流model: nn.Sequential = ... [m(x) for x in model]
不是- 不是
1
2
3
4
5
6
7
8class A(nn.Module):
backbone: nn.Module
head: Optiona[nn.Module]
def forward(self, x):
x = self.backbone(x)
if self.head is not None:
x = self.head(x)
return x
在之后的文章中,会介绍如何使 trace
具备泛化能力。
script 方法
script
方法直接分析 python
代码进行转换:使用他们提供的 script
编译器,将 python
的代码进行语法分析,并重新解释为 TorchScript
。
1 | import torch |
TorchScript
代码可以被它自己的解释器(一个受限的Python
解释器)调用。这个解释器不需要获得全局解释锁GIL,这样很多请求可以同时处理。- 这个格式可以让我们保存模型到硬盘上,在另一个环境中加载,例如服务器,也可以使用非
python
的语言。 TorchScript
提供的表示可以做编译器优化,做到更有效地执行。TorchScript
可以与其他后端/设备运行时进行对接,他们只需要处理整个项目,无需关心细节运算。
Trace 和 Script 谁更好?
通过上文我们可以了解到:
trace
只记录走过的tensor
和对tensor
的操作,不会记录任何控制流信息,如if
条件句和循环。因为没有记录控制流的另外的路,也没办法对其进行优化。好处是trace
深度嵌入python
语言,复用了所有python
的语法,在计算流中记录数据流。script
会去理解所有的code
,真正像一个编译器一样去进行词法分析语法分析句法分析,形成AST
树,最后再将AST
树线性化。script
相当于一个嵌入在Python/Pytorch
的DSL
,其语法只是Pytorch
语法的子集,这意味着存在一些op
和语法script
不支持,这样在编译的时候就会遇到问题。此外,script
的编译优化方式更像是CPU
上的传统编译优化,重点对于图进行硬件无关优化,并对if
、loop
进行优化。
在大模型的部署上 trace
更好,因为可以有效的优化复杂的计算图,如下所示:
1 | class A(nn.Module): |
因为 script
试图忠实地表示 Python
代码,所以即使其中一些是不必要的。例如:并不能对 Python
代码中的某些循环或数据结构进行优化。如上例,所以它实际上有变通方法,或者循环可能会在以后的优化过程中得到优化。但关键是:这个编译器并不总是足够聪明。对于复杂的模型, script
可能会生成一个具有不必要复杂性且难以优化的计算图。
tracing
有许多优点,事实上,在 Facebook/Meta
部署的分割和检测模型中,tracing
是默认的选择,仅当必要的时候使用 scripting
。因为 trace
不会破坏代码质量,可以结合 script
来避免一些限制。
python
是一个很大很动态的语言,编译器最多只能支持其语法功能和内置函数的子集,同理,script
也不例外。这个编译器支持 Python
的哪个子集?一个粗略的答案是:编译器对最基本的语法有很好的支持,但对任何更复杂的东西(类、内置函数、动态类型等)的支持度很低或者不支持。但并没有明确的答案:即使是编译器的开发者,通常也需要运行代码,看看能不能编译去判断是否支持。
所以不完整的 Python
编译器限制了用户编写代码的方式。尽管没有明确的约束列表,但可以从经验中看出它们对大型项目的影响:script
的问题会影响代码质量。很多项目只停留在了代码能 script
成功这一层面,使用基础语法,没有自定义类型,没有继承,没有内置函数,没有 lambda
等等的高级特性。因为这些高级的功能编译器并不支持或者部分支持,就会导致在某些情况下成功,但在其他情况下失败。而且由于没有明确的规范哪些是被支持的,因此用户无法推理或解决故障。因此,最终用户会仅仅停留在代码成功搬移,而不考虑可维护性和性能问题,会导致开发者因为害怕报错而停止进一步的探索高级特性。
如此下去,代码质量可能会严重恶化:垃圾代码开始积累,因为优良的代码有时无法编译。此外,由于编译器的语法限制,无法轻松进行抽象以清理垃圾代码。该项目的可维护状况逐渐走下坡路。如果认为 script
似乎适用于我的项目,基于过去在一些支持 script
的项目中的经验,我可能会出于以下原因建议不要这样做:
- 编译成功可能比你想象的更脆弱(除非将自己限制在基本语法上):你的代码现在可能恰好可以编译,但是有一天你会在模型中添加一些更改,并发现编译失败;
- 基本语法是不够的:即使目前你的项目似乎不需要更复杂的抽象和继承,但如果预计项目会增长,未来将需要更多的语言特性。
以多任务检测器为例:
- 可能有 10 个输入,因此最好使用一些结构/类。
- 检测器有许多架构选择,这使得继承很有用。
- 大型、不断增长的项目肯定需要不断发展的抽象来保持可维护性。
因此,这个问题的现状是:script
迫使你编写垃圾的代码,因此我们仅在必要时使用它。
Trace 细节
trace
让模型的 trace
更清楚,对代码质量有很少的影响。
如果模型不是以 Pytorch
格式表示的计算图,则 script
和 trace
都不起作用。例如,如果模型具有 DataParallel
子模块,或者如果模型将张量转换为 numpy
数组并调用 OpenCV
函数等,则必须对其进行重构。除了这个明显的限制之外,对 trace
只有两个额外的要求:
输入/输出格式是
Tensor
类型时才能被trace
。但是,这里的格式约束不适用于子模块:子模块可以使用任何输入/输出格式:类、kwargs
以及Python
支持的任何内容。格式要求仅适用于最外层的模型,因此很容易解决。如果模型使用更丰富的格式,只需围绕它创建一个简单的包装器,它可以与Tuple[Tensor]
相互转换。shape
。tensor.size(0)
是eager
模式下的整数,但它是tracing mode
下的tensor
。这个差异在trace
时是必要的,shape
的计算可以被捕获为计算图中的算子。由于不同的返回类型,如果返回的一部分是shape
是整数则无法trace
,这通常可以简单的解决。此外,一个有用的函数是torch.jit.is_tracing
,它检查代码是否在trace
模式下执行。
我们来看个例子:
1 | 1), torch.rand(2) a, b = torch.rand( |
在 trace f2
函数时,lex(x)
是一个定值而非 tensor
,这样在传入其他长度的数据时就回报错。除了 len()
,这个问题也可能出现在:
.item()
将张量转换为int/float
。- 将
Torch
类型转换为numpy/python
原语的任何其他代码。
tensor.size()
在 trace
期间返回 Tensor
,以便在图中捕获形状计算。用户应避免意外将张量形状转换为常量。使用 tensor.size(0)
而不是 len(tensor)
,因为后者是一个 int
。这个函数对于将大小转换为张量很有用,在 trace
和 eager
模式下都可以使用。对于自定义类,实现 .size()
方法或使用 .__len__()
而不是 len()
,不要通过 int()
转换大小,因为它们会捕获常量。
这就是 trace
所需要的一切。最重要的是,模型实现中允许使用任何 Python
语法,因为 trace
根本不关心语法。
Trace 的泛化问题
Trace 和 Script 混合
1 | def f(x): |
注意这种代码在 trace
时不会报错,只有 warning
的输出,因此我们要特别关注。trace
和 script
都有各自的问题,最好的方法是混合使用他们。避免影响代码质量,主要的部分进行 trace
,必要时进行 script
。如果有一个 module
里面有很多选择,但是我们不希望在 TorchScript
里出现,那么应该使用 tracing
而不是 scripting
,这个时候,trace
将内联 script
模块的代码。
1 | import torch |
我们简化一下:
1 | model.submodule = torch.jit.script(model.submodule) |
对于不能正确 trace
的子模块,可以进行 script
处理。但是并不推荐,更建议使用 @script_if_tracing
,因为这样修改 script
仅限于子模块的内部,而不影响模块的接口。使用 @script_if_tracing
装饰器,在 torch.jit.trace
时,@script_if_tracing
装饰器可以通过 script
编译。通常,这只需要对前向逻辑进行少量重构,以分离需要编译的部分(具有控制流的部分):
1 | def forward(self, ...): |
只 script
需要的部分,代码质量相对于全部 script
被破坏的很少,被 @script_if_tracing
装饰的函数必须是不包含 tensor
模块运算的纯函数。因此,有时需要进行更多重构:
1 | # Before: |
同样的,我们可以在 script
中嵌套 trace
:
1 | model.submodule = torch.jit.trace(model.submodule, submodule_inputs) |
这里的子模块是 trace
,但是实际中并不常用,因为会影响子模块的推理(当且仅当子模块的输入和输出都是 tensor
时才适用),这是很大的限制。但是 trace
作为子模块的时候也有很试用的场景:
1 | class A(nn.Module): |
@script_if_tracing
不能处理这样的控制流,因为它只支持纯函数。如果子模块很复杂不能被 script
,使用 trace
trace
子模块是很好的选择,这里就是 self.submodule2
和 self.submodule1
,类 A
还是要 script
的。
Script 优势
事实上,对于大多数视觉模型,动态控制流仅在少数易于编写 script
的子模块中需要。script
相对于 trace
,有两个有点:
- 一个数据有很多属性的控制流,
trace
无法处理 trace
只支持forward
方法,script
支持更多的方法
实际上,上述两个功能都在做同样的事情:它们允许以不同的方式使用导出的模型,即根据调用者的请求执行不同的运算符序列。下面是一个这样的特性很有用的示例场景:如果 Detector
是 script
化,调用者可以改变它的 do_keypoint
属性来控制它的行为,或者如果需要直接调用 predict_keypoint
方法。
1 | class Detector(nn.Module): |
这种要求并不常见。但是如果需要,如何在 trace
中实现这一点?我有一个不是很优雅的解决方案:Tracing
只能捕获一个序列的算子,所以自然的方式是对模型进行两次 Tracing
:
1 | det1 = torch.jit.trace(Detector(do_keypoint=True), inputs) |
然后我们可以为它们的模型设置别名(以不重复存储),并将两个 trace
合并到一个模块中以编写 script
:
1 | det2.submodule.weight = det1.submodule.weight |
单元测试
还可以使用单元测试来判断 trace
是否成功:
1 | assert allclose(torch.jit.trace(model, input1)(input2), model(input2)) |
程序优化
此外,还可以通过优化程序,避免掉不必要的特殊情况:
1 | if x.numel() > 0: |
设备
此外还需要注意设备问题,在 trace
期间会记录使用的设备,而 trace
不会对不同的设备进行泛化,但是部署时都会有固定的设备,这个问题不用担心。
1 | def f(x): |
结论
trace
有明显的局限性:本文大部分时间都在讨论 trace
的局限性以及如何解决它们。我实际上认为这是 trace
的优点:它有明确的限制和解决方案,所以你可以推断它是否有效。相反, script
更像是一个黑匣子:在尝试之前没有人知道它是否有效。
trace
具有较小的代码破坏范围: trace
和 script
都会影响代码的编写方式,但 trace
的代码破坏范围要小得多,并且造成的损害要小得多:
- 它限制了输入/输出格式,但仅限于最外层的模块。
- 在
trace
中混合script
,但可以只更改受影响模块的内部实现,而不是它们的接口。
另一方面, script
对以下方面有影响:
- 涉及的每个模块和子模块的接口,接口需要高级语法特性,针对接口编程时,千万别在接口设计上妥协。
- 这也可能最终影响训练,因为接口通常在训练和推理之间共享。
这也是为什么 script
会对代码质量造成很大损害的原因。Detectron2
支持 script
,但不推荐其他大型项目以可 script
且不丢失抽象为目标,因为这实在有点难度,除非它们也能像阿里巴巴那样得到 PyTorch
团队的支持。
PyTorch
深受用户喜爱,最重要的是编写 Python
控制流。但是 Python
的其他语法也很重要。如果能够编写 Python
控制流( 使用 script
)意味着失去其他优秀的语法,我宁愿放弃编写 Python
控制流的能力。事实上,如果 PyTorch
对 Python
控制流不那么执着,并且像这样(类似于 tf.cond
的 API
)为我提供了诸如 torch.cond
之类的符号控制流:
1 | def f(x): |
然后 f
可以正确 trace
,不再需要担心 script
。
保存和加载模型
1 | traced.save('wrapped_rnn.pt') |