编码实践 | 一文读懂Self-Attention机制

介绍

在当代自然语言处理(NLP)和深度学习领域,Transformer模型及其核心组成部分——自注意力(Self-Attention)机制——已经成为一种革命性的架构,对序列建模任务产生了深远的影响。自从Vaswani等人在2017年的论文《Attention is All You Need》中首次提出以来,Transformer模型已经成为了多种复杂任务的基石,包括机器翻译、文本生成、语音识别以及图像处理等。

在Transformer中,自注意力(Self-Attention)因其卓越的性能,引发了学界和工业界的广泛关注。注意力机制让模型在每个时间步骤都能访问所有序列元素。其中的关键在于选择性,也就是确定在特定上下文中哪些词最重要。它通过纳入与输入上下文有关的信息来增强输入嵌入的信息内容。换句话说,自注意力机制让模型能够权衡输入序列中不同元素的重要性,并动态调整它们对输出的影响。

在本文,我们根据文章《Understanding and Coding Self-Attention, Multi-Head Attention, Cross-Attention, and Causal-Attention in LLMs》中的内容,手撕自注意力(Self-Attention)。其中文版发表在机器之心往期文章《大模型时代还不理解自注意力?这篇文章教你从头写代码实现》上。相比于上述两篇文章,本文加入了作者自己的思考和经验。

Embedding

在这一章节,我们通过Embedding编码,将输入序列中的离散符号(如单词或字符)转换为连续的、高维的向量表示。简单来讲,这样做的原因是深度学习模型无法直接理解原始的文本数据,而必须通过学习这些向量中的信息来理解文本的语义和语法结构。更深入的内容可以查看知乎文章《一文读懂Embedding的概念,以及它和深度学习的关系》

对于一段句子,我们希望通过Embedding将其转化为连续的向量表示,这里以句子「Life is short, eat dessert first」为例。

在预处理阶段,我们需要对句子中的单词进行去重,然后将每一个单词与一个整数索引进行映射。这个过程可以用Python轻松表示。

输入:

sentence = 'Life is short, eat dessert first'

dc = {s:i for i,s 
      in enumerate(sorted(sentence.replace(',', '').split()))}

print(dc)

输出:

{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}

通过将单词与整数索引的映射,句子可以被索引表示。

输入:

import torch

sentence_int = torch.tensor(
    [dc[s] for s in sentence.replace(',', '').split()]
)
print(sentence_int)

输出:

tensor([0, 4, 5, 2, 1, 3])

好了!接下来我们就要进行Embedding的关键步骤了。

在Embedding中,每个单词都被一个多维向量所表示。向量的维度由词库大小所决定。举个例子,Llama 2 的嵌入大小为 4096。为不显冗余,在本文,我们使用三维作为样例。

输入:

vocab_size = 50_000

torch.manual_seed(123)
embed = torch.nn.Embedding(vocab_size, 3)
embedded_sentence = embed(sentence_int).detach()

print(embedded_sentence)
print(embedded_sentence.shape)

输出:

tensor([[ 0.3374, -0.1778, -0.3035],
        [ 0.1794,  1.8951,  0.4954],
        [ 0.2692, -0.0770, -1.0205],
        [-0.2196, -0.3792,  0.7671],
        [-0.5880,  0.3486,  0.6603],
        [-1.1925,  0.6984, -1.4097]])
torch.Size([6, 3])

如输出所示,我们句子中的6个单词,被表示为了六行向量,每个向量的维度为3。这意味着,句子中的每个单词被一个由三个数字构成的向量所表示。

定义权重矩阵

自这一章节,我们将介绍Transformer中赫赫有名的"QKV"机制。

自注意力使用了三个权重矩阵,分别记为 $W_q$、$W_k$ 和 $W_v$;它们作为模型参数,会在训练过程中不断调整。这些矩阵的作用是将输入分别投射成序列的查询、键和值分量。

相应的查询、键和值序列可通过权重矩阵 $W$ 和嵌入的输入 $x$ 之间的矩阵乘法来获得:

  • 查询序列:对于属于序列 $1……T$ 的 $i$,有 $q^{(i)}=x^{(i)}W_q$
  • 键序列:对于属于序列 $1……T$ 的 $i$,有 $k^{(i)}=x^{(i)}W_k$
  • 值序列:对于属于序列 $1……T$ 的 $i$,有 $v^{(i)}=x^{(i)}W_v$
  • 索引 $i$ 是指输入序列中的 $token$ 索引位置,其长度为 $T$。

在本文,$q^{(i)}$ 和$k^{(i)}$都是维度为$d_k$的向量。投射矩阵$W_q$和$W_k$的形状为 $d × d_k$,而$W_v$的形状是$d × d_v$。其中,$d$表示每个词向量$x$的维度数,本文中为3。

由于我们要计算查询和键向量的点积,因此这两个向量的元素数量必须相同($d_q=d_k$)。很多大模型也会使用同样大小的值向量,也就是 $d_q=d_k=d_v$。但是,值向量$v^{(i)}$的元素数量可以是任意值,其决定了所得上下文向量的大小。

在接下来的代码中,我们将设定$d_q=d_k=2$,而$d_v=4$。投射矩阵的初始化如下:

输入:

torch.manual_seed(123)

d = embedded_sentence.shape[1]

d_q, d_k, d_v = 2, 2, 4

