高性能 PyTorch 训练:性能瓶颈的调查与分析

PyTorch

在 2014 年之前,神经网络与深度学习还没有大规模地应用于工业界。研究者们开发了一些基本而有效的工具包,来搭建神经网络。其中的代表就是 Caffe、Torch 和 Theano。由于当时的研究主流方向是卷积神经网络 (CNN) 在计算机视觉 (CV) 中的应用,所以这些框架主要关注的是 layers。这种设计完全可以满足研究者们拼接不同卷积层、了解不同神经网络结构效果的目的。

而在之后,随着循环神经网络 (RNN) 和自然语言处理 (NLP) 的兴起,以 layers 为“first class citizen”的工具包们就开始力不从心了。而工业界也开始对模型构建、训练以及部署的效率提出了新的要求。随着以 Google、Microsoft、Facebook、Amazon 等巨头的加入,以数据流 (Data Flow) 为中心的体系被提出,TensorFlow、CNTK、MXNet、PyTorch 等新一代的深度学习框架应运而生。

PyTorch 是一个基于 Torch 库的开源机器学习库,用于计算机视觉和自然语言处理等应用,主要由 Facebook 的人工智能研究实验室 (FAIR) 开发。它是在修改后的 BSD 2.0 许可协议下发布的免费开源软件。

正如同它名字的前缀一般,PyTorch 主要采用 Python 语言接口。与 TensorFlow 1.x 相比,PyTorch 的编写方式更简单自然,API 更 pythonic,对 debug 也更加友好。因此,PyTorch 在学术界赢得了更多的拥趸,近年来顶级会议中,PyTorch 的代码提交量遥遥领先于第二名的 TensorFlow (Keras)。而在工业界,PyTorch 后来居上,已经逐渐和 TensorFlow 分庭抗礼。

很多学术界最新的成果都是以 PyTorch 构建的,并被作者开源在了 GitHub。但也有很多声音表示 PyTorch 在训练中比 TensorFlow 更慢。

高性能 PyTorch 的训练流程是什么样的?是产生最高准确率的模型?是最快的运行速度?是易于理解和扩展?还是容易并行化?

答案是,上述所有。

结合我自己给 PyTorch 提速的经历,本文将给出一些提升 PyTorch 性能的方向。当然,作为本文的读者,您需要对 Linux 操作系统和 PyTorch 足够熟悉。

了解瓶颈所在

首先,当感受到训练缓慢时,我们应当检查系统的状态来得知代码的性能被哪些因素限制了。计算,是一个由应用程序(代码)、存储设备(硬盘、内存)、运算设备(CPU、GPU)共同参与的过程,有着非常明显的木桶效应,即其中任意一个环节都有可能成为性能瓶颈的来源。

工具

此时,熟悉一些运维工具可以有效地帮助你了解当前整个计算机以及各个硬件设备的工作状态。只有将性能瓶颈定位到 CPU、GPU、I/O 或是代码中,才能开始解决问题。

htop

htop 是一个跨平台的交互式流程查看器。

htop 允许垂直和水平滚动进程列表,以查看它们的完整命令行以及内存和CPU消耗等相关信息。显示的信息可以通过图形设置进行配置,并且可以交互地进行排序和过滤。与进程相关的任务(例如终止和更新)可以在不输入其 PID 的情况下完成。

从 htop 顶部的信息集合中,可以监视 CPU 和内存的使用情况。

一个好的高性能程序,应当尽可能多地进行异步运算,来充分发挥多核 CPU 的能力。同时,尽量多地使用内存,能够大大提高数据的交流效率。当然,这并不意味着你可以把所有的 CPU 和内存资源耗尽,这将使系统不能够正常调度资源,反而拖累计算。

在所有的 Linux 发行版中,你都可以直接从软件仓库中安装 htop,例如:

iotop 和 iostat

iotop 是用来监视每个命令所占用的 I/O 情况的命令行应用程序。

iostat 则是属于 sysstat 工具包中的一个组件,可以监视外部存储设备(硬盘)当前的 I/O 情况。

安装命令分别为:

需要注意的是,因为涉及到 I/O 情况的监视,所以以上两款程序均需要 root 权限才能正常运行。

nvidia-smi

NVIDIA System Management Interface 是基于 NVIDIA Management Library (NVML) 的命令行应用程序,旨在帮助管理和监视 NVIDIA GPU 设备。

