深度学习进阶(二)多头自注意力机制(Multi-Head Attention)

第一篇中,我们已经得到了自注意力的核心公式:

\[\mathrm{Attention}(\mathbf{Q},\mathbf{K},\mathbf{V})=\mathrm{softmax}\left(\frac{ \mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}}\right) \mathbf{V} \]

再概述一下自注意力的本质:通过一次全局加权,将序列中的所有信息重新融合到每一个位置上,最终强化信息表示。

但单头的自注意力还是有些局限:一组 \((\mathbf{W}_Q, \mathbf{W}_K, \mathbf{W}_V)\) 只能用一种方式去理解序列
这其实和在卷积层中使用多个卷积核是相似的道理,我们不能只用一个卷积核去提取纹理、色彩、形状等所有特征。
同理,我们不能指望一组参数矩阵就能学习到序列在语义、语法、情感等多个方面的关联。

因此,我们在实际算法设计中,使用的往往是多头自注意力
其原理并不复杂,只是在单头自注意力基础上的简单改进。

1. 多头注意力的核心思想

多头注意力的核心思想很直观:

在“一层自注意力层”中,不只做一次注意力,而是做多次注意力,每次关注不同的信息子空间。

具体来说就是:不再只用一组 \((\mathbf{W}_Q, \mathbf{W}_K, \mathbf{W}_V)\), 而是同时学习 \(h\) 组不同的参数矩阵,这样的每一组参数矩阵就是一个“头”,综合所有头的注意力信息,得到最终输出。

比如,第 \(t\) 个头为:

\[\mathbf{Q}_t = \mathbf{X} \mathbf{W}_Q^{(t)} \]

\[\mathbf{K}_t = \mathbf{X} \mathbf{W}_K^{(t)} \]

\[\mathbf{V}_t = \mathbf{X} \mathbf{W}_V^{(t)} \]

然后,每一个头都独立进行注意力计算:

\[\mathbf{Z}_t = \mathrm{Attention}(\mathbf{Q}_t, \mathbf{K}_t, \mathbf{V}_t) \]

由于每个注意力头拥有独立的初始化参数矩阵,所以每一个头都是一个“观察角度”,它们分别回答在不同语义空间下相关性问题。
其计算过程和单头自注意力并无区别,但一个新的问题是:

如何融合多头输出?

2. 多头的融合方式

2.1 拼接

通过多头自注意力,现在我们得到了多个输出:

\[\mathbf{Z}_1, \mathbf{Z}_2, \dots, \mathbf{Z}_h \]

而要回答这些信息怎么合在一起,首先要了解 Transformer 对每个头的维度划分设计

\[d_k = d_v = \frac{d}{h} \]

这里的 \(d\) 仍是表示序列中一个位置,或者说一个 token 的特征维度。

举个例子来说明:
假定每个 token进入模型的特征维度:

\[d = 8 \]

并且使用:

\[h = 2 \quad (\text{两个注意力头}) \]

那么每个头的维度为:

\[d_k = d_v = \frac{d}{h} = \frac{8}{2} = 4 \]

这代表每个头的 Query / Key 向量维度和 Value 输出维度都为 4,分别计算注意力得到输出:

\[\mathbf{Z_1},\mathbf{Z_2} \in \mathbb{R}^{n \times 4} \]

明确了维度变化后,我们就能进行多头输出融合的第一步:拼接

\[\mathbf{Z} = \text{Concat}(\mathbf{Z}_1, \mathbf{Z}_2, \dots, \mathbf{Z}_h) \]

思路很明确,就是把所有头的输出先直接拼在一起

\[\mathbf{Z} \in \mathbb{R}^{n \times (h \cdot d_v=d)} \]

可以发现,拼接后的维度重新恢复到了原始的模型维度 \(d\)
这样的切分逻辑可以在固定模型维度 \(d\) 的前提下,使多头不会增加总体计算复杂度的数量级,从而避免因头数增加而产生计算量爆炸问题。
同时,这种让输入输出维度相同的设计也和 Transformer 的后续逻辑相关。

2.2 线性变换

拼接完还没有结束,实际上,在这之后,\(\mathbf{Z}\) 还需要经过一个线性层:

\[\mathbf{Z}_{\text{final}} = \mathbf{Z} \mathbf{W}_O,\mathbf{W}_O \in \mathbb{R}^{d \times d} \]

你会发现,这里没有加入偏置
实际上,线性层本身自然是拥有偏置的,但许多 Transformer 实现都会选择关闭偏置,这是因为 Transformer block 的结构中仍存在后续线性变换以及归一化操作,这里的单个线性层中的偏置项对整体表达能力的影响较小,因此在理论公式中经常省略。

现在把整体写成一行如下:

\[\mathrm{MultiHead}(\mathbf{Q},\mathbf{K},\mathbf{V}) = \text{Concat}(\mathbf{Z}_1,\dots,\mathbf{Z}_h)\mathbf{W}_O \]

这就是多头注意力的融合公式。

到这里,你可能有这样一个问题:只拼接不行吗?为什么还要再过一个线性层?
我们举个例子来回答这个问题:
假设对于某个位置 \(i\),经过多头注意力后,我们得到了两个头的输出:

\[\mathbf{z}_1^{(i)} = [语法,结构,顺序,主谓] \]

\[\mathbf{z}_2^{(i)} = [语义,情感,主题,语境] \]

拼接之后得到:

\[\mathbf{z}^{(i)} = [语法,结构,顺序,主谓 , 语义,情感,主题,语境] \]

此时,不同头的信息只是被“并排放在一起”,但它们之间并没有发生任何交互或关联。
也就是说:有点用,但还不够。

现在再进行融合:

\[\mathbf{z}_{final}^{(i)} = \mathbf{z}^{(i)} \mathbf{W}_O \]

这步计算实际上是对拼接后的所有特征进行一次“重新加权组合”。
假设经过学习后,\(\mathbf{W}_O\) 做出的组合类似于:

  1. 把“主谓” + “语义”组合,得到更准确的句法语义关系。
  2. 把“结构” + “语境”组合,得到更高层次的上下文理解。
  3. 把“情感”适当放大或抑制。

最终形成新的表示:

\[\mathbf{z}_{final}^{(i)} = [综合特征_1, 综合特征_2, \dots] \]

由此,所有头的信息被打散并重新组合,模型可以自由地学习跨多头的特征关系。

这就是多头自注意力机制的详细内容,我们由此实现了从多个角度对输入信息的强化表示,从模型整体角度来说,多头注意力本质上是一个用于建模序列内部关系的计算模块。

因此,在 Transformer 中,多头注意力并不是单独使用的,而是被嵌入到一个更完整的结构单元中,这个单元就是 Transformer Block,我们在下一篇中再对其展开介绍。

文章摘自:https://www.cnblogs.com/Goblinscholar/p/19811074