博客摘录「 自动微分----pytorch中的梯度运算与反向传播函数(预备知识)5」2024年4月18日

Python控制流的梯度计算

使用自动微分的一个好处是: 即使构建函数的计算图需要通过Python控制流(例如,条件、循环或任意函数调用),我们仍然可以计算得到的变量的梯度。 在下面的代码中,while循环的迭代次数和if语句的结果都取决于输入a的值。

def f(a):
    b = a * 2
    while b.norm() < 1000:
        b = b * 2
    if b.sum() > 0:
        c = b
    else:
        c = b * 100
    return c

让我们计算梯度。

a = torch.randn(size=(),requires_grad=True)                          #随机生成一个符合正态分布的小数

d = f(a)                                                              #调用函数f,计算关于a的函数
d.backward()                                                          #调用反向传播函数

我们现在可以分析上面定义的 f 函数。 请注意,它在其输入 a 中是分段线性的。 换言之,对于任何 a  ,存在某个常量标量k,使得 f ( a ) = k a  ,其中 k的值取决于输入 a。 因此,我们可以用 d / a验证梯度是否正确。

a.grad, a.grad == d / a                                               #计算a的梯度
(tensor(1024.), tensor(True))

注:

在您给出的例子中,`f(a)` 是一个依赖于输入 `a` 的非线性函数,尽管对于某些特定的输入范围和操作,它的局部行为可能近似线性。由于涉及到的操作包括了基于张量 `b` 的范数判断循环条件以及基于 `b` 的和判断条件分支,所以整体函数不是简单的线性缩放关系。

- 当 `a` 被初始化为从正态分布中抽取的一个标量并要求梯度时,`a.grad` 初始时会被清零。
- 在调用 `f(a)` 后,函数首先将 `a` 乘以 2 存储到 `b`。
- 然后进入一个 `while` 循环,在循环内部不断将 `b` 乘以 2 直到其范数大于等于 1000。
- 接下来是一个条件判断:如果 `b` 的元素之和大于 0,则将 `c` 设置为 `b`;否则将 `c` 设置为 `b` 的 100 倍。
- 函数返回 `c`,并且因为 `a` 要求了梯度,调用 `d.backward()` 将会计算 `d` 关于 `a` 的梯度。

由于 `b` 的每一次更新都是通过乘以 2 实现的,而 `a` 的梯度可以通过链式法则追溯至 `b`,再进一步回溯至 `a`,最终 `a.grad` 应该是每次循环中梯度累积的结果。然而,在实际运行这段代码时,`a.grad` 是否恰好等于 `d / a` 并不保证,因为它取决于具体的 `a` 值和循环终止时 `b` 的状态。

实际情况中,上述代码无法直接得出 `a.grad == d / a` 的结论,因为函数 `f` 并不具备全局的线性性质。不过,您可以执行这段代码来观察 `a.grad` 的具体值,并尝试理解不同输入下梯度是如何随着循环和条件判断变化的。

最近更新

  1. TCP协议是安全的吗?

    2024-04-22 11:34:01       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-04-22 11:34:01       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-04-22 11:34:01       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-04-22 11:34:01       18 阅读

热门阅读

  1. 【每日一题】补档 CF371 D. Vessels | 并查集 | 简单

    2024-04-22 11:34:01       11 阅读
  2. 什么是深度学习?

    2024-04-22 11:34:01       12 阅读
  3. C#中检查一个矩阵是否可逆

    2024-04-22 11:34:01       14 阅读
  4. 金融领域思考-前言

    2024-04-22 11:34:01       10 阅读
  5. hadoop

    hadoop

    2024-04-22 11:34:01      9 阅读
  6. 上海计算机学会2020年7月月赛C++丙组T2感应门

    2024-04-22 11:34:01       10 阅读
  7. Day15-Python基础学习之PySpark

    2024-04-22 11:34:01       9 阅读
  8. CSS简单的选择器

    2024-04-22 11:34:01       10 阅读
  9. Linux第二章

    2024-04-22 11:34:01       11 阅读