交叉熵公式

内容纲要

交叉熵(Cross Entropy)常用于分类问题中衡量两个概率分布之间的差异。它的公式如下:


1. 对于二分类问题(Binary Classification):

给定真实标签 y \in \{0, 1\},预测概率 \hat{y} \in [0, 1],交叉熵损失为:

\mathcal{L}_{\text{CE}} = - \left[ y \log(\hat{y}) + (1 - y) \log(1 - \hat{y}) \right]


2. 对于多分类问题(Multi-class Classification):

给定真实标签的 one-hot 向量 \mathbf{y} = (y_1, y_2, ..., y_C),预测概率 \hat{\mathbf{y}} = (\hat{y}_1, \hat{y}_2, ..., \hat{y}_C),其中 C 是类别数,交叉熵损失为:

\mathcal{L}_{\text{CE}} = - \sum_{i=1}^{C} y_i \log(\hat{y}_i)

通常只有一个 y_i = 1,其余为 0,所以实际上就是取:

\mathcal{L}_{\text{CE}} = - \log(\hat{y}_{\text{true class}})


在 PyTorch 中使用交叉熵

在 PyTorch 中,torch.nn.CrossEntropyLoss 用于多分类问题,torch.nn.BCEWithLogitsLoss 用于二分类问题。

1. 二分类问题(Binary Classification)

import torch
import torch.nn as nn

# 真实标签和预测概率(未经 sigmoid)
y_true = torch.tensor([1.0, 0.0, 1.0])  # 真实标签
y_pred = torch.tensor([0.9, 0.2, 0.8])  # 预测的概率

# 使用 BCEWithLogitsLoss
criterion = nn.BCEWithLogitsLoss()
loss = criterion(y_pred, y_true)
print("Binary Cross Entropy Loss:", loss.item())

2. 多分类问题(Multi-class Classification)

import torch
import torch.nn as nn

# 真实标签(one-hot 向量)
y_true = torch.tensor([2, 0, 1])  # 真实标签是类别的索引
y_pred = torch.tensor([[2.0, 1.0, 0.5], [1.0, 2.0, 0.1], [0.1, 1.5, 2.0]])  # 预测的 logits

# 使用 CrossEntropyLoss
criterion = nn.CrossEntropyLoss()
loss = criterion(y_pred, y_true)
print("Multi-class Cross Entropy Loss:", loss.item())

在 TensorFlow 中使用交叉熵

在 TensorFlow 中,tf.keras.losses.BinaryCrossentropy 用于二分类问题,tf.keras.losses.CategoricalCrossentropy 用于多分类问题。

1. 二分类问题(Binary Classification)

import tensorflow as tf

# 真实标签和预测概率(未经 sigmoid)
y_true = tf.constant([1.0, 0.0, 1.0])  # 真实标签
y_pred = tf.constant([0.9, 0.2, 0.8])  # 预测的概率

# 使用 BinaryCrossentropy
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)
loss = loss_fn(y_true, y_pred)
print("Binary Cross Entropy Loss:", loss.numpy())

2. 多分类问题(Multi-class Classification)

import tensorflow as tf

# 真实标签(one-hot 向量)
y_true = tf.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])  # one-hot 编码的真实标签
y_pred = tf.constant([[2.0, 1.0, 0.5], [1.0, 2.0, 0.1], [0.1, 1.5, 2.0]])  # 预测的 logits

# 使用 CategoricalCrossentropy
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
loss = loss_fn(y_true, y_pred)
print("Multi-class Cross Entropy Loss:", loss.numpy())

这些代码片段展示了如何在 PyTorch 和 TensorFlow 中计算交叉熵损失,分别适用于二分类和多分类问题。

Leave a Comment

您的电子邮箱地址不会被公开。 必填项已用*标注

close
arrow_upward