llama2.c 源码阅读

1. 概述

前OpenAI著名工程师Andrej Kapathy开源了llama2.c项目,该项目是llama2模型推理代码的C语言实现,用大概970行C代码实现了LLama2模型的推理算法。整个项目代码简洁高效,值得深度阅读。对掌握大模型推理算法的细节有极大的帮助。

2. 源码阅读

2.1 基础算法

RMS归一化公式是:

$$ o_i = w_i \times x_i \times \frac {1}{\sqrt{\frac{1}{n}\sum_{j=0}^{n-1} x_j^2 + \epsilon}} $$

其中,\(\epsilon\) 为防止分母为0的数值。还有RMS因子是对x的归一化,w变量是gain变量,重新缩放标准化后的输入向量。

softmax函数公式是:

$$ o_i = \frac {e^{x_i-x_{max}}}{\sum_{j=0}^{n-1} e^{x_j-x_{max}}} $$

代码如下,注释说的很清楚,减去最大值是为了防止数值溢出,数值更稳定。通过简单数学变换可以得知,最终结果不变。

W (d,n) @ x (n,) -> xout (d,)的矩阵乘法,采用naive的矩阵乘法,即外层循环是行,内层循环是列。代码如下:

2.2. forward计算

模型中一个attention block的计算如下图所示:

项目代码是按照每一个token来计算QKV的,其中参数dim是transformer的向量维度。l是layer序号。

第一步是rmsnorm,即归一化。输入是x (d,),rms权重向量是w->rms_att_weight + l*dim,计算结果输出到s->xb (d,)中。

第二步是QKV的矩阵乘法,注意kv_dim和dim的区别,是为了同时兼容multi head attention和grouped query attention两种算法。如下图所示:

kv_dim是key和value的总维度,dim是transformer的向量总维度。在multi head attention中,kv_dim = dim。在grouped query attention中,kv_dim = dim * n_kv_heads / n_heads。以图中为例,n_kv_heads = 4, n_heads = 8,则kv_dim = dim / 2。

对于各矩阵的维度,以及在MHA、GQA等算法中的关系,参考下图:

Q、K、V三个向量计算的详细代码如下,即Wq(d,d) @ xb(d,) -> q(d,),Wk(dkv,d) @ xb(d,) -> k(dkv,), Wv(dkv,d) @ xb(d,) -> v(dkv,)

接下来需要给Q和K向量添加RoPE位置编码,按照如下公式计算,其中m就是当前token的序号pos。需要注意的是,llama模型是给每一层的Q和K向量都添加这个编码。

$$ \begin{aligned} \theta_i &= \frac{1}{10000^{2i/hs}}= 10000^{-2i/hs} \\ Q(i) &=Q(i)\cos (m\theta_i) - Q(i+1)\sin(m\theta_i)\\ Q(i+1) &=Q(i)\sin (m \theta_i) + Q(i+1)\cos(m\theta_i)\\ K(i) &=K(i)\cos (m \theta_i) - K(i+1)\sin(m\theta_i)\\ K(i+1) &=K(i)\sin (m \theta_i) + K(i+1)\cos(m\theta_i)\\ \end{aligned} $$

详细代码如下,注意在GQA中,K的向量长度小于Q的向量长度,所以在i < kv_dim时,计算Q和K的向量。在i >= kv_dim时,只计算Q的向量。

接下来针对每个头,计算attention score。attention score的计算公式如下:

$$ score(i) = softmax(\frac{ Q_i K^T}{\sqrt{d}})V , \quad Q_i \in \R^{1 \times d},K \in \R^{n\times d},V\in\R^{n\times d} $$

具体计算的时候,先遍历每个head,在每个head中,先计算Qi和K的点积,然后除以sqrt(d),得到att (1,n)向量,最后softmax得到attention score。

在GQA中,由于分组共享了Q和K的向量,在计算attention score的时候,需要把Q和K的向量“展开”还原为(n,d)的矩阵,具体做法是通过h / kv_mul,保证 kv_mul个Q和K向量共享一个权重。

然后计算attention score (1,n)和V (n,d)的乘积,得到xb (1,d)。这个计算并不是完全按照普通矩阵乘来计算的,而是把每个位置的attention score和V的 每一行相乘,然后累加到xb中。这样计算的好处是对cache更加友好,是一种常见的矩阵乘算法。

对于每个头,每个token的attention score计算过程的可视化如图所示:

