交叉熵(Cross Entropy)常用于分类问题中衡量两个概率分布之间的差异。它的公式如下:
1. 对于二分类问题(Binary Classification):
给定真实标签 y \in \{0, 1\}
,预测概率 \hat{y} \in [0, 1]
,交叉熵损失为:
\mathcal{L}_{\text{CE}} = - \left[ y \log(\hat{y}) + (1 - y) \log(1 - \hat{y}) \right]
2. 对于多分类问题(Multi-class Classification):
给定真实标签的 one-hot 向量 \mathbf{y} = (y_1, y_2, ..., y_C)
,预测概率 \hat{\mathbf{y}} = (\hat{y}_1, \hat{y}_2, ..., \hat{y}_C)
,其中 C
是类别数,交叉熵损失为:
\mathcal{L}_{\text{CE}} = - \sum_{i=1}^{C} y_i \log(\hat{y}_i)
通常只有一个
y_i = 1
,其余为 0,所以实际上就是取:
\mathcal{L}_{\text{CE}} = - \log(\hat{y}_{\text{true class}})
在 PyTorch 中使用交叉熵
在 PyTorch 中,torch.nn.CrossEntropyLoss
用于多分类问题,torch.nn.BCEWithLogitsLoss
用于二分类问题。
1. 二分类问题(Binary Classification)
import torch
import torch.nn as nn
# 真实标签和预测概率(未经 sigmoid)
y_true = torch.tensor([1.0, 0.0, 1.0]) # 真实标签
y_pred = torch.tensor([0.9, 0.2, 0.8]) # 预测的概率
# 使用 BCEWithLogitsLoss
criterion = nn.BCEWithLogitsLoss()
loss = criterion(y_pred, y_true)
print("Binary Cross Entropy Loss:", loss.item())
2. 多分类问题(Multi-class Classification)
import torch
import torch.nn as nn
# 真实标签(one-hot 向量)
y_true = torch.tensor([2, 0, 1]) # 真实标签是类别的索引
y_pred = torch.tensor([[2.0, 1.0, 0.5], [1.0, 2.0, 0.1], [0.1, 1.5, 2.0]]) # 预测的 logits
# 使用 CrossEntropyLoss
criterion = nn.CrossEntropyLoss()
loss = criterion(y_pred, y_true)
print("Multi-class Cross Entropy Loss:", loss.item())
在 TensorFlow 中使用交叉熵
在 TensorFlow 中,tf.keras.losses.BinaryCrossentropy
用于二分类问题,tf.keras.losses.CategoricalCrossentropy
用于多分类问题。
1. 二分类问题(Binary Classification)
import tensorflow as tf
# 真实标签和预测概率(未经 sigmoid)
y_true = tf.constant([1.0, 0.0, 1.0]) # 真实标签
y_pred = tf.constant([0.9, 0.2, 0.8]) # 预测的概率
# 使用 BinaryCrossentropy
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)
loss = loss_fn(y_true, y_pred)
print("Binary Cross Entropy Loss:", loss.numpy())
2. 多分类问题(Multi-class Classification)
import tensorflow as tf
# 真实标签(one-hot 向量)
y_true = tf.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) # one-hot 编码的真实标签
y_pred = tf.constant([[2.0, 1.0, 0.5], [1.0, 2.0, 0.1], [0.1, 1.5, 2.0]]) # 预测的 logits
# 使用 CategoricalCrossentropy
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
loss = loss_fn(y_true, y_pred)
print("Multi-class Cross Entropy Loss:", loss.numpy())
这些代码片段展示了如何在 PyTorch 和 TensorFlow 中计算交叉熵损失,分别适用于二分类和多分类问题。