一般来说,GPU 的流处理器使用率越高,就说明 GPU 是在以更高的效率运转的。设备的当前功率也能从侧面反映这个问题。换言之,如果你发现你的流处理器利用率低于 50%,则说明模型没能很好地利用 GPU 的并行能力。

在通过 sh 脚本安装 NVIDIA GPU 驱动后,nvidia-smi 会被自动安装。如果是从发行版的仓库中安装的驱动,可以尝试在软件源中搜索安装 nvidia-smi

nvtop

nvtop 代表 NVIDIA top,由开发者 Maxime Schmitt 发布于 GitHub,是一款用于观察和记录 NVIDIA GPU 使用情况的 (h)top 任务监视器。你会发现 nvtop 有着与 htop 非常相似的 UI。

它可以用于 GPU,并以曲线图的形式输出在一段时间内 GPU 流处理器和显存使用情况的变化。

相比于 nvidia-sminvtop 会关注到更关键的设备信息,并给出其在时间序列上的变化情况。

你可以参考 Syllo/nvtop 中的描述自行编译安装 nvtop

py-spy

py-spy 是一个针对 Python 程序的采样分析器。

它使您可以直观地看到 Python 程序正在花费时间,而无需重新启动程序或以任何方式修改代码。 py-spy 的开销非常低:为了提高速度,它是用 Rust 编写的,并且不会在所分析的 Python 程序相同的进程中运行。这意味着对生产 Python 代码使用 py-spy 是安全的。

py-spy 可以生成如下的 SVG 图像,来帮助你统计每一个 package、model 甚至每一个 function 在运行时所耗费的时间。

py-spy 同样能够以 top 的方式实时显示 Python 程序中哪些函数花费的时间最多。

只需要在 pypi 中安装即可:

问题分析与解决思路

有了以上工具所收集的信息,我们就可以开始分析限制程序性能发挥的原因了。

PyTorch Workflow

而首先,我们要了解 PyTorch 的工作流程。

因为 PyTorch 使用 Python 接口,同时在底层调用了相当多的 C 库,所以在使用 PyTorch 时,很多细节对用户是不暴露的。实际上,在常见的训练过程中,用户和 PyTorch 一起,大致完成了以下的步骤:

  1. 构建模型。将编写好的模型类实例化为 nn.Module 对象。
  2. 准备数据。将训练数据和测试数据进行预处理,然后组织为 Dataloader 的形式,并设置好数据增强方案。
  3. 定义 Loss Function 和 Optimizer。
  4. 主训练循环。

其中,主训练循环决定了网络经过多少次完整的数据集,即我们常说的 epoch:

  1. 从 Dataloader 中提取当前 batch 的数据。一般 Dataloader 中只记录了数据的 index 信息,所以每次训练循环时,对应的数据都会从硬盘被读取到内存,然后再从内存放入显存中,交由 GPU 进行后续步骤。
  2. 数据经过模型。
  3. 将得到的输出送到 Loss Function 中计算损失,随后进行 backward 求导。
  4. Optimizer 执行梯度下降,更新参数。

一般来说,PyTorch 训练的过程的快慢决定于主训练循环。主循环中的每一步都将被执行上万次乃至几十万次,任何的效率提升都能够带来极大的收益。

CPU 瓶颈

CPU 和 GPU 的计算特点,决定了它们不同的功用:CPU 具有更高的主频和精度,适合于进行串行任务;GPU 拥有几千到上万个 Stream 核心,可以进行大规模的并行任务。

所以,对于数据增强等没有相互依赖的任务交给 CPU 来进行,很大程度上会拖慢训练的进程。在每次的数据导入时,都会产生一定时间的等待。这是一种非常普遍的 CPU 瓶颈,即将不适合 CPU 的任务交给它来处理。

GPU 瓶颈

如果你熟悉梯度下降 (Gradient Descent) 的原理,那么你一定能够理解 batch size 对训练速度的影响。梯度下降将一个 batch 中的平均梯度作为总体梯度方向的近似,进行一次参数更新。Batch size 越大,那么 GPU 内同时并行计算的数据也就越多,相应的训练速度会有很大的提升。

Batch size 的设定对最终的训练结果有一定的影响,但是在一定范围内的调整并不会产生非常大的扰动。

主流观点中,在不过分影响最终的模型性能的前提下,batch size 的选取以最大化利用显存和流处理器为佳。

