继上次了解完信息熵后,没想到在对抗训练中又回到了损失函数的分析上。torch
实现的有点迷惑,正好趁这次机会,连着代码和公式好好推导下。这种损失的优缺点日后会放上来,主要取决论文进度。
一个简单实例
交叉熵的公式定义耳熟能详了:
\begin{equation}\label{ce}
H(p,q) = -\sum_{x\in X} p(x) \log q(x)
\end{equation}
这里带入一个实例算一下:
1 | # example of calculating cross entropy |
需要注意的是,$H(p,q) \neq H(q,p)$。那么再来看深度学习中的定义,其实都一样,只是深度学习中要求真实标签不能是 $\log$ 那一项,因为 $\log$ 的定义域不能取 0。
\begin{equation}
H = -\sum_{c=1}^M y \log \hat{y}
\end{equation}
其中,$y$ 是真实的标签,$M$ 是类别数量,$\hat{y}$ 是自己预测的结果。如果是三分类问题,假设 $y=[1,0,0]$,$\hat{y}=[0.4,0.1,0.5]$,那么此时的交叉熵是 $-\log 0.4$。因为其它标签取值是 0,即不参与运算,换句话说,交叉熵的损失只取决于被正确分类的概率,具体缺陷可以公式推导一下。带入上述代码,得到一致的结果。
torch 实现
用过 tensorflow 和 keras 之后,愈发的发现 pytorch 真好用。torch
实现的交叉熵有一点需要注意,是通过 logsoftmaax + NLLLoss
实现的,实际代码中要注意这里,不要操作失误。
- 当获取网络的输出后,这里设网络的输出是 $O$,那么直接将 $O$ 扔进
logsoftmax
为交叉熵做准备,不需要激活。相当于计算公式 $\eqref{ce}$ 中的 $\log$ 那一项。 - 而后
logsoftmax
的输出进入NLLLoss
,计算真正的交叉熵损失。计算方法为,如果是三分类的任务,logsoftmax
的输出是 $-0.5,-0.1,-0.4$,如果当前类别是第一个类,那么损失就是 0.5,具体原因参考上述公式推导。
如果你感觉哪里错了,可以看文末的参考链接中,有人用各种方法试验了 torch
的交叉熵,验证了 torch
的实现是正确的。
优缺点分析
- 优点:对于 softmax 函数,常用的 MSE 误差越大,下降越慢。但交叉熵能很好改善这个问题,当误差较大时,梯度也大,下降的较快;也避免了某些情况下激活函数进入饱和区,梯度消失的问题。详情参考文末的公式推导。
- 缺点:论文要用,等哪天论文尘埃落定,我在回来填坑。
你可以手推公式,推着推着就发现了。