本文分享书中对“如何实现transformer模型”的阅读理解。
一、引言
transformer模型,以自注意力机制为主要特点,即对输入按照查询Q、键K、值V三部分进行分解,通过自注意力计算,然后叠加得到输出,这项技术在自然语言处理(NLP)领域得到了成熟应用。
transformer模型包括编码器和解码器,每个编、解码器由多个transformer层组成,每个transformer层由嵌入层,残差连接和归一化层,自注意力层和前馈网络层组成。
Llama2模型的transformer架构,与GPT3模型类似,都是纯解码器架构,这是由此类模型的应用需求决定的。
二、原生transformer模型
在实际应用中,transformer模型有许多变体,书中把标准的transformer理论模型称为原生transformer模型。
原生transformer模型,由6层编码器和6层解码器组成,在每个transformer层里,输入经过嵌入层,再通过多头注意力层与由两个线性层和一个激活函数构成的前馈网络层的结合处理,通过残差连接和层归一化处理输出到下一环节。
原生transformer模型的架构图,如下:
三、Llama2的transformer模型
Llama2 7B模型有32个transformer层,在每一层里分组查询注意力层和前馈网络层结合构成注意力模块,输入经过嵌入层处理,通过残差连接和层归一化,进入到注意力模块的处理,然后再经过残差连接和归一化输出到下一层。
在嵌入层上,Llama2使用了旋转字节对编码(RoBPE)技术。
在残差和归一化层的处理上,Llama2采用了RMSNorm进行归一化处理。
分组查询注意力(GQA)是Llama2的一个改进,对多头注意力(MHA)技术与多查询注意力(MQA)技术之间进行了折衷处理。
Llama2的transformer架构图,如下:
四、transformer模型实现
1、原生transformer实现
书中给出了利用PyTorch实现原生transformer模型的代码。
2、纯解码transformer实现
书中针对Meta发布的Llama2中transformer模型的PyTorch实现代码,做了简单的修改(删掉注释、减少层数、将并行运算改为串行)来满足低配运行的需要,介绍了Llama2模型的transformer实现。这是实现代码的组成框架图:
首先导入类库,除了PyTorch,还导入了math和dataclasses、typing等第三方类库。
在Llama2中,定义了Transformer类、前向网络层类FeedForward、层归一化类RMSNorm、自注意力层类Attention、TransformerBlock类,它们都是由PyTorch的nn.Module类继承而来的。
这是RMSNorm类,它执行归一化处理。
这是Transformer类,调用nn.Embedding方法执行嵌入处理,调用RMSNorm执行层归一化处理,可以看出Llama2采用的是预归一化处理。
这是TransformerBlock类,它将自注意力层和前馈网络层组合在了一起。
这是Attention类,其中wq、wk、wv、wo是查询Q、键K、值V和输出O的权重,通过nn.Linear执行线性层调用。
这是FeedForward类,它完成两个线性层和一个激活函数的处理。
Llama2 7B的transformer模型是pth格式模型。
五、小结与讨论
通过这一主题的学习,我进一步了解了transformer模型的架构机理及内部实现技术,并对Llama2 7B模型中transformer架构及代码实现加深了理解和认识。
推荐阅读和参考资料:1、分组查询注意力机制(GQA)。
2、多头注意力机制(MHA)。
3、多查询注意力机制(MQA)