Transformer架构


公式
Parameter
Parameter | Description |
batch size | |
the model size / hidden state dimension / positional encoding size | |
number of attention heads | |
head dimension, usually | |
sequence length | |
number of transformer layers | |
hidden state of nth layer | |
Weight matrix of first feed-forward layer | |
Weight matrix of second feed-forward layer | |
Weight matrix of query projection | |
Weight matrix of key projection | |
Weight matrix of value projection | |
Weight matrix of output projection |
Transformer layer
FNN layer
MultiHeadAttention
SelfAttention
单次训练所需要的算力
一次MatMul所需要的算力
forward
MatMul是深度神经网络中最常见的的操作,也是最消耗算力的步骤,我们想先把他拆解开。
我们以向量乘法 为例,其中 ,,。
计算结果向量C中一个元素,需要K次乘法(A向量中一行的每个元素逐个乘以B向量一列的每个元素,A向量中一行和B向量中一列的元素个数都为K),和K次加法,用于把每个乘的结果相加。计算矩阵中一个元素需要进行 次的计算
结果矩阵C中一共有 个,所以整个矩阵计算一共需要计算 次,也就是 FLOPs
backward
计算矩阵A和B的梯度后更新:
- 需要的算力是 FLOPs
- 需要的算力是 FLOPs
- total: FLOPs
total
总的来看,一次MatMul考虑forward和backward过程,一共需要算力: FLOPs
Transformer所需算力
现在,我们可以根据前面的公式,来逐个拆解计算一个Transformer layer需要的算力情况。
为了简化结果,我们下面的计算只考虑 MatMul,而不考虑逐个元素的操作,比如说:layer normalization, GeLU activations 和 residual connections.
FFN FLOPs
- FFN公式(3) 所需要的算力是:
- FFN公式(5) 所需要的算力是:
合起来算的话,前向传播需要 ,反向传播需要 ,所以FFN需要的算力是:
中是根据bs(batch size)做了一次广播,把相同的权重应用在batch size的每一个 的矩阵中
QKVO FLOPs
对应的是multi head attention公式中的(6)(7)(8)(9),每个都是矩阵乘法,每次计算量都是
合并起来算的话,前向传播需要 ,反向传播需要 ,所以QKVO所需要的算力是:
ATT FLOPs
对应的是最后计算self attention的步骤,self attention计算的是一个head的数据,因此下面的算力结果需要乘以
- 计算公式(10) ,需要的算力是
- 计算公式(12) ,需要的算力是
因为 ,所以, ,多头self attention的算力合并起来是:
所以attention的算力,前向传播需要 ,反向传播需要 ,所以ATT所需要的算力是:
但是,我们因为Transformer是自回归模型,当前token只会和前面的token计算attention,所以矩阵结果实际是个lower tridiagonal matrix。那么公式(10)和(12)的算力会减少为 ,最终ATT需要的算力是:
算力和context length的关系
为了评估context length在训练过程中的算力占比,我们可以用公式
从公式结果看,在d固定的情况下,随着context length增加, 的算力占比会增加。下面有个图展示了在LLaMA_7B(d = 4096)下,context length变化对算力的有影响。

以长上下文评测榜单:https://longbench2.github.io/#leaderboard的数据作为参考,目前长上下文基本上都是128K。
维持d=4096不变,attention的算力比率会从2K的4.07%,增加到128K的260%。整体算力需要上涨。
如果说context length=2K的模型训练需要一周,那么当context length=128K的时候,则需要3.5周才能训练完成

平衡算力和参数量的关系
从公式15中,我们还能看出来,算力占比和d有关。大部分情况下,随着模型参数量的增加,为了能够更加丰富的表达token,也会随之增加d,比如说LLaMA_7B的d=4096,但是LLaMA-65B却翻倍增加到了d = 8192。

从结果上,相比于7B的模型,65B的模型在从2K增加到,128K,训练时长只增加了130%。
Loading Comments...