【学习】pytorch中with torch.no_grad():和 model.eval()的区别

model.eval()和with torch.no_grad()的区别
先给出结论: 这两个的所起的作用是不同的。

model.eval()

model.eval()的作用是不启用Batch Normalization 和 Dropout. 相当于告诉网络,目前在eval模式,dropout层不会起作用(而在训练阶段,dropout会随机以一定的概率丢弃掉网络中间层的神经元,而且在实际操作过程中一般不在卷积层设置该操作,而是在全连接层设置,因为全连接层的参数量一般远大于卷积层),batch normalization也有不同的作用(在训练过程中会对每一个特征维做归一化操作,对每一批量输入算出mean和std,而在eval模式下BN层将能够使用全部训练数据的均值和方差,即测试过程中不再针对测试样本计算mean和std,而是用训练好的值)。

with torch.no_grad()

当我们计算梯度的时候,我们需要缓存前向传播过程中大量的中间输出,因为在反向传播pytoch自动计算梯度时需要用到这些值。而我们在测试时,我们不需要计算梯度,那么也就意味着我们不需要在forward的时候保存这些中间输出。此外,在测试阶段,我们也不需要构造计算图(这也需要一定的存储开销)。Pytorch为我们提供了一个上下文管理器,torch.no_grad,在with torch.no_grad() 管理的环境中进行计算,不会生成计算图,不会存储为计算梯度而缓存的中间值。

结论

当网络中出现batch normalization或者dropout这样的在training,eval时表现不同的层,应当使用model.eval()。在测试时用with torch.no_grad()会节省存储空间。

另外需要注意的是,即便不使用with torch.no_grad(),在测试只要你不调用loss.backward()就不会计算梯度,with torch.no_grad()的作用只是节省存储空间。当然,在测试阶段,可以两者一起使用,效果更好。
总的来说,区别不大。

相关推荐

  1. pytorch,load_state_dicttorch.load区别

    2024-02-22 06:34:03       6 阅读
  2. PyTorch ,TensorFlowCaffe之间区别

    2024-02-22 06:34:03       39 阅读
  3. Mybatis${}#{}区别

    2024-02-22 06:34:03       23 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-02-22 06:34:03       19 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-02-22 06:34:03       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-02-22 06:34:03       20 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-02-22 06:34:03       20 阅读

热门阅读

  1. Sora后观察:AI大模型产业落地的八个锚点

    2024-02-22 06:34:03       27 阅读
  2. 华为配置直连三层组网隧道转发示例

    2024-02-22 06:34:03       22 阅读
  3. Linux 环境变量

    2024-02-22 06:34:03       24 阅读
  4. Mybatis中各个方法

    2024-02-22 06:34:03       26 阅读
  5. Redis

    2024-02-22 06:34:03       24 阅读
  6. 鸿蒙 gnss 开关使能流程

    2024-02-22 06:34:03       25 阅读
  7. HTML5 扩展了 HTMLDocument 类型

    2024-02-22 06:34:03       26 阅读
  8. css3实现无缝滚动,鼠标经过暂停

    2024-02-22 06:34:03       28 阅读