深度学习损失计算

深度学习损失计算

1.如何计算当前epoch的损失?

深度学习中的损失计算,通常为数据集的平均损失,即每个样本的平均损失值。计算步骤如下:

  • 计算单个批次的损失。每次迭代中,用当前模型预测值和真实值计算损失。假设 _loss 是这次迭代中计算得到的损失。
  • 转换为标量。利用item()方法将其转换为标量值。_loss.item()
  • 乘以批次大小。乘以批次大小的原因是,希望总损失是所有数据点的损失总和,而不是批次平均损失。
  • 累加损失loss += _loss.item() * batch_size 将当前批次的总损失累加到变量 loss 中。这样所有批次遍历结束后,就得到一个epoch的总损失。
  • 计算当前epoch的样本平均损失。通过总损失除以总的数据样本数,来得到平均损失。average_loss = loss/len(dataloader.dataset)【注意:除的是总的数据样本数(len(dataloader.dataset))!不是总的批次数(len(dataloader))!】

示例代码如下:

for epoch in total_epoch:  # epoch迭代
    
    total_loss = 0.0  # 初始化总损失

    for inputs, targets in dataloader:  # batch迭代
        outputs = model(inputs)  # 获取预测值
        _loss = criterion(outputs, targets)  # 计算当前批次损失,为批次平均损失

        batch_size = inputs.size(0)  # 获取批次大小
        total_loss += _loss.item() * batch_size  # 计算当前批次的总损失

	# 计算当前epoch的平均损失
    average_loss = total_loss / len(dataloader.dataset)  

2.为什么要计算样本平均损失,而不是计算批次平均损失?

由于每个批次的大小可能不一样,特别是在数据集的大小不是批次大小的整数倍时,所以使用 len(dataloader) 会导致错误的平均损失计算。

下面用一个简单的例子,解释这两种计算方式的不同:

假设数据集有 105 个样本,每个批次大小为 10,这样会有 11 个批次,其中最后一个批次只有 5 个样本。结合上面的伪代码,假设损失值 _loss.item() 是 1,对于 10 个批次的损失是 10,最后一个批次的损失是 5。那么:

  • t o t a l _ l o s s = ( 1 ∗ 10 ) ∗ 10 + ( 1 ∗ 5 ) ∗ 1 = 105 total\_loss = (1 * 10) * 10 + (1 * 5) * 1 = 105 total_loss=(110)10+(15)1=105
  • l e n ( d a t a l o a d e r . d a t a s e t ) = 105 len(dataloader.dataset) = 105 len(dataloader.dataset)=105
  • l e n ( d a t a l o a d e r ) = 11 len(dataloader) = 11 len(dataloader)=11

计算结果:

  • 样本平均损失计算:average_loss = total_loss / len(dataloader.dataset) 105 / 105 = 1 105/105 = 1 105/105=1
  • 批次平均损失计算:average_loss = total_loss / len(dataloader) 105 / 11 ≈ 9.545 105/11 \approx 9.545 105/119.545

显然,第一种方式是正确的,反映了每个样本的真实平均损失。

😃😃😃

相关推荐

  1. 深度学习损失计算

    2024-07-16 22:36:04       20 阅读
  2. 深度学习 - softmax交叉熵损失计算

    2024-07-16 22:36:04       25 阅读
  3. 深度学习损失函数

    2024-07-16 22:36:04       29 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-16 22:36:04       67 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-16 22:36:04       72 阅读
  3. 在Django里面运行非项目文件

    2024-07-16 22:36:04       58 阅读
  4. Python语言-面向对象

    2024-07-16 22:36:04       69 阅读

热门阅读

  1. Python字典基础与高级详解

    2024-07-16 22:36:04       19 阅读
  2. 代码随想录打卡第二十五天

    2024-07-16 22:36:04       21 阅读
  3. [笔试题] C语言部分练习2

    2024-07-16 22:36:04       17 阅读