0%

交叉熵优缺点分析

继上次了解完信息熵后,没想到在对抗训练中又回到了损失函数的分析上。torch实现的有点迷惑,正好趁这次机会,连着代码和公式好好推导下。这种损失的优缺点日后会放上来,主要取决论文进度。

一个简单实例

交叉熵的公式定义耳熟能详了:

\begin{equation}\label{ce}
H(p,q) = -\sum_{x\in X} p(x) \log q(x)
\end{equation}

这里带入一个实例算一下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# example of calculating cross entropy
from math import log2

# calculate cross entropy
def cross_entropy(p, q):
return -sum([p[i] * log2(q[i]) for i in range(len(p))])

# define data
p = [0.10, 0.40, 0.50]
q = [0.80, 0.15, 0.05]

# calculate cross entropy H(P, Q)
ce_pq = cross_entropy(p, q)
print('H(P, Q): %.3f bits' % ce_pq) # H(P, Q): 3.288 bits
# calculate cross entropy H(Q, P)
ce_qp = cross_entropy(q, p)
print('H(Q, P): %.3f bits' % ce_qp) # H(Q, P): 2.906 bits

# 3 bit 的信息差异

需要注意的是,$H(p,q) \neq H(q,p)$。那么再来看深度学习中的定义,其实都一样,只是深度学习中要求真实标签不能是 $\log$ 那一项,因为 $\log$ 的定义域不能取 0。

\begin{equation}
H = -\sum_{c=1}^M y \log \hat{y}
\end{equation}

其中,$y$ 是真实的标签,$M$ 是类别数量,$\hat{y}$ 是自己预测的结果。如果是三分类问题,假设 $y=[1,0,0]$,$\hat{y}=[0.4,0.1,0.5]$,那么此时的交叉熵是 $-\log 0.4$。因为其它标签取值是 0,即不参与运算,换句话说,交叉熵的损失只取决于被正确分类的概率,具体缺陷可以公式推导一下。带入上述代码,得到一致的结果。

torch 实现

用过 tensorflow 和 keras 之后,愈发的发现 pytorch 真好用。torch 实现的交叉熵有一点需要注意,是通过 logsoftmaax + NLLLoss 实现的,实际代码中要注意这里,不要操作失误。

  1. 当获取网络的输出后,这里设网络的输出是 $O$,那么直接将 $O$ 扔进 logsoftmax 为交叉熵做准备,不需要激活。相当于计算公式 $\eqref{ce}$ 中的 $\log$ 那一项。
  2. 而后 logsoftmax 的输出进入 NLLLoss,计算真正的交叉熵损失。计算方法为,如果是三分类的任务,logsoftmax 的输出是 $-0.5,-0.1,-0.4$,如果当前类别是第一个类,那么损失就是 0.5,具体原因参考上述公式推导。

如果你感觉哪里错了,可以看文末的参考链接中,有人用各种方法试验了 torch 的交叉熵,验证了 torch 的实现是正确的。

优缺点分析

  1. 优点:对于 softmax 函数,常用的 MSE 误差越大,下降越慢。但交叉熵能很好改善这个问题,当误差较大时,梯度也大,下降的较快;也避免了某些情况下激活函数进入饱和区,梯度消失的问题。详情参考文末的公式推导。
  2. 缺点:论文要用,等哪天论文尘埃落定,我在回来填坑。你可以手推公式,推着推着就发现了。

参考

  1. 交叉熵定义
  2. torch 交叉熵代码验证
  3. 交叉熵优点及公式推导
感谢上学期间打赏我的朋友们。赛博乞讨:我,秦始皇,打钱。

欢迎订阅我的文章