图中可以清楚看出,每个token都计算了一遍和其他token的相关度,再进行加权求和得到最终的attention score。

具体代码如下:

从代码中也能看出,为什么需要把K和V的矩阵进行cache。因为对于一个位置的token而言,Q矩阵每次参与计算的只有当前位置的一行,而K和V矩阵,则是每行都需要 参与计算。最终得到的也是该位置的(1,d)向量作为attention score。因此,为了减少计算量,把K和V矩阵进行cache也是理所当然。

接下来的计算就非常简单,注释也非常直观。详细步骤如下:

  1. 计算Wo (d,d) @ xb^T (d,)得到xb2 (d,)
  2. 通过残差连接,叠加x (d,)向量:x += xb2
  3. x再经过一个RMSNorm(x),得到xb (d,)
  4. 计算hb和hb2:W1(hd, d) @ xb (d,) -> hb1(hd,) , W3(hd, d) @ xb (d,) -> hb2(hd, )
  5. hb经过silu非线性激活函数变换,计算方式为:$$silu(hb) = hb (1/ (1 + e^{-hb}))$$
  6. 然后计算逐位相乘 hb * hb2, 得到hb (hd,)
  7. 计算W2(d, hd) @ hb (hd,) -> xb (d,)
  8. 最终再通过残差连接,叠加xb向量:x += xb

继续每一层的计算,每一层的输入都是x,输出也是x,循环计算。在每一层都算完以后,最后再计算:

  1. RMSNorm(x),把x向量进行归一化。
  2. 计算Wc(dvoc, d) @ x (d,) -> logits (dvoc,),其中dvoc为词典大小。

至此,最终得到的logits就是该位置的在token词典中的分类概率。

2.3 抽样方法

拿到logits之后,需要通过抽样来最终确定输出哪个token,常见的抽样方法有greedy(argmax),随机抽样,以及top-p (nucleus) 抽样。

2.3.1 Greedy Sampling

Greedy Sampling是直接选择概率最大的token作为输出。代码简单直观,如下:

2.3.2 Random Sampling

Random Sampling是随机选择一个token作为输出。代码也很简单,如下:

2.3.3 Top-p (Nucleus) Sampling

Top-p (Nucleus) Sampling是随机选择概率大于某个阈值的token作为输出。代码也很简单,如下:

2.3.4 选择抽样策略

具体执行抽样前,需要做一些变换,比如:

  • 除以temperature,用来调整概率分布,温度越高,概率分布越平滑
  • 计算softmax(logits),得到概率分布 代码如下所示:

然后根据不同的采样策略,选择不同的采样函数。

2.4 encode和decode
2.4.1 encode

encode函数将输入文本转化为token id序列。token id为int类型,长度为max_len。encode算法非常直观,先是在tokenize词典中查询每个UTF-8字符。如果找不到,则将文本编码为byte fallback。注意每个UTF-8字符长度是1到3个字节之间,需要针对UTF-8编码的规范进行判断。

代码如下:

其次,尝试合并临近的字符,并查询tokenize词典,如果存在,则将临近的token缩对应的字符串合并为一个token。 并反复迭代,直到找不到相邻的两个token可以合并为一个token为止。代码也很直观,如下:

2.4.2 decode

decode函数将token id序列转化为文本。代码也直观,有一些比较tricky之处,代码也注释清楚:

2.5 文本生成

文本生成是最基础的inference逻辑,对话也是基于文本生成而实现的。整个代码逻辑也非常简单:

  1. 将每一个token id逐个进行forward计算
  2. 判断当前token位置是否还在prompt长度内,如果不在则执行sampling策略,通过logits向量选取下一个token
  3. 否则直接从prompt中读取下一个token。
  4. 将下一个token进行decode,并打印出来。

代码详见:

2.6 其他

其他部分的代码就是一些简单的数据结构定义,以及helper函数和main函数,这里就不再赘述了。

3. 总结

总体来说,这个项目是一个toy项目,代码逻辑比较简单,但是也提供了非常多的细节参考。特别是兼容了MHA和GQA算法,对于理解这些算法的原理非常有帮助。

但也要看出,这个代码中并没有实现prefill阶段,而是采用逐个token输入的方式填充kv cache。效率的确比较低,但好在逻辑清晰,容易理解。

如果需要进一步优化这个代码,其实有很多可优化点,例如prefill的并行加载优化,减少重复decode等,但这些都超出了这个项目的范围,留给读者自己探索。

参考链接


llama2.c 源码阅读

发布者

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注