W_query = torch.nn.Parameter(torch.rand(d, d_q))
W_key = torch.nn.Parameter(torch.rand(d, d_k))
W_value = torch.nn.Parameter(torch.rand(d, d_v))

在原论文《Attention is All You Need》中,$d_q$、$d_k$和$d_v$通常设置为64,而模型的总维度为512

计算非归一化的注意力权重

在本章,我们以第二个Token为样例,来进行演示:


如上图所示,我们需要将输入$x$与 $W_q$、$W_k$ 、$W_v$分别相乘。

代码如下:

x_2 = embedded_sentence[1]
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value

print(query_2.shape)
print(key_2.shape)
print(value_2.shape)

输出:

torch.Size([2])
torch.Size([2])
torch.Size([4])

推而广之,我们可以将embedded_sentence与$W_k$ 、$W_v$分别相乘,在后续的步骤中会被用到。

输入:

keys = embedded_sentence @ W_key
values = embedded_sentence @ W_value

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

输出:

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 4])

现在我们已经拥有了$query^{(2)}$和所有的keysvalues,让我们接下来计算非归一化注意力权重$\omega$,如下图所示:


如上图所示,$\omega (i,j)$ 是查询和键序列之间的点积 $\omega (i,j) = q^{(i)}k^{(j)}$。

举个例子,我们能以如下方式计算查询第2个Token与第5个Token之间的非归一化注意力矩阵:

输入:

omega_24 = query_2.dot(keys[4])
print(omega_24)

输出:

tensor(1.2903)

推而广之,我们可以将query_2keys相乘,获得第2个Token与其他Token之间的非归一化注意力矩阵。

输入:

omega_2 = query_2 @ keys.T
print(omega_2)

输出:

tensor([-0.6004,  3.4707, -1.5023,  0.4991,  1.2903, -1.3374])

计算注意力权重

在上一章,我们已经计算出了第2个Token与其他Token之间的非归一化注意力矩阵。在实际应用中,我们还需要对注意力矩阵进行归一化。

对注意力矩阵进行归一化的目的在于,将每个序列位置上的注意力权重都在0到1之间,且所有位置的权重加起来等于1。这样的概率解释使模型能够以概率的形式表达对序列中不同部分的关注程度,其中较高的权重表示模型在该位置上给予更多的关注。


如上图所示,在自注意力机制中,会先用 $1/√{d_k} $对$\omega$进行缩放,然后使用softmax 函数进行归一化。

按$d_k$进行缩放可确保权重向量的欧几里得长度都大致在同等尺度上。这有助于防止注意力权重变得太小或太大 —— 这可能导致数值不稳定或影响模型在训练期间收敛的能力。

我们可以这样用代码实现注意力权重的计算:

输入:

import torch.nn.functional as F

attention_weights_2 = F.softmax(omega_2 / d_k**0.5, dim=0)
print(attention_weights_2)

输出:

tensor([0.0386, 0.6870, 0.0204, 0.0840, 0.1470, 0.0229])

最后一步是计算上下文向量 $z^{(2)}$,即原始查询输入$x^{(2)}$ 经过注意力加权后的版本,其通过注意力权重将所有其它输入元素作为了上下文:


输入:

context_vector_2 = attention_weights_2 @ values

print(context_vector_2.shape)
print(context_vector_2)

输出:

torch.Size([4])
tensor([0.5313, 1.3607, 0.7891, 1.3110])

请注意,这个输出向量的维度($d_v=4$)比输入向量($d=3$)多,因为我们之前已经设定了 $d_v > d$。但是,$d_v$的嵌入大小可以任意选择。

自注意力

现在,总结一下之前小节中自注意力机制的代码实现。

我们可以将之前的代码总结成一个紧凑的Self-Attention 类:

import torch.nn as nn

class SelfAttention(nn.Module):

    def __init__(self, d_in, d_out_kq, d_out_v):
        super().__init__()
        self.d_out_kq = d_out_kq
        self.W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out_v))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value

        attn_scores = queries @ keys.T  # unnormalized attention weights    
        attn_weights = torch.softmax(
            attn_scores / self.d_out_kq**0.5, dim=-1
        )

        context_vec = attn_weights @ values
        return context_vec

输入:

torch.manual_seed(123)

# reduce d_out_v from 4 to 1, because we have 4 heads
d_in, d_out_kq, d_out_v = 3, 2, 4

sa = SelfAttention(d_in, d_out_kq, d_out_v)
print(sa(embedded_sentence))

输出:

tensor([[-0.1564,  0.1028, -0.0763, -0.0764],
        [ 0.5313,  1.3607,  0.7891,  1.3110],
        [-0.3542, -0.1234, -0.2627, -0.3706],
        [ 0.0071,  0.3345,  0.0969,  0.1998],
        [ 0.1008,  0.4780,  0.2021,  0.3674],
        [-0.5296, -0.2799, -0.4107, -0.6006]], grad_fn=<MmBackward0>)

可以从第二行看到,其值与前一节中 context_vector_2 的值完全一样:tensor ([0.5313, 1.3607, 0.7891, 1.3110])。

总结

在本文,我们根据文章《Understanding and Coding Self-Attention, Multi-Head Attention, Cross-Attention, and Causal-Attention in LLMs》中的内容,手撕自注意力(Self-Attention),提供了代码和解释。