深入理解交叉熵损失 CrossEntropyLoss - CrossEntropyLoss

西笑生 2024-07-07 11:35:02 阅读 58

深入理解交叉熵损失 CrossEntropyLoss - CrossEntropyLoss

flyfish

本系列的主要内容是在2017年所写,GPT使用了交叉熵损失函数,所以就温故而知新,文中代码又用新版的PyTorch写了一遍,在看交叉熵损失函数遇到问题时,可先看链接提供的基础知识,可以有更深的理解。

深入理解交叉熵损失 CrossEntropyLoss - one-hot 编码

深入理解交叉熵损失 CrossEntropyLoss - 对数

深入理解交叉熵损失 CrossEntropyLoss - 概率基础

深入理解交叉熵损失 CrossEntropyLoss - 概率分布

深入理解交叉熵损失 CrossEntropyLoss - 损失函数

深入理解交叉熵损失 CrossEntropyLoss - 归一化

深入理解交叉熵损失 CrossEntropyLoss - 信息论(交叉熵)

深入理解交叉熵损失 CrossEntropyLoss - Softmax

深入理解交叉熵损失 CrossEntropyLoss - nn.LogSoftmax

深入理解交叉熵损失 CrossEntropyLoss - 似然

深入理解交叉熵损失CrossEntropyLoss - 乘积符号在似然函数中的应用

深入理解交叉熵损失 CrossEntropyLoss - nn.NLLLoss

深入理解交叉熵损失 CrossEntropyLoss - nn.CrossEntropyLoss

深入理解交叉熵损失CrossEntropyLoss

深入理解交叉熵损失 CrossEntropyLoss - CrossEntropyLossLogSoftmax和 NLLLoss两者的结合,对比立使用CrossEntropyLoss解释直观解释 Softmax和负对数似然

二分类问题手动计算步骤代码实现

多分类问题手动计算步骤代码验证

在 PyTorch 中,

<code>torch.nn.CrossEntropyLoss 是一个常用的

损失函数,主要用于多分类任务。它结合了

nn.LogSoftmax

nn.NLLLoss,并且内部进行了优化以避免

数值稳定性问题。

具体来说,torch.nn.CrossEntropyLoss 计算的是预测值与目标值之间的交叉熵损失。对于多分类问题,交叉熵损失是最常用的损失函数,因为它直接衡量了两个概率分布(预测概率分布和实际分布)之间的差异。

LogSoftmax和 NLLLoss两者的结合,对比立使用CrossEntropyLoss

nn.CrossEntropyLoss 在内部已经包含了 LogSoftmax 和 NLLLoss 的操作。

编写代码验证,分别是 LogSoftmax和 NLLLoss两者的结合,对比立使用CrossEntropyLoss。

import torch

import torch.nn as nn

# 输入张量 (batch_size=2, num_classes=3)

input_tensor = torch.tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]])

# 目标张量 (batch_size=2)

target_tensor = torch.tensor([2, 0])

# 使用 nn.LogSoftmax 和 nn.NLLLoss

log_softmax = nn.LogSoftmax(dim=1)

log_probs = log_softmax(input_tensor)

nll_loss = nn.NLLLoss()

loss = nll_loss(log_probs, target_tensor)

print(f'Loss using LogSoftmax and NLLLoss: { loss.item()}')

# 使用 nn.CrossEntropyLoss

cross_entropy_loss = nn.CrossEntropyLoss()

loss_ce = cross_entropy_loss(input_tensor, target_tensor)

print(f'Loss using CrossEntropyLoss: { loss_ce.item()}')

输出结果

Loss using LogSoftmax and NLLLoss: 1.4076058864593506

Loss using CrossEntropyLoss: 1.4076058864593506

解释

对于单个样本,交叉熵损失的定义如下:

CrossEntropyLoss

=

i

=

1

C

y

i

log

(

y

^

i

)

\text{CrossEntropyLoss} = -\sum_{i=1}^{C} y_i \log(\hat{y}_i)

CrossEntropyLoss=−i=1∑C​yi​log(y^​i​)

其中:

C

C

C 是类别的数量。

y

i

y_i

yi​ 是真实标签的一个one-hot编码(若样本属于类别

i

i

i,则

y

i

=

1

y_i = 1

yi​=1,否则

y

i

=

0

y_i = 0

yi​=0)。

y

^

i

\hat{y}_i

y^​i​ 是模型预测的第

i

i

i 类的概率。

直观解释 Softmax和负对数似然

交叉熵损失结合了两个概念:

Softmax

首先将模型输出的原始分数(logits)通过 softmax 函数转换成概率分布,Softmax 函数将 logits 转换为概率分布。对于一个有

C

C

C 个类别的分类问题,Softmax 公式如下:

y

^

i

=

exp

(

z

i

)

j

=

1

C

exp

(

z

j

)

\hat{y}_i = \frac{\exp(z_i)}{\sum_{j=1}^{C} \exp(z_j)}

