《大语言模型开发:用开源模型开发本地系统》分享3:理解Transformer:构建LLM核心
[复制链接]
在人工智能的快速发展中,Transformer模型以其卓越的性能在自然语言处理领域占据了核心地位。本次分享,将深入探讨《大语言模型开发:用开源模型开发本地系统》中关于Transformer模型的详解,以便通俗易懂地理解这一模型的构成和工作机制。
第一部分:大语言模型的简介和分类
1. 简介
大语言模型是人工智能领域的一个重要分支,它们通过深度学习和大量的数据训练,能够理解和生成自然语言。这些模型在机器翻译、文本摘要、问答系统等多个领域都有广泛的应用。
2. 分类
大语言模型可以根据不同的架构和训练方式进行分类。了解这些分类有助于我们选择合适的模型来解决特定的问题。
第二部分:Transformer模型详解
1. 模型构成
Transformer模型由编码器和解码器组成,它们通过自注意力机制来处理序列数据。这种结构使得模型能够捕捉到序列中的长距离依赖关系。
2. 因果解码器结构
因果解码器是Transformer模型中的一个重要特性,它确保了在生成序列时,每一步只依赖于之前的位置,这对于生成连贯的文本至关重要。
Transformer的主要组成部分有以下几个。
(1)自注意力机制:这是Transformer的核心组成部分,也是其能够处理序列数据的关键。自注意力机制能够计算序列中每个元素与其他元素之间的关系,并基于这些关系来更新元素的表示。这使 Transformer 能够捕捉到序列中长距离的依赖关系。
(2)多头自注意力(Multi-Head Attention):Transformer 并不只计算一次自注意力,而是同时计算多次,每次使用不同的参数,然后将这些结果合并起来。这使Transformer 能够捕捉到数据的多个不同方面的信息。
(3)位置编码(Positional Encoding):由于Transformer 并没有使用RNN或CNN,所以它无法直接处理序列的顺序信息。为了解决这个问题,Transformer 引人位置编码,通过给每个元素添加一个位置相关的向量,来向模型提供序列中元素的位置信息。
(4)前馈神经网络:除了自注意力机制,Transformer 的每一层还包括一个前馈神经网络。这个网络在每个位置上都是独立运行的,它能够增强模型的复杂性,而不会增强处理序列的复杂性。
(5)归一化层:Transformer 在每个子层(自注意力和前馈神经网络)的输出后都添加了一个归一化层,以防止模型的训练发散。
(6)残差连接:Transformer 在每个子层的输人和输出之间都添加了一个残差连接。这可以帮助模型更容易地学习深层网络。
第三部分:分词
1. 词汇表
词汇表是模型理解语言的基础,它包含了模型能够识别的所有单词或符号。
2. 分词算法
分词算法将文本分割成词汇表中的单元,这是模型处理文本的第一步。
以下是一些常见的分词算法。
(1)空格分词:这是最简单的分词方法,只需按空格将文本分割成单词。这种方法在处理英语等大部分西方语言时效果不错,但对于没有明确单词边界的语言(如中文)或者复合词丰富的语言(如德语 ),效果就不理想了。
(2)基于词典的分词:这种方法需要一个预先定义好的词典,然后根据词典将文本分割成单词。这种方法可以处理一些复杂的情况,但依赖于词典的质量,而且不能很好地处理词典中不存在的单词。
(3)基于统计的分词:这种方法使用机器学习算法从大量的文本数据中学习单词的边
界。常见的基于统计的分词算法包括 HMM、CRF 等。
(4)子词分词:这种方法将单词进一步分割为子词。这样可以处理词典中不存在的单词,因为即使一个单词在词典中不存在,其组成的子词也可能存在。常见的子词分词算法包括字节对编码、句子片段(SentencePiece)等。
3. 字节对编码
字节对编码(Byte Pair Encoding, BPE)是一种有效的分词方法,它通过合并频繁出现的字节对来构建词汇表。
BPE的操作步骤如下。
(1)初始化词汇表:开始时,词汇表中的每个符号都是语料库中的一个字符。
(2)统计符号对频率:在语料库中统计每对连续符号的出现频率。
(3)合并频率最高的符号对:将频率最高的符号对合并为一个新的符号,加到同表中。
(4)重复步骤(2)和(3):重复上述步骤,直到达到预定的词汇表大小或者没有可以合并的符号对。
4. 句子片段
句子片段是另一种分词方法,它将句子分割成更小的片段,以提高模型的处理效率。
5. 分词过程
分词过程是将文本转换为模型可以理解的格式,这是模型训练和推理的前提。
6. 词汇索引
词汇索引是将词汇表中的单词映射到唯一的索引,这些索引将用于模型的输入。
第四部分:词嵌入
1. 标词嵌入
词嵌入是将词汇索引转换为高维空间中的向量,这些向量能够捕捉单词的语义信息。
2. 位置编码
位置编码是为词嵌入向量添加位置信息,使得模型能够理解单词在序列中的顺序。
3. 词汇索引和词嵌入向量
词汇索引和词嵌入向量的结合,为模型提供了丰富的语义和位置信息。
第五部分:位置编码方法
1. 原生位置编码
原生位置编码是Transformer模型中的一种位置编码方式,它通过正弦和余弦函数来实现。
2. 旋转位置编码
旋转位置编码是另一种位置编码方法,它通过旋转矩阵来实现。
3. Llama位置编码
Llama位置编码是一种新型的位置编码方式,它在某些任务中表现出更好的性能。
import torch
def precompute_freqs cis(dim,seqlen,theta =10000.0):
fregs =1.0/(theta **(torch.arange(0,dim,2)[:(dim // 2)].float()/ dim))
t=torch.arange(seqlen) #顺序位置,0~ seqlen - 1
freqs =torch,outer(t,fregs).float()
return freqs
embedding_dim=8
sequence_length=5
#标记嵌入,全为1
token_embedding=torch.ones((sequence_length, embedding_dim))
freqs =precompute_freqs_cis(embedding_dim, sequence_length)
#标准位置编码
pe = torch.zeros(sequence_length,embedding_dim)
pe[:,0::2]= torch.sin(fregs)
pe[:,1::2]= torch.cos(fregs)
#标记嵌入 + 位置嵌入
pe_out=token_embedding + pe
Print(pe_out)
#旋转位置编码
freqs_cis =torch.polar(torch.ones_like(freqs),fregs)
token_embedding cis
torch.view_as_complex(token_embedding.reshape(sequence_length, -1,2))
rope_out =torch.view_as_real(token_embedding cis * fregs cis) .flatten(1)
print(rope_out)
第六部分:自注意力机制
1. 原理
自注意力机制是Transformer模型的核心,它允许模型在处理序列时同时关注序列的不同部分。
2. 注意力分数的计算
注意力分数的计算是自注意力机制的关键,它决定了模型在不同位置的关注度。
3. 多头注意力机制
多头注意力机制通过多个注意力头来捕捉不同子空间的信息。
4. 分组查询注意力
分组查询注意力是一种优化的自注意力机制,它通过分组来减少计算量。
5. Llama 2源代码分析
通过分析Llama 2的源代码,我们可以更深入地理解自注意力机制的实现细节。
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
# 对输入的张量进行重复操作,以满足多头注意力机制中多次使用同一个键-值对的需要
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
class Attention(nn.Module):
def __init__(self, args: 'ModelArgs'):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = args.init_get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
# 构造注意力查询(Q)、键(K)和值(V)所需要的线性变换算子
# 这堂查蘀角一个交换算子文持了多头的场景,因为每个头实际上计算方式是完全一样的,只是参数不同
self.wq = ColumnParallelLinear(
args.dim, args.n_heads * self.head_dim, bias=False,
gather_output=False, init_method=lambda x: x
)
self.wk = ColumnParallelLinear(
args.dim, self.n_kv_heads * self.head_dim, bias=False,
gather_output=False, init_method=lambda x: x
)
self.wv = ColumnParallelLinear(
args.dim, self.n_kv_heads * self.head_dim,
bias=False, gather_output=False, init_method=lambda x: x
)
# 构造对最终输出进行线性变换的算子
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim, bias=False,
input_is_parallel=True, init_method=lambda x: x
)
self.cache_k = torch.zeros(
args.max_batch_size, args.max_seq_len, self.n_local_kv_heads, self.head_dim
).cuda()
self.cache_v = torch.zeros(
args.max_batch_size, args.max_seq_len, self.n_local_kv_heads, self.head_dim,
).cuda()
def forward(
self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]
):
bsz, seq_len = x.shape[:2]
# 对输入序列进行线性变换,分别得到查询(Q)、键(K)和值(V)
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
# 对查询和键应用旋转嵌入(Rotary Embedding)操作
# 旋转嵌入是一种在注意力机制中引入周期性信息的技术,有助于模型捕捉序列的顺序关系
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
# 更新缓存中的键(K)和值(V),将当前位置的键和值存储在缓存中以供后续的注意力计算使用
self.cache_k = self.cache_k.to(xq.device)
self.cache_v = self.cache_v.to(xq.device)
self.cache_k[:, start_pos:start_pos + seq_len] = xk
self.cache_v[:, start_pos:start_pos + seq_len] = xv
# 从缓存中获取用于注意力计算的键(K)和值(V),包括当前位置之前的所有位置
keys = self.cache_k[:, :start_pos + seq_len]
values = self.cache_v[:, :start_pos + seq_len]
# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(keys, self.n_rep)
values = repeat_kv(values, self.n_rep)
# 对查询、键和值进行维度转置,以便进行矩阵乘法操作
xq = xq.transpose(1, 2) # (bs, n_local_heads, seq_len, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, seq_len, head_dim)
values = values.transpose(1, 2) # (bs, n_local_heads, seq_len, head_dim)
# 计算查询和键之间的相似度得分,通过矩阵乘法计算得到,同时除以头的维度的平方根来进行缩放,以控制相似度的范围
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
# 如果存在掩码(mask),则将其加到相似度得分上,以屏蔽无效位置的影响
scores = scores + mask # (bs, n_local_heads, seq_len, cache_len + seq_len)
# 对相似度得分进行softmax操作,将其转换为注意力权重,使得权重在每个位置的分布总和为1
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
# 根据注意力权重对值进行加权求和,得到最终的注意力输出
output = torch.matmul(scores, values) # (bs, n_local_heads, seq_len, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
# 对注意力输出进行线性变换,得到最终的注意力机制的输出
return self.wo(output)
第七部分:残差连接和层归一化
1. 预归一化
预归一化是在应用激活函数之前进行归一化,这有助于模型的训练。
2. RMSNorm
RMSNorm是一种归一化方法,它通过除以均方根来实现。
3. Llama 2源代码分析
通过分析Llama 2的源代码,我们可以了解残差连接和层归一化在实际模型中的应用。
import torch
import torch.nn as nn
from typing import Optional
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: 'ModelArgs'):
super().__init__()
self.n_heads = args.n_heads
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(dim=args.dim,
hidden_dim=4 * args.dim,
ffn_dim_multiplier=args.ffn_dim_multiplier)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor] = None):
h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward.forward(self.ffn_norm(h))
return out
Llama 中实现 RMSNorm 的源代码为
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
# dim参数表示输入张量的维度,即要在哪个维度上计算均方根并进行归一化。
# weight是一个可学习的权重参数,用于缩放标准化后的输入。
self.weight = nn.Parameter(torch.ones(dim))
def norm(self, x):
# 计算输入张量的均方根,并将每个元素除以均方根值。
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
# 调用norm方法对输入张量进行标准化处理,并将标准化后的结果与权重参数相乘,以进一步缩放和调整输出。
output = self.norm(x.float()).type_as(x)
return output * self.weight
第八部分:前馈网络
1. 激活函数
激活函数为前馈网络提供了非线性能力,使得模型能够学习复杂的模式。
2. 前馈网络隐藏层维度
前馈网络隐藏层的维度决定了模型的表达能力。
3. Llama 2源代码分析
通过分析Llama 2的源代码,我们可以了解前馈网络在Transformer模型中的实现。
import torch
import torch.nn as nn
from typing import Optional
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, ffn_dim_multiplier: Optional[float]):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3) # custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = (hidden_dim + multiple_of - 1) // multiple_of * multiple_of # Ensure hidden_dim is a multiple of multiple_of
self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
self.silu = nn.SiLU() # Assuming you want to use the SiLU activation function
def forward(self, x):
return self.w2(self.silu(self.w1(x)) * self.w3(x))
第九部分:损失函数及掩码
损失函数是衡量模型预测与实际值差异的指标,对于模型的训练至关重要。
Transformer 模型的损失函数通常采用交叉熵损失来计算预测结果与真实标签之间的差异。
掩码技术在处理序列数据时用于忽略某些位置,这对于某些特定的任务非常有用。
在 Transfonmer 的前向计算时,会计算一个掩码矩阵。然后,在计算注意力时,使用此掩码来遮蔽掉无效位置。
第十部分:Pytorch的nn.transformer模块
Pytorch的nn.transformer模块提供了构建Transformer模型的工具,使得开发者能够更容易地实现和训练模型。
以下是纯解码器模型的参考代码
import torch.nn as nn
class TransformerDecoder(nn.Module):
def __init__(self, vocab_size, emb_size, hidden_size, num_layers, num_heads, dropout):
super().__init__()
self.embedding = nn.Embedding(vocab_size, emb_size)
self.decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(emb_size, num_heads, hidden_size, dropout),
num_layers=num_layers
)
self.norm = nn.LayerNorm(hidden_size)
self.fc = nn.Linear(hidden_size, vocab_size)
def forward(self, trg, memory, trg_mask=None, memory_mask=None):
# trg: [trg len, batch size]
# memory: [src len, batch size, hidden size]
# trg mask: [trg len, trg len]
# memory mask: [trg len, src len]
trg_emb = self.embedding(trg) # [trg len, batch size, emb size]
trg_emb = trg_emb.transpose(0, 1) # [batch size, trg len, emb size]
output = self.decoder(trg_emb, memory, tgt_mask=trg_mask, memory_mask=memory_mask) # [batch size, trg len, hidden size]
output = self.norm(output) # [batch size, trg len, hidden size]
output = self.fc(output) # [batch size, trg len, vocab size]
return output
# Initialize the model
model = TransformerDecoder(vocab_size=32000, emb_size=512, hidden_size=1024, num_layers=6, num_heads=4, dropout=0.1)
print(model)
Liama 2模型的参考代码
import math
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
@dataclass
class ModelArgs:
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: Optional[int] = None
vocab_size: int = -1 # defined later by tokenizer
multiple_of: int = 256 # make SwiGlu hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5
max_batch_size: int = 32
max_seq_len: int = 2048
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self.norm(x.float()).type_as(x)
return output * self.weight
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
xq_out = torch.view_as_complex(xq * freqs_cis)
xk_out = torch.view_as_real(xk * freqs_cis)
return xq_out.type_as(xq), xk_out.type_as(xk)
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :].expand(bs, slen, n_kv_heads, n_rep, head_dim).reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
class Attention(nn.Module):
"""Multi-head attention module."""
def __init__(self, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = 1 if not hasattr(args, 'get_model_parallel_world_size') else args.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.cache_k = torch.zeros(
args.max_batch_size, args.max_seq_len, self.n_local_kv_heads, self.head_dim,
)
self.cache_v = torch.zeros(
args.max_batch_size, args.max_seq_len, self.n_local_kv_heads, self.head_dim,
)
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
bsz, seqlen = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
self.cache_k = self.cache_k.to(xq.device)
self.cache_v = self.cache_v.to(xq.device)
self.cache_k[:bsz, start_pos:start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos:start_pos + seqlen] = xv
keys = self.cache_k[:bsz, :start_pos + seqlen]
values = self.cache_v[:bsz, :start_pos + seqlen]
keys = repeat_kv(keys, self.n_rep)
values = repeat_kv(values, self.n_rep)
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, ffn_dim_multiplier: Optional[float]):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3) # custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(torch.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.layer_id = layer_id
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
h = self.attention_norm(x)
h = self.attention.forward(h, start_pos, freqs_cis, mask)
h = h + x
out = self.feed_forward.forward(self.ffn_norm(h))
out = out + h
return out
class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.vocab_size = params.vocab_size
self.params = params
self.token_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.layers = nn.ModuleList([TransformerBlock(layer_id, params) for layer_id in range(params.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.freqs_cis = precompute_freqs_cis(params.dim // 2, params.max_seq_len * 2)
def forward(self, tokens: torch.Tensor, start_pos:int):
with torch.inference_mode():
bsz, seqlen = tokens.shape
tokens = self.token_embeddings(tokens)
freqs_cis = self.freqs_cis.to(tokens.device)
mask = None if seqlen <= 1 else torch.full((seqlen, seqlen), float('-inf'), device=tokens.device)
mask = torch.tril(torch.ones((seqlen, start_pos), device=tokens.device))
for layer in self.layers:
h = layer(tokens, start_pos, freqs_cis, mask)
tokens = self.norm(h)
output = self.output(tokens)
return output.float()
model_args: ModelArgs=ModelArgs()
medel_args.vocab_size =32000
model_args.n_layers =6
model_args.max_seq_len = 2048
model = Transformer(model_args)
print (model)
|