上一次正儿八经写博客是今年 2 月,5 月做了个比赛总结,其余的博客竟然都是刷题和算法,实属无聊。艰难的日子已经过去,准备学点模型部署相关的东西以及参与一个实际的开源项目,争取数据、算法和工程全链路打通。众所周知,对于一个不是很常用的东西,学完就忘,如 spark, Go 等学过的但很少用的东西,已经被我抛到九霄云外了。所以,这次学完模型的 trace 之后,尝试部署一些能实际运行的软件。
基本概念
TorchScript 是 PyTorch 的 JIT 实现。JIT 全程是 Just In Time Compilation,也就是即使编译。在深度学习中 JIT 的思想更是随处可见,最明显的例子就是 Keras 框架的 model.compile 创建的静态图。
- 静态图需要先构建再运行,优势是在运行前可以对图结构进行优化,比如常数折叠、算子融合等,可以获得更快的前向运算速度。缺点也很明显,就是只有在计算图运行起来之后,才能看到变量的值,像
TensorFlow1.x中的session.run那样。 - 动态图是一边运行一边构建,优势是可以在搭建网络的时候看见变量的值,便于检查。缺点是前向运算不好优化,因为根本不知道下一步运算要算什么。动态图模型通过牺牲一些高级特性来换取易用性。
那么那到底 JIT 有哪些特性,使得 torch 这样的动态图框架也要走 JIT 这条路呢?或者说在什么情况下不得不用到 JIT 呢?下面主要通过介绍 TorchScript 来分析 JIT 到底带来了哪些好处。
JIT 是 Python 和 C++ 的桥梁,我们可以使用 Python 训练模型,然后通过 JIT 将模型转为语言无关的模块,从而让 C++ 可以非常方便得调用,从此「使用 Python 训练模型,使用 C++ 将模型部署到生产环境」对 PyTorch 来说成为了一件很容易的事。而因为使用了 C++,我们现在几乎可以把 PyTorch 模型部署到任意平台和设备上:树莓派、iOS、Android 等等。不然每次都要通过 python 调用模型,性能会大打折扣。
既然是为部署生产所提供的特性,那免不了在性能上面做了极大的优化,如果推断的场景对性能要求高,则可以考虑将模型(torch.nn.Module)转换为 TorchScript Module,再进行推断。有两种方式可以转换:
- 使用
TorchScript Module的更简单的办法是使用Tracing,Tracing可以直接将PyTorch模型(torch.nn.Module)转换成TorchScript Module。「trace」顾名思义,就是需要提供一个「输入」来让模型forward一遍,以通过该输入的流转路径,获得图的结构。这种方式对于forward逻辑简单的模型来说非常实用,但如果forward里面本身夹杂了很多流程控制语句,就会存在问题,因为同一个输入不可能遍历到所有的逻辑分枝。而没有被经过的分支就不会被trace。 - 可以直接使用
TorchScript Language来定义一个PyTorch JIT Module,然后用torch.jit.script来将他转换成TorchScript Module并保存成文件。而TorchScript Language本身也是Python代码,所以可以直接写在Python文件中。对于TensorFlow我们知道不能直接使用Python中的if等语句来做条件控制,而是需要用tf.cond,但对于TorchScript我们依然能够直接使用if和for等条件控制语句,所以即使是在静态图上,PyTorch依然秉承了「易用」的特性。
简单例子
trace 方法
首先定义一个简单的模型:
1 | import torch |
我们可以绑定输入对模型进行 trace:
1 | import torch |
可以看到没有出现 if-else 的分支, trace 做的是:运行代码,记录出现的运算,构建 ScriptModule,但是控制流就丢失了。然后流程丢失并不是好事,在 trace 只会对一个输入进行处理的情况下,对不同的输入得到的结果是不一样的,因为输入只会满足一个分支,因此 trace 的程序也只包含一个分支。
1 | import torch |
因此,我们认为这样的 trace 没有泛化能力。而这种现象普遍发生在动态控制流中,即:具体执行哪个算子取决于输入的数据。
if x[0] == 4: x += 1是动态控制流model: nn.Sequential = ... [m(x) for x in model]不是- 不是
1
2
3
4
5
6
7
8class A(nn.Module):
backbone: nn.Module
head: Optiona[nn.Module]
def forward(self, x):
x = self.backbone(x)
if self.head is not None:
x = self.head(x)
return x
在之后的文章中,会介绍如何使 trace 具备泛化能力。
script 方法
script 方法直接分析 python 代码进行转换:使用他们提供的 script 编译器,将 python 的代码进行语法分析,并重新解释为 TorchScript。
1 | import torch |
TorchScript代码可以被它自己的解释器(一个受限的Python解释器)调用。这个解释器不需要获得全局解释锁GIL,这样很多请求可以同时处理。- 这个格式可以让我们保存模型到硬盘上,在另一个环境中加载,例如服务器,也可以使用非
python的语言。 TorchScript提供的表示可以做编译器优化,做到更有效地执行。TorchScript可以与其他后端/设备运行时进行对接,他们只需要处理整个项目,无需关心细节运算。
Trace 和 Script 谁更好?
通过上文我们可以了解到:
trace只记录走过的tensor和对tensor的操作,不会记录任何控制流信息,如if条件句和循环。因为没有记录控制流的另外的路,也没办法对其进行优化。好处是trace深度嵌入python语言,复用了所有python的语法,在计算流中记录数据流。script会去理解所有的code,真正像一个编译器一样去进行词法分析语法分析句法分析,形成AST树,最后再将AST树线性化。script相当于一个嵌入在Python/Pytorch的DSL,其语法只是Pytorch语法的子集,这意味着存在一些op和语法script不支持,这样在编译的时候就会遇到问题。此外,script的编译优化方式更像是CPU上的传统编译优化,重点对于图进行硬件无关优化,并对if、loop进行优化。
在大模型的部署上 trace 更好,因为可以有效的优化复杂的计算图,如下所示:
1 | class A(nn.Module): |
因为 script 试图忠实地表示 Python 代码,所以即使其中一些是不必要的。例如:并不能对 Python 代码中的某些循环或数据结构进行优化。如上例,所以它实际上有变通方法,或者循环可能会在以后的优化过程中得到优化。但关键是:这个编译器并不总是足够聪明。对于复杂的模型, script 可能会生成一个具有不必要复杂性且难以优化的计算图。
tracing 有许多优点,事实上,在 Facebook/Meta 部署的分割和检测模型中,tracing 是默认的选择,仅当必要的时候使用 scripting。因为 trace 不会破坏代码质量,可以结合 script 来避免一些限制。
python 是一个很大很动态的语言,编译器最多只能支持其语法功能和内置函数的子集,同理,script 也不例外。这个编译器支持 Python 的哪个子集?一个粗略的答案是:编译器对最基本的语法有很好的支持,但对任何更复杂的东西(类、内置函数、动态类型等)的支持度很低或者不支持。但并没有明确的答案:即使是编译器的开发者,通常也需要运行代码,看看能不能编译去判断是否支持。
所以不完整的 Python 编译器限制了用户编写代码的方式。尽管没有明确的约束列表,但可以从经验中看出它们对大型项目的影响:script 的问题会影响代码质量。很多项目只停留在了代码能 script 成功这一层面,使用基础语法,没有自定义类型,没有继承,没有内置函数,没有 lambda 等等的高级特性。因为这些高级的功能编译器并不支持或者部分支持,就会导致在某些情况下成功,但在其他情况下失败。而且由于没有明确的规范哪些是被支持的,因此用户无法推理或解决故障。因此,最终用户会仅仅停留在代码成功搬移,而不考虑可维护性和性能问题,会导致开发者因为害怕报错而停止进一步的探索高级特性。
如此下去,代码质量可能会严重恶化:垃圾代码开始积累,因为优良的代码有时无法编译。此外,由于编译器的语法限制,无法轻松进行抽象以清理垃圾代码。该项目的可维护状况逐渐走下坡路。如果认为 script 似乎适用于我的项目,基于过去在一些支持 script 的项目中的经验,我可能会出于以下原因建议不要这样做:
- 编译成功可能比你想象的更脆弱(除非将自己限制在基本语法上):你的代码现在可能恰好可以编译,但是有一天你会在模型中添加一些更改,并发现编译失败;
- 基本语法是不够的:即使目前你的项目似乎不需要更复杂的抽象和继承,但如果预计项目会增长,未来将需要更多的语言特性。
以多任务检测器为例:
- 可能有 10 个输入,因此最好使用一些结构/类。
- 检测器有许多架构选择,这使得继承很有用。
- 大型、不断增长的项目肯定需要不断发展的抽象来保持可维护性。
因此,这个问题的现状是:script 迫使你编写垃圾的代码,因此我们仅在必要时使用它。
Trace 细节
trace 让模型的 trace 更清楚,对代码质量有很少的影响。
如果模型不是以 Pytorch 格式表示的计算图,则 script 和 trace 都不起作用。例如,如果模型具有 DataParallel 子模块,或者如果模型将张量转换为 numpy 数组并调用 OpenCV 函数等,则必须对其进行重构。除了这个明显的限制之外,对 trace 只有两个额外的要求:
输入/输出格式是
Tensor类型时才能被trace。但是,这里的格式约束不适用于子模块:子模块可以使用任何输入/输出格式:类、kwargs以及Python支持的任何内容。格式要求仅适用于最外层的模型,因此很容易解决。如果模型使用更丰富的格式,只需围绕它创建一个简单的包装器,它可以与Tuple[Tensor]相互转换。shape。tensor.size(0)是eager模式下的整数,但它是tracing mode下的tensor。这个差异在trace时是必要的,shape的计算可以被捕获为计算图中的算子。由于不同的返回类型,如果返回的一部分是shape是整数则无法trace,这通常可以简单的解决。此外,一个有用的函数是torch.jit.is_tracing,它检查代码是否在trace模式下执行。
我们来看个例子:
1 | a, b = torch.rand(1), torch.rand(2) |
在 trace f2 函数时,lex(x) 是一个定值而非 tensor,这样在传入其他长度的数据时就回报错。除了 len(),这个问题也可能出现在:
.item()将张量转换为int/float。- 将
Torch类型转换为numpy/python原语的任何其他代码。
tensor.size() 在 trace 期间返回 Tensor,以便在图中捕获形状计算。用户应避免意外将张量形状转换为常量。使用 tensor.size(0) 而不是 len(tensor),因为后者是一个 int。这个函数对于将大小转换为张量很有用,在 trace 和 eager 模式下都可以使用。对于自定义类,实现 .size() 方法或使用 .__len__() 而不是 len(),不要通过 int() 转换大小,因为它们会捕获常量。
这就是 trace 所需要的一切。最重要的是,模型实现中允许使用任何 Python 语法,因为 trace 根本不关心语法。
Trace 的泛化问题
Trace 和 Script 混合
1 | def f(x): |
注意这种代码在 trace 时不会报错,只有 warning 的输出,因此我们要特别关注。trace 和 script 都有各自的问题,最好的方法是混合使用他们。避免影响代码质量,主要的部分进行 trace,必要时进行 script。如果有一个 module 里面有很多选择,但是我们不希望在 TorchScript 里出现,那么应该使用 tracing 而不是 scripting,这个时候,trace 将内联 script 模块的代码。
1 | import torch |
我们简化一下:
1 | model.submodule = torch.jit.script(model.submodule) |
对于不能正确 trace 的子模块,可以进行 script 处理。但是并不推荐,更建议使用 @script_if_tracing,因为这样修改 script 仅限于子模块的内部,而不影响模块的接口。使用 @script_if_tracing 装饰器,在 torch.jit.trace 时,@script_if_tracing 装饰器可以通过 script 编译。通常,这只需要对前向逻辑进行少量重构,以分离需要编译的部分(具有控制流的部分):
1 | def forward(self, ...): |
只 script 需要的部分,代码质量相对于全部 script 被破坏的很少,被 @script_if_tracing 装饰的函数必须是不包含 tensor 模块运算的纯函数。因此,有时需要进行更多重构:
1 | # Before: |
同样的,我们可以在 script 中嵌套 trace:
1 | model.submodule = torch.jit.trace(model.submodule, submodule_inputs) |
这里的子模块是 trace,但是实际中并不常用,因为会影响子模块的推理(当且仅当子模块的输入和输出都是 tensor 时才适用),这是很大的限制。但是 trace 作为子模块的时候也有很试用的场景:
1 | class A(nn.Module): |
@script_if_tracing 不能处理这样的控制流,因为它只支持纯函数。如果子模块很复杂不能被 script,使用 trace trace 子模块是很好的选择,这里就是 self.submodule2 和 self.submodule1,类 A 还是要 script 的。
Script 优势
事实上,对于大多数视觉模型,动态控制流仅在少数易于编写 script 的子模块中需要。script 相对于 trace,有两个有点:
- 一个数据有很多属性的控制流,
trace无法处理 trace只支持forward方法,script支持更多的方法
实际上,上述两个功能都在做同样的事情:它们允许以不同的方式使用导出的模型,即根据调用者的请求执行不同的运算符序列。下面是一个这样的特性很有用的示例场景:如果 Detector 是 script 化,调用者可以改变它的 do_keypoint 属性来控制它的行为,或者如果需要直接调用 predict_keypoint 方法。
1 | class Detector(nn.Module): |
这种要求并不常见。但是如果需要,如何在 trace 中实现这一点?我有一个不是很优雅的解决方案:Tracing 只能捕获一个序列的算子,所以自然的方式是对模型进行两次 Tracing:
1 | det1 = torch.jit.trace(Detector(do_keypoint=True), inputs) |
然后我们可以为它们的模型设置别名(以不重复存储),并将两个 trace 合并到一个模块中以编写 script:
1 | det2.submodule.weight = det1.submodule.weight |
单元测试
还可以使用单元测试来判断 trace 是否成功:
1 | assert allclose(torch.jit.trace(model, input1)(input2), model(input2)) |
程序优化
此外,还可以通过优化程序,避免掉不必要的特殊情况:
1 | if x.numel() > 0: |
设备
此外还需要注意设备问题,在 trace 期间会记录使用的设备,而 trace 不会对不同的设备进行泛化,但是部署时都会有固定的设备,这个问题不用担心。
1 | def f(x): |
结论
trace 有明显的局限性:本文大部分时间都在讨论 trace 的局限性以及如何解决它们。我实际上认为这是 trace 的优点:它有明确的限制和解决方案,所以你可以推断它是否有效。相反, script 更像是一个黑匣子:在尝试之前没有人知道它是否有效。
trace 具有较小的代码破坏范围: trace 和 script 都会影响代码的编写方式,但 trace 的代码破坏范围要小得多,并且造成的损害要小得多:
- 它限制了输入/输出格式,但仅限于最外层的模块。
- 在
trace中混合script,但可以只更改受影响模块的内部实现,而不是它们的接口。
另一方面, script 对以下方面有影响:
- 涉及的每个模块和子模块的接口,接口需要高级语法特性,针对接口编程时,千万别在接口设计上妥协。
- 这也可能最终影响训练,因为接口通常在训练和推理之间共享。
这也是为什么 script 会对代码质量造成很大损害的原因。Detectron2 支持 script,但不推荐其他大型项目以可 script 且不丢失抽象为目标,因为这实在有点难度,除非它们也能像阿里巴巴那样得到 PyTorch 团队的支持。
PyTorch 深受用户喜爱,最重要的是编写 Python 控制流。但是 Python 的其他语法也很重要。如果能够编写 Python 控制流( 使用 script )意味着失去其他优秀的语法,我宁愿放弃编写 Python 控制流的能力。事实上,如果 PyTorch 对 Python 控制流不那么执着,并且像这样(类似于 tf.cond 的 API)为我提供了诸如 torch.cond 之类的符号控制流:
1 | def f(x): |
然后 f 可以正确 trace,不再需要担心 script。
保存和加载模型
1 | traced.save('wrapped_rnn.pt') |