【DL】FocalLoss的PyTorch实现
此篇不介绍FocalLoss的原理,仅展示PyTorch实现FocalLoss的两种方式。个人认为相关原理已在文章《FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现》中讲得很清晰,故此篇不再介绍。
方式一
同时计算一个batch中所有样本关于FocalLoss的损失值(来自文章《FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现》,个人补充了一些注释):
import torch
from torch import nn
import random
class FocalLoss(nn.Module):
"""
参考 https://github.com/lonePatient/TorchBlocks
"""
def __init__(self, gamma=2.0, alpha=1, epsilon=1.e-9, device=None):
super(FocalLoss, self).__init__()
self.gamma = gamma
if isinstance(alpha, list):
self.alpha = torch.Tensor(alpha, device=device)
else:
self.alpha = alpha
self.epsilon = epsilon
'''
batch中所有样本一起计算loss
'''
def forward(self, input, target):
"""
Args:
input: model's output, shape of [batch_size, num_cls]
target: ground truth labels, shape of [batch_size]
Returns:
shape of [batch_size]
"""
num_labels = input.size(-1) # 类别数量
idx = target.view(-1, 1).long() # 行向量target变成列向量idx
one_hot_key = torch.zeros(idx.size(0), num_labels, dtype=torch.float32, device=idx.device)
one_hot_key = one_hot_key.scatter_(1, idx, 1) # one_hot_key矩阵中的每一行对应相应样本的标签one_hot向量,利用scatter_方法将样本的标签类别标记为1,其余位置为0
one_hot_key[:, 0] = 0 # ignore 0 index. 此行需要视具体情况决定是否保留,如果标签中存在类别0(而不是直接从类别1开始),此行应当注释、不使用
logits = torch.softmax(input, dim=-1)
loss = -self.alpha * one_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log() # 计算FocalLoss
loss = loss.sum(1)
return loss.mean()
# 固定随机数种子,方便复现
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
loss = FocalLoss(alpha=[0.1, 0.2, 0.3, 0.15, 0.25])
# 设置随机数种子
setup_seed(20)
input = torch.randn(3, 5, requires_grad=True) # torch.Size([3, 5]) [sample_num, class_num]
target = torch.empty(3, dtype=torch.long).random_(5) # torch.Size([3]) [sample_num]
output = loss(input, target)
# print(output)
output.backward()
方式二
一个batch中逐个样本计算关于FocalLoss的损失值,将它们求平均,返回一个batch内所有样本的FocalLoss的平均值:
import torch
from torch import nn
import random
class FocalLoss(nn.Module):
"""
参考 https://github.com/lonePatient/TorchBlocks
"""
def __init__(self, gamma=2.0, alpha=1, epsilon=1.e-9, device=None):
super(FocalLoss, self).__init__()
self.gamma = gamma
if isinstance(alpha, list):
self.alpha = torch.Tensor(alpha, device=device)
else:
self.alpha = alpha
self.epsilon = epsilon
'''
逐个样本计算loss
'''
def forward(self, input, target):
"""
Args:
input: model's output, shape of [batch_size, num_cls]
target: ground truth labels, shape of [batch_size]
Returns:
shape of [batch_size]
"""
num_labels = input.size(-1) # 类别数量
loss = []
for i, sample in enumerate(input):
one_hot_key = torch.zeros(1, num_labels, dtype=torch.float32, device=input.device)
one_hot_key.scatter_(1, target[i].view(1, -1), 1)
logits = torch.softmax(sample, dim=-1)
loss_this_sample = - self.alpha * one_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log()
loss_this_sample = loss_this_sample.sum(1)
if i == 0:
loss = loss_this_sample
else:
loss = torch.cat((loss, loss_this_sample))
return loss.mean()
# 固定随机数种子,方便复现
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
loss = FocalLoss(alpha=[0.1, 0.2, 0.3, 0.15, 0.25])
# 设置随机数种子
setup_seed(20)
input = torch.randn(3, 5, requires_grad=True) # torch.Size([3, 5]) [sample_num, class_num]
target = torch.empty(3, dtype=torch.long).random_(5) # torch.Size([3]) [sample_num]
output = loss(input, target)
# print(output)
output.backward()