交叉熵损失:简单解释,数学原理

阅读水平:适合对人工智能基础知识感兴趣的初学者。无需机器学习先验知识。
想象一下教你的狗区分网球和零食,并奖励正确的猜测。在这个简单的场景中,你刚刚实现了一个基本的“损失函数”——驱动机器学习的反馈机制。
从面部识别到语言翻译,我发现交叉熵损失仍然是现代人工智能系统的基石。
交叉熵:置信度计量器
交叉熵函数就像一位理想的教练,它不仅衡量预测是对还是错,还衡量其自信程度。
- 正确且自信? 轻微惩罚
- 正确但不确定? 中等惩罚
- 错误但不确定? 中等惩罚
- 错误且自信? 严重惩罚
这创造了完美的学习环境——只在有充分理由时才自信。
其运作的数学原理
对于是/否问题(例如“这封邮件是垃圾邮件吗?”),交叉熵的计算方式如下:
其中:
- 是真实答案(1 表示“是”,0 表示“否”)
- 是模型的置信度(从 0 到 1)
- 是产生的惩罚
这个公式的两个方面值得解释
为什么是负号? 机器学习系统通过最小化损失函数(找到可能的最低值)来工作。然而,在使用概率时,我们实际上希望最大化正确预测的可能性。负号将这个最大化问题转化为最小化问题——本质上是告诉算法:“最小化我们想要最大化的负值。”
为什么是对数? 对数有多个关键用途:
- 它们将微小的概率乘法转换为更简单的加法
- 它们对自信的错误施加指数级更严厉的惩罚
- 它们在处理非常小的概率时提供数值稳定性
- 它们与信息论直接相关,其中对数概率表示信息比特
对数创造了完美的惩罚曲线——高置信度预测中的小错误比低置信度预测中的相同错误受到更严重的惩罚。
让我们用实际数字来看看这个
场景 1: 邮件是垃圾邮件 ,AI 有 90% 的置信度 损失 = - 正确且自信的轻微惩罚
场景 2: 邮件是垃圾邮件 ,AI 只有 10% 的置信度
损失 = - 大约差 22 倍!正确但太不确定
场景 3: 邮件不是垃圾邮件 ,AI 有 90% 的置信度是
损失 = - 自信地犯错也会受到同样的严重惩罚
从错误中学习
回溯到 2018 年,我的团队的图像识别模型曾自信地将一只吉娃娃错误分类为蓝莓松饼。惩罚计算揭示了问题为何如此严重。
- 真实答案:“吉娃娃”
- AI 预测:是吉娃娃的概率为 1%
- 损失 =
如果它表现出不确定性(50% 的置信度),损失将只有 0.693——几乎低了 7 倍。这种显著的差异表明了为什么交叉熵在教导模型校准其置信度方面表现出色。
从数学到机器:真实代码
以下是一个简化版的实现,它为数十亿美元的 AI 系统提供动力:
import numpy as np
def binary_cross_entropy(y_true, y_pred):
# Prevent log(0) which would crash with -infinity
y_pred = np.clip(y_pred, 0.000001, 0.999999)
# The cross-entropy formula
loss = -(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))
return loss
# Examples
print(f"Correct and confident: {binary_cross_entropy(1, 0.9):.3f}") # 0.105
print(f"Wrong and confident: {binary_cross_entropy(1, 0.1):.3f}") # 2.303
仅仅十行代码就驱动了价值数十亿美元的系统的学习。
多项选择问题
对于多类别问题(例如将图像分类为狗/猫/马/人),交叉熵扩展为:
这会将每种可能选择的损失相加,只有正确的选择才会对损失产生贡献。
考虑一个将图像分类为 {狗、猫、老虎、狮子} 的模型
# True label is "cat" (one-hot encoded)
true_label = [0, 1, 0, 0] # [dog, cat, tiger, lion]
# Different prediction scenarios with their losses:
# Confident+correct [0.01, 0.97, 0.01, 0.01]: Loss = 0.03
# Uncertain [0.20, 0.29, 0.31, 0.20]: Loss = 1.24
# Confident+wrong [0.97, 0.01, 0.01, 0.01]: Loss = 4.61
惩罚的巨大差异促使模型做出准确、校准良好的预测。
使用交叉熵损失训练神经网络
让我们看看交叉熵在训练用于图像分类的简单神经网络(使用 PyTorch)时的实际应用。
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
# Load and preprocess MNIST dataset (handwritten digits)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
trainset = torchvision.datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
shuffle=True)
# Define a simple neural network
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10) # 10 classes (digits 0-9)
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
# Initialize the model, loss function, and optimizer
model = SimpleNN()
loss_fn = nn.CrossEntropyLoss() # PyTorch's cross-entropy loss
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# Training loop
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
# Forward pass
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss_value = loss.item()
current = batch * len(X)
print(f"loss: {loss_value:>7f} [{current:>5d}/{size:>5d}]")
# Train for 5 epochs
epochs = 5
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(trainloader, model, loss_fn, optimizer)
print("Training complete!")
# Test on a few examples
model.eval()
test_images, test_labels = next(iter(trainloader))
with torch.no_grad():
# Get predictions
outputs = model(test_images[:5])
# Get confidence scores using softmax
probabilities = nn.functional.softmax(outputs, dim=1)
for i in range(5):
true_label = test_labels[i].item()
pred_label = probabilities[i].argmax().item()
confidence = probabilities[i][pred_label].item() * 100
print(f"Image {i+1}:")
print(f" True label: {true_label}")
print(f" Predicted: {pred_label} with {confidence:.2f}% confidence")
# Calculate cross-entropy loss for this example
loss = -torch.log(probabilities[i][true_label])
print(f" Loss: {loss.item():.4f}")
print()
在这个例子中,nn.CrossEntropyLoss()
结合了两个操作:
- 一个 softmax 函数,将原始模型输出(logits)转换为概率
- 对这些概率进行负对数似然损失计算
训练循环会重复执行以下步骤:
- 使用当前模型进行预测
- 计算预测与真实标签之间的交叉熵损失
- 计算损失相对于模型参数的梯度
- 朝向减少损失的方向更新模型参数
随着训练的进行,模型学会将更高的概率分配给正确的类别,从而减少交叉熵损失。模型的置信度逐渐与其准确性保持一致——这正是我们想要的!
最终的测试部分演示了交叉熵损失如何创建置信度和正确性之间的直接关系,精确地显示了不同预测场景下产生的损失量。
高级技术
研究人员已经开发出交叉熵的巧妙扩展:
标签平滑
我们不使用纯粹的 0/1 标签进行训练,而是使用略微“平滑”的值,例如 0.1/0.9。
其中 。这可以防止过度自信并提高模型的鲁棒性。
焦点损失
对于大多数样本都容易的问题(例如安全摄像头,其中大多数帧都没有显示重要内容),焦点损失将学习重点放在困难样本上。
其中 会降低分类良好样本的损失。这项突破性技术彻底改变了图像中的目标检测。
交叉熵损失函数扮演着理想老师的角色:它要求模型诚实地面对不确定性,奖励校准良好的置信度,并根据错误的大小施加相应的惩罚。
如果这篇介绍激发了你的好奇心,可以考虑探索以下资源,以加深理解:
克劳德·香农(Claude Shannon)1948 年的奠基性论文 “通信的数学理论”,它建立了信息论和熵概念
库尔巴克(Kullback)和莱布勒(Leibler)关于 “信息与充分性”(1951)的工作,它规范了与交叉熵相关的 KL 散度
迈克尔·尼尔森(Michael Nielsen)出色的 在线书籍章节,提供了神经网络中交叉熵的直观解释
Mao、Mohri 和 Zhong (2023) 的最新论文 “交叉熵损失函数:理论分析与应用”,它推动了我们对交叉熵保证的理论理解
Zhang 和 Sabuncu 关于 “用于含噪声标签的深度神经网络训练的广义交叉熵损失” (2018) 的工作,它提供了从不完美数据中学习的解决方案