I/O 瓶颈

I/O 瓶颈是最常见、最普遍的训练效率影响因素。

正如上文中所描述的,数据将会在硬盘、内存和显存中不断地转移和复制。不同存储设备的读写速度,可能有几个数量级上的差异。

出现 I/O 瓶颈的标志主要有:

  1. 系统 I/O 读写(尤其是硬盘读写)占用过高;
  2. 内存占用和 CPU 占用都普遍偏低;
  3. 显存占用较高的情况下,流处理器的利用率过低。

将数据预读入内存中、异步进行数据加载都是有效的解决方案。一个简单稳定的方式是直接使用 DALI 库。

NVIDIA Data Loading Library (DALI) 是一个可移植的开源库,用于解码和增强图像、视频和语音,以加速深度学习应用程序。DALI 通过重叠训练和预处理减少了延迟和训练时间,缓解了瓶颈。它为流行的深度学习框架中内置的数据加载器和数据迭代器提供了一个插件,便于集成或重定向到不同的框架。

用图像训练神经网络需要开发人员首先对这些图像进行归一化处理。此外,图像通常会被压缩以节省存储空间。因此,开发人员构建了多阶段数据处理流程,包括加载、解码、裁剪、调整大小和许多其他增强算子。这些目前在 CPU 上执行的数据处理流水线已经成为瓶颈,限制了整体吞吐量。

DALI 是内置数据加载器和数据迭代器的高性能替代品。开发人员现在可以在 GPU 上运行他们的数据处理工作。

参考链接


高性能 PyTorch 训练 (1):性能瓶颈的调查与分析

python2.7解决UnicodeEncodeError: 'ascii' codec can't encode character问题

最近业务中需要用 Python 写一些脚本。尽管脚本的交互只是命令行 + 日志输出,但是为了让界面友好些,我还是决定用中文输出日志信息。

遇到了异常:

为了解决问题,研究了一下 Python 的字符编码处理。网上也有不少文章讲 Python 的字符编码,但是看过一遍,觉得自己可以讲得更明白些。

下面先复述一下 Python 字符串的基础,熟悉此内容的可以跳过。

 1.引入

对应 C/C++ 的 char 和 wchar_t, Python 也有两种字符串类型,str 与 unicode:

前面的申明:# -*- coding: utf-8 -*- 表明,上面的 Python 代码由 utf-8 编码。

为了保证输出不会在 linux 终端上显示乱码,需要设置好 linux 的环境变量:export LANG=en_US.UTF-8

如果你和我一样是使用 SecureCRT,请设置 Session Options/Terminal/Appearance/Character EncodingUTF-8 ,保证能够正确的解码 linux 终端的输出。

两个 Python 字符串类型间可以用 encode / decode 方法转换:

为什么从 unicode 转 str 是 encode,而反过来叫 decode? 

因为 Python 认为 16 位的 unicode 才是字符的唯一内码,而大家常用的字符集如 gb2312,gb18030/gbk,utf-8,以及 ascii 都是字符的二进制(字节)编码形式。把字符从 unicode 转换成二进制编码,当然是要 encode。

反过来,在 Python 中出现的 str 都是用字符集编码的 ansi 字符串。Python 本身并不知道 str 的编码,需要由开发者指定正确的字符集 decode。

