大语言模型的工程技巧(四)——梯度检查点

相关说明

这篇文章的大部分内容参考自我的新书《解构大语言模型:从线性回归到通用人工智能》,欢迎有兴趣的读者多多支持。

本文将讨论如何利用梯度检查点算法来减少模型在训练时候(更准确地说是运行反向传播算法时)的内存开支。这在训练超大规模的模型时会用到。

关于其他的工程技巧可以参考:

关于大语言模型的讨论请参考:

一、标准反向传播

根据梯度的定义,变量的梯度与其本身的值密切相关。因此,要想得到某个变量的梯度,必须先知道这个变量的值。这也是为什么在进行反向传播算法之前,需要先对计算图进行向前传播,并记录每个节点的计算结果,如图1左侧部分所示。这样在计算节点的梯度时,可以利用这些事先缓存的结果,直接启动反向传播过程,从而得到梯度,如图1中的节点d所示。这种方法也被称为标准反向传播。这种方式能够确保梯度计算以最高效的方式进行。

图1

图1

二、内存极简算法

然而,采用标准反向传播算法会造成较大的内存开销。为了在计算过程中尽可能地压缩内存使用,可以采用一种以时间换空间的方法。在这种算法中,一旦向前传播完成,仅会保留顶点的计算结果,而中间节点的结果会被清空(叶子节点的值会保留)。在反向传播遇到中间计算节点没有缓存时,则重新触发向前传播,以获取所需节点的结果。这就是内存极简的反向传播算法。以节点d为例,为了计算其梯度,需要首先从节点a开始重新触发向前传播直到节点d,并缓存计算结果。然后使用这个缓存的结果以及节点e的梯度,计算出节点d的梯度。对于其他节点,也采用类似的步骤计算梯度。通过这种方式,在完成反向传播的同时,节省了内存开销。以图1为例,内存极简算法只需要3个存储空间,而标准算法需要5个存储空间。

三、梯度检查点

尽管内存极简算法在降低内存开销方面取得了显著成果,但它涉及大量的重复计算,运行时间相对较长。为了在内存使用和运行时间之间取得平衡,下面引入梯度检查点(Gradient Checkpoint)。这一算法的核心思想是选择一些中间节点作为存储点,以便在再次触发向前传播时,以这些存储点作为起点开始传播,避免从头开始重复计算。这种方式在一定程度上减少重复计算,从而提高运行效率。需要注意的是,由于需要存储额外的中间结果,梯度检查点会稍微增加一些内存开销。

关于梯度检查点算法,PyTorch中已经提供了便捷的封装函数,即torch.utils.checkpoint。这个工具能够帮助我们更方便地应用梯度检查点算法,以平衡内存开锁和运行时间。更多细节请参考这个链接

相关推荐

  1. 语言模型高质量提示词工程技巧指南

    2024-05-26 01:30:34       31 阅读
  2. 语言模型技术-算法原理

    2024-05-26 01:30:34       33 阅读
  3. 提示工程 1—常用语言模型参数说明

    2024-05-26 01:30:34       13 阅读
  4. 语言模型提示工程简介

    2024-05-26 01:30:34       22 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-05-26 01:30:34       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-05-26 01:30:34       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-05-26 01:30:34       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-05-26 01:30:34       20 阅读

热门阅读

  1. 【数据结构与算法 | 基础篇】数组模拟栈

    2024-05-26 01:30:34       12 阅读
  2. 银发经济:老龄化社会中的机遇与挑战

    2024-05-26 01:30:34       10 阅读
  3. 基于Amazon Cognito的安全登录与资源访问

    2024-05-26 01:30:34       11 阅读
  4. ORACLE 6节点组成的ACFS文件系统异常的分析思路

    2024-05-26 01:30:34       11 阅读
  5. Nginx 从入门到精通-Nginx-Web服务器的瑞士军刀

    2024-05-26 01:30:34       12 阅读
  6. PostgreSQL入门教程

    2024-05-26 01:30:34       9 阅读
  7. 系统分析师-案例分析-数据库

    2024-05-26 01:30:34       15 阅读
  8. 巧用count与count()

    2024-05-26 01:30:34       10 阅读
  9. React hooks - forwardRef+useImperativeHandle

    2024-05-26 01:30:34       11 阅读
  10. 安卓adb 抓取模拟器日志

    2024-05-26 01:30:34       11 阅读
  11. Compose在xml中使用滑动冲突处理

    2024-05-26 01:30:34       13 阅读
  12. 【webrtc】MediaEngine的实现CompositeMediaEngine创建VOE

    2024-05-26 01:30:34       10 阅读