y^​i​=∑j=1C​exp(zj​)exp(zi​)​

其中

z

i

z_i

zi​ 是第

i

i

i 类的 logit。

负对数似然

计算这些概率分布与真实标签之间的负对数似然。在获得概率分布后,交叉熵损失计算真实标签的负对数概率。如果真实标签对应的类别概率很高,损失就小;如果概率很低,损失就大。这驱动模型在训练过程中提高真实标签类别的预测概率。

以下是一个简单的示例,展示如何计算交叉熵损失:

import torch

import torch.nn as nn

# 假设我们有两个样本,每个样本属于3个类别中的一个

logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.0, 0.3]])

# 真实标签

labels = torch.tensor([0, 1])

# 使用 nn.CrossEntropyLoss 计算损失

criterion = nn.CrossEntropyLoss()

loss = criterion(logits, labels)

print(f'Cross Entropy Loss: { loss.item()}')

Cross Entropy Loss: 0.37882310152053833

在这个示例中:

logits 是模型输出的原始分数。labels 是真实的类别标签。nn.CrossEntropyLoss 会先将 logits 转换为概率分布,然后计算真实标签的负对数似然损失。

二分类问题

二分类交叉熵损失的公式为:

CrossEntropyLoss

=

(

y

log

(

y

^

)

+

(

1

y

)

log

(

1

y

^

)

)

\text{CrossEntropyLoss} = - (y \log(\hat{y}) + (1 - y) \log(1 - \hat{y}))

CrossEntropyLoss=−(ylog(y^​)+(1−y)log(1−y^​))

手动计算步骤

计算 Sigmoid 激活值

假设:

真实标签

y

=

1

y = 1

y=1模型输出的logits为

z

=

1.5

z = 1.5

z=1.5

计算过程:

σ

(

z

)

=

1

1

+

exp

(

1.5

)

\sigma(z) = \frac{1}{1 + \exp(-1.5)}

σ(z)=1+exp(−1.5)1​

我们使用更高精度来计算:

exp

(

1.5

)

0.22313016014842982

\exp(-1.5) \approx 0.22313016014842982

exp(−1.5)≈0.22313016014842982

σ

(

z

)

=

1

1

+

0.22313016014842982

1

1.22313016014842982

0.8175744761936437

\sigma(z) = \frac{1}{1 + 0.22313016014842982} \approx \frac{1}{1.22313016014842982} \approx 0.8175744761936437

σ(z)=1+0.223130160148429821​≈1.223130160148429821​≈0.8175744761936437

计算交叉熵损失

CrossEntropyLoss

=

(

y

log

(

σ

(

z

)

)

+

(

1

y

)

log

(

1

σ

(

z

)

)

)

\text{CrossEntropyLoss} = - (y \log(\sigma(z)) + (1 - y) \log(1 - \sigma(z)))

CrossEntropyLoss=−(ylog(σ(z))+(1−y)log(1−σ(z)))

CrossEntropyLoss

=

log

(

0.8175744761936437

)

\text{CrossEntropyLoss} = - \log(0.8175744761936437)

CrossEntropyLoss=−log(0.8175744761936437)

log

(

0.8175744761936437

)

0.2014132779827524

\log(0.8175744761936437) \approx -0.2014132779827524

log(0.8175744761936437)≈−0.2014132779827524

CrossEntropyLoss

0.2014132779827524

\text{CrossEntropyLoss} \approx 0.2014132779827524

CrossEntropyLoss≈0.2014132779827524

代码实现

import torch

import torch.nn as nn

import math

# 真实标签和 logits

labels = torch.tensor([1.0])

logits = torch.tensor([1.5])

# 使用 BCEWithLogitsLoss

criterion = nn.BCEWithLogitsLoss()

loss = criterion(logits, labels)

print(f'Binary Classification Cross Entropy Loss: { loss.item()}')

# 手动计算 sigmoid 和交叉熵损失

sigmoid = 1 / (1 + math.exp(-1.5))

manual_loss = - (1 * math.log(sigmoid) + (1 - 1) * math.log(1 - sigmoid))

print(f'Manually Computed Cross Entropy Loss: { manual_loss}')

输出结果

Binary Classification Cross Entropy Loss: 0.20141397416591644

Manually Computed Cross Entropy Loss: 0.2014132779827524

多分类问题

假设有3个类别:

真实标签为第3类,所以one-hot编码

y

=

[

0

,

0

,

1

]

y = [0, 0, 1]

y=[0,0,1]。模型预测的logits为

logits

=

[

0.1

,

0.2

,

0.7

]

\text{logits} = [0.1, 0.2, 0.7]

logits=[0.1,0.2,0.7]。

手动计算步骤

计算Softmax

y

^

i

=

exp

(

z

i

)

k

=

1

C

exp

(

z

k

)