(补充一句,其实 Python 是可以知道 str 编码的。因为我们在代码前面申明了 # -*- coding: utf-8 -*-,这表明代码中的 str 都是用 utf-8 编码的,我不知道 Python 为什么不这样做。)

如果用错误的字符集来 encode/decode 会怎样?

这就遇到了我在本文开头贴出的异常:UnicodeEncodeError: 'ascii' codec can't encode characters in position 0-3: ordinal not in range(128)

现在我们知道了这是个字符串编码异常。接下来, 为什么 Python 这么容易出现字符串编/解码异常? 

这要提到处理 Python 编码时容易遇到的两个陷阱。第一个是有关字符串连接的:

简单的字符串连接也会出现解码错误?

陷阱一:在进行同时包含 str 与 unicode 的运算时,Python 一律都把 str 转换成 unicode 再运算,当然,运算结果也都是 unicode。

由于 Python 事先并不知道 str 的编码,它只能使用 sys.getdefaultencoding() 编码去 decode。在我的印象里,sys.getdefaultencoding() 的值总是 'ascii' ——显然,如果需要转换的 str 有中文,一定会出现错误。

除了字符串连接,% 运算的结果也是一样的:

我不理解为什么 sys.getdefaultencoding() 与环境变量 $LANG 全无关系。如果 Python 用 $LANG 设置 sys.getdefaultencoding() 的值,那么至少开发者遇到 UnicodeDecodeError 的几率会降低 50%。

另外,就像前面说的,我也怀疑为什么 Python 在这里不参考 # -*- coding: utf-8 -*- ,因为 Python 在运行前总是会检查你的代码,这保证了代码里定义的 str 一定是 utf-8 。

对于这个问题,我的唯一建议是在代码里的中文字符串前写上 u。另外,在 Python 3 已经取消了 str,让所有的字符串都是 unicode ——这也许是个正确的决定。

其实,sys.getdefaultencoding() 的值是可以用“后门”方式修改的,我不是特别推荐这个解决方案,但是还是贴一下,因为后面有用:

可以看到,问题魔术般的解决了。但是注意! sys.setdefaultencoding() 的效果是全局的,如果你的代码由几个不同编码的 Python 文件组成,用这种方法只是按下了葫芦浮起了瓢,让问题变得复杂。

另一个陷阱是有关标准输出的。

刚刚怎么来着?我一直说要设置正确的 linux $LANG 环境变量。那么,设置错误的 $LANG,比如 zh_CN.GBK 会怎样?(避免终端的影响,请把 SecureCRT 也设置成相同的字符集。)

显然会是乱码,但是不是所有输出都是乱码。

为什么是 unicode 而不是 str 的字符显示是正确的? 首先我们需要了解 print。与所有语言一样,这个 Python 命令实际上是把字符打印到标准输出流 —— sys.stdout。而 Python 在这里变了个魔术,它会按照 sys.stdout.encoding 来给 unicode 编码,而把 str 直接输出,扔给操作系统去解决。

这也是为什么要设置 linux $LANG 环境变量与 SecureCRT 一致,否则这些字符会被 SecureCRT 再转换一次,才会交给桌面的 Windows 系统用编码 CP936 或者说 GBK 来显示。

通常情况,sys.stdout.encoding 的值与 linux $LANG 环境变量保持一致:

但是,这里有 陷阱二:一旦你的 Python 代码是用管道 / 子进程方式运行,sys.stdout.encoding 就会失效,让你重新遇到 UnicodeEncodeError。

比如,用管道方式运行上面的 example4.py 代码:

可以看到,第一:sys.stdout.encoding 的值变成了 None;第二:Python 在 print 时会尝试用 ascii 去编码 unicode.

由于 ascii 字符集不能用来表示中文字符,这里当然会编码失败。

怎么解决这个问题? 不知道别人是怎么搞定的,总之我用了一个丑陋的办法:

  这个方法仍然有个副作用:直接输出中文 str 会失败,因为 codecs 模块的 writer 与 sys.stdout 的行为相反,它会把所有的 str 用 sys.getdefaultencoding() 的字符集转换成 unicode 输出。

显然,sys.getdefaultencoding() 的值是 'ascii', 编码失败。

解决办法就像 example3.py 里说的,你要么给 str 加上 u 申明成 unicode,要么通过“后门”去修改 sys.getdefaultencoding():

总而言之,在 Python 2 下进行中文输入输出是个危机四伏的事,特别是在你的代码里混合使用 str 与 unicode 时。

有些模块,例如 json,会直接返回 unicode 类型的字符串,让你的 % 运算需要进行字符解码而失败。而有些会直接返回 str, 你需要知道它们的真实编码,特别是在 print 的时候。

为了避免一些陷阱,上文中说过,最好的办法就是在 Python 代码里永远使用 u 定义中文字符串。另外,如果你的代码需要用管道 / 子进程方式运行,则需要用到 example6.py 里的技巧。

 2.python 自动解编码机制导致报错

1.string 和 unicode 对象合并

2.列表合并

3.格式化字符串

4.打印 unicode 对象

5.输出到文件

1,2,3 的例子中,python 自动用 ascii 把 string 解码为 unicode 对象然后再进行相应操作,所以都是 decode 错误, 4 和 5 python 自动用 ascii 把 unicode 对象编码为字符串然后输出,所以都是 encode 错误。

只要涉及到 unicode 对象和 string 的转换以及 unicode 对象输出、输入的地方可能都会触发 python 自动进行解码/编码,比如写入数据库、写入到文件、读取 socket 等等。

到此,这两个异常产生的真正原因了基本已经清楚了: unicode 对象需要编码为相应的 string(字符串)才可以存储、传输、打印,字符串需要解码为对应的 unicode 对象才能完成 unicode 对象的各种操作,lenfind 等。

3.如何避免这些的错误

1.理解编码或解码的转换方向

无论何时发生编码错误,首先要理解编码方向,然后再针对性解决。

2.设置默认编码为 utf-8

在文件头写入

python 会查找: coding: name or coding=name,并设置文件编码格式为 name,此方式是告诉 python 默认编码不再是 ascii ,而是要使用声明的编码格式。

3.输入对象尽早解码为 unicode,输出对象尽早编码为字节流

无论何时有字节流输入,都需要尽早解码为 unicode 对象。任何时候想要把 unicode 对象写入到文件、数据库、socket 等外界程序,都需要进行编码。

4.使用 codecs 模块来处理输入输出 unicode 对象

codecs 模块可以自动的完成解编码的工作。

参考链接


python2.7 的中文编码处理,解决UnicodeEncodeError: 'ascii' codec can't encode character 问题

Python编写AES加密代码

最近需要 Python 实现 AES 加解密操作,在目前的 macOS Big Sur (11.4) 上使用 PyCryptodomePyCrypto 都会报错:

经过测试,我们可以使用 cryptography 来实现这个功能,如下:

参考 AES 算法实现代码:

参考链接


ubuntu 20.04将Python3交叉编译移植到Android平台

最近想在Android环境中集成Python3,参考了一下网上的实现,发现已经有项目实现这个功能的,具体的编译过程参考下面:

参考链接


使用pip安装模块时提示:No module named pip

系统输入cmd:
windows 解决方法:

升级pip版本:

linux解决方法:

参考链接


使用pip安装模块时提示: No module named pip

Python发送form-data请求提交表单数据

最近在处理服务器接口的时候,需要使用Python模拟发送表单数据到服务器,搜索了一下,找到如下例子:

参考链接


ubuntu 18.04系统报错ModuleNotFoundError: No module named 'pip._internal'

重新安装的ubuntu 18.04系统上初次安装python3-pip之后,执行升级命令,出现如下错误信息:

解决方法为重新升级安装一次 pip,如下:

参考链接


No module named 'pip._internal.cli.main'

使用mitmproxy + python做拦截代理

本文是一个较为完整的 mitmproxy 教程,侧重于介绍如何开发拦截脚本,帮助读者能够快速得到一个自定义的代理工具。

本文假设读者有基本的 python 知识,且已经安装好了一个 python 3 开发环境。如果你对 nodejs 的熟悉程度大于对 python,可移步到 anyproxy,anyproxy 的功能与 mitmproxy 基本一致,但使用 js 编写定制脚本。除此之外我就不知道有什么其他类似的工具了,如果你知道,欢迎评论告诉我。

本文基于 mitmproxy v4,当前版本号为 v4.0.1

继续阅读使用mitmproxy + python做拦截代理

读取的XML节点中带有冒号怎么办?

昨天,编程读取XML的时候,遇上了类似下面的一段XML

起初没有特别的留意,于是乎就像平时读取XML一样使用了。

但是,运行报错,不允许传入冒号:之类的字符,后来查阅资料发现,节点中,冒号前的a代表是的命名空间,冒号后的才是根节点名称。在Root节点中,也对命名空间进行了声明 xmlns:a="http://ww.abc.com/" ,知道了这么一回事后,再来看看如何去读取,正确的读取是:

从代码可以看出,我声明了一个XNamespace类型的变量,并且把XML文件中出现的命名空间 http://ww.abc.com/ 赋值给它,然后再读取节点的时候,与真正的节点名称book进行拼接就可以了!

XML中出现命名空间的原因是,当你需要使用多个XML一起工作时,由于两个文档都包含带有不同内容和定义的节点元素,就会发生命名冲突,加上命名空间使用可以避免发生冲突,这与C#编程中类的命名空间的用处差不多。

另外,如果需要了解更多操作XML的可以访问下面这篇文章,写得很详细:

http://www.cnblogs.com/nsky/archive/2013/03/05/2944725.html

参考链接