\hat{y}_i = \frac{\exp(z_i)}{\sum_{k=1}^{C} \exp(z_k)}

y^​i​=∑k=1C​exp(zk​)exp(zi​)​

具体计算:

y

^

1

=

exp

(

0.1

)

exp

(

0.1

)

+

exp

(

0.2

)

+

exp

(

0.7

)

\hat{y}_1 = \frac{\exp(0.1)}{\exp(0.1) + \exp(0.2) + \exp(0.7)}

y^​1​=exp(0.1)+exp(0.2)+exp(0.7)exp(0.1)​

y

^

2

=

exp

(

0.2

)

exp

(

0.1

)

+

exp

(

0.2

)

+

exp

(

0.7

)

\hat{y}_2 = \frac{\exp(0.2)}{\exp(0.1) + \exp(0.2) + \exp(0.7)}

y^​2​=exp(0.1)+exp(0.2)+exp(0.7)exp(0.2)​

y

^

3

=

exp

(

0.7

)

exp

(

0.1

)

+

exp

(

0.2

)

+

exp

(

0.7

)

\hat{y}_3 = \frac{\exp(0.7)}{\exp(0.1) + \exp(0.2) + \exp(0.7)}

y^​3​=exp(0.1)+exp(0.2)+exp(0.7)exp(0.7)​

计算得到:

exp

(

0.1

)

1.1052

\exp(0.1) \approx 1.1052

exp(0.1)≈1.1052

exp

(

0.2

)

1.2214

\exp(0.2) \approx 1.2214

exp(0.2)≈1.2214

exp

(

0.7

)

2.0138

\exp(0.7) \approx 2.0138

exp(0.7)≈2.0138

总和:

exp

(

0.1

)

+

exp

(

0.2

)

+

exp

(

0.7

)

1.1052

+

1.2214

+

2.0138

=

4.3404

\exp(0.1) + \exp(0.2) + \exp(0.7) \approx 1.1052 + 1.2214 + 2.0138 = 4.3404

exp(0.1)+exp(0.2)+exp(0.7)≈1.1052+1.2214+2.0138=4.3404

各个概率:

y

^

1

=

1.1052

4.3404

0.2546

\hat{y}_1 = \frac{1.1052}{4.3404} \approx 0.2546

y^​1​=4.34041.1052​≈0.2546

y

^

2

=

1.2214

4.3404

0.2814

\hat{y}_2 = \frac{1.2214}{4.3404} \approx 0.2814

y^​2​=4.34041.2214​≈0.2814

y

^

3

=

2.0138

4.3404

0.4639

\hat{y}_3 = \frac{2.0138}{4.3404} \approx 0.4639

y^​3​=4.34042.0138​≈0.4639

计算交叉熵损失

CrossEntropyLoss

=

(

0

log

(

0.2546

)

+

0

log

(

0.2814

)

+

1

log

(

0.4639

)

)

\text{CrossEntropyLoss} = - (0 \cdot \log(0.2546) + 0 \cdot \log(0.2814) + 1 \cdot \log(0.4639))

CrossEntropyLoss=−(0⋅log(0.2546)+0⋅log(0.2814)+1⋅log(0.4639))

CrossEntropyLoss

=

log

(

0.4639

)

0.769

\text{CrossEntropyLoss} = - \log(0.4639) \approx 0.769

CrossEntropyLoss=−log(0.4639)≈0.769

代码验证

import torch

import torch.nn as nn

import torch.nn.functional as F

# 模拟输入的 logits 和真实标签

logits = torch.tensor([[0.1, 0.2, 0.7]], requires_grad=True)

labels = torch.tensor([2])

# 使用 CrossEntropyLoss

criterion = nn.CrossEntropyLoss()

loss = criterion(logits, labels)

print(f'Computed Cross Entropy Loss (using nn.CrossEntropyLoss): { loss.item()}')

# 手动计算 softmax 和交叉熵损失

softmax_probs = F.softmax(logits, dim=1)

manual_loss = -torch.log(softmax_probs[0, labels])

print(f'Manually Computed Cross Entropy Loss: { manual_loss.item()}')

输出结果

Computed Cross Entropy Loss (using nn.CrossEntropyLoss): 0.7679495811462402

Manually Computed Cross Entropy Loss: 0.7679495811462402

注意在多分类问题的代码中,我们提供了logits而不是softmax后的概率,因为nn.CrossEntropyLoss会在内部应用softmax。

在二分类问题中,我们可以使用 nn.BCEWithLogitsLoss,它会在内部应用 Sigmoid 激活函数,并计算二分类的交叉熵损失。

在多分类问题中,我们可以使用 nn.CrossEntropyLoss,它会在内部应用 Softmax 激活函数,并计算多分类的交叉熵损失



声明

本文内容仅代表作者观点,或转载于其他网站,本站不以此文作为商业用途
如有涉及侵权,请联系本站进行删除
转载本站原创文章,请注明来源及作者。