4453|2

7244

帖子

2

TA的资源

版主

楼主
 

AI计算中显存占用分析 [复制链接]

了解Transformer架构的AI大模型显存占用是非常重要的,特别是在训练和推理过程中。以下是详细解释和分析这些组成部分及其影响的专业描述:

 

1 显存占用


1.1 模型本身参数
模型的参数包括所有的权重和偏置项,这些参数需要存储在显存中,以便在训练和推理过程中进行计算。

占用字节:每个FP32参数占用4个字节,每个FP16参数占用2个字节。
计算:模型参数数量(例如,BERT-base模型大约有110M参数)。如果使用FP32表示,则总显存占用为 110M * 4 bytes。


1.2 模型的梯度动态值
在训练过程中,每个模型参数都有对应的梯度值,这些梯度用于更新模型参数。梯度存储同样需要显存。

占用字节:梯度和模型参数类型相同,所以FP32梯度占用4个字节,FP16梯度占用2个字节。
计算:梯度存储显存占用与模型参数相同,例如,如果模型参数使用FP32,则梯度显存占用为 参数数量 * 4 bytes。


1.3 优化器参数
优化器(如Adam)在训练过程中需要存储额外的参数,如一阶动量和二阶动量。这些参数也需要显存来存储。

Adam优化器:存储m和v两个参数,即需要2倍的模型参数量。
占用字节:每个FP32参数占用4个字节,每个FP16参数占用2个字节。
计算:例如,使用Adam优化器和FP32表示,则优化器参数显存占用为 2 * 参数数量 * 4 bytes。


1.4 模型的中间计算结果
在前向传播和反向传播过程中,需要存储每一层的中间计算结果,这些结果用于反向传播的求导。这些中间结果的显存占用与批量大小(batch size)、序列长度(sequence length)和每层的输出维度(hidden size)有关。

前向传播:每一层的输入x和输出都需要存储。
反向传播:中间结果的计算图不会被释放,以便计算梯度。
占用字节:这部分的显存占用难以精确计算,但可以通过调整batch size和sequence length来估算显存差值。
计算方法:常用的方法是实验性地调整batch size和sequence length,观察显存变化来估算中间结果的显存占用。


1.5 KV Cache
在推理过程中,尤其是在自回归模型(如GPT)中,需要缓存先前计算的键和值(Key和Value)以加速计算。这些缓存需要显存来存储。

占用字节:这部分的显存占用与输入的序列长度、批量大小和注意力头数有关。
计算方法:具体计算公式取决于模型的架构和缓存策略。
不同的参数类型所占的字节对比表

          类型                             所占字节          
FP32 4
FP16 2
INT8 1



2 具体示例


假设我们有一个Transformer模型,其架构和超参数如下:

层数(layers):12
隐藏层大小(hidden_size):768
注意力头数(num_heads):12
词汇表大小(vocab_size):30522
最大序列长度(sequence_length):512
批量大小(batch_size):1
数据类型:FP32(每个参数4字节)

 

为了具体计算一个具有上述参数的Transformer模型在推理时的显存占用,我们需要考虑以下几个部分:

  1. 模型本身的参数
  2. 输入和输出激活值
  3. 中间计算结果
  4. KV Cache


2.1 模型本身的参数
嵌入层
词嵌入矩阵:vocab_size * hidden_size
[30522 \times 768 = 23440896 \text{ 个参数}]
位置嵌入矩阵:sequence_length * hidden_size
[512 \times 768 = 393216 \text{ 个参数}]
嵌入层总参数:
[23440896 + 393216 = 23834112 \text{ 个参数}]

Transformer 层
每层的主要参数包括:

注意力层的 Q, K, V 权重和偏置:
[3 \times (hidden_size \times hidden_size) = 3 \times (768 \times 768) = 1769472 \text{ 个参数}]
输出权重和偏置:
[hidden_size \times hidden_size = 768 \times 768 = 589824 \text{ 个参数}]
前馈网络(两层):
[2 \times (hidden_size \times 4 \times hidden_size) = 2 \times (768 \times 4 \times 768) = 4718592 \text{ 个参数}]
每层总参数:
[1769472 + 589824 + 4718592 = 7077888 \text{ 个参数}]

12层总参数:
[12 \times 7077888 = 84934656 \text{ 个参数}]

总参数数量
模型总参数数量:
[23834112 + 84934656 = 108768768 \text{ 个参数}]

每个FP32参数占用4个字节:
[108768768 \times 4 = 435075072 \text{ 字节} = 435.08 \text{ MB}]

 

2.2 输入和输出激活值
假设模型在推理时的输入和输出激活值为 batch_size * sequence_length * hidden_size,对于每个层的激活值也相同。

每层激活值:
[batch_size \times sequence_length \times hidden_size = 1 \times 512 \times 768 = 393216 \text{ 个元素}]

每个FP32激活值占用4个字节:
[393216 \times 4 = 1572864 \text{ 字节} = 1.57 \text{ MB}]

 

2.3 中间计算结果
由于反向传播不需要考虑推理时的显存占用,我们可以忽略这部分。

 

2.4 KV Cache
在推理过程中,需要缓存每一层的键和值(Key和Value):

每层的KV Cache占用:
[2 \times batch_size \times sequence_length \times hidden_size = 2 \times 1 \times 512 \times 768 = 786432 \text{ 个元素}]

每个FP32值占用4个字节:
[786432 \times 4 = 3145728 \text{ 字节} = 3.14 \text{ MB}]

12层的KV Cache总占用:
[12 \times 3.14 \text{ MB} = 37.68 \text{ MB}]

 

2.5 总显存占用
[\text{模型参数} + \text{输入和输出激活值} + \text{KV Cache}]

显存占用计算:

模型参数:435.08 MB
激活值:1.57 MB(每层)× 12层 = 18.84 MB
KV Cache:37.68 MB
[\text{总显存占用} = 435.08 \text{ MB} + 18.84 \text{ MB} + 37.68 \text{ MB} = 491.60 \text{ MB}]

 

在推理过程中,一个具有上述配置的Transformer模型大约需要491.60 MB的显存。这一估算没有包括额外的显存开销,例如模型加载时的一些临时数据结构和框架本身的开销。实际使用中,可能还需要一些额外的显存来处理这些开销。
 

最新回复

Transformer模型大约需要491.60 MB的显存,看来显存占的不算多   详情 回复 发表于 2024-7-3 07:25
点赞 关注

回复
举报

6822

帖子

0

TA的资源

五彩晶圆(高级)

沙发
 

Transformer模型大约需要491.60 MB的显存,看来显存占的不算多

点评

带图像的模型就占很大显存,12G不够用。。。  详情 回复 发表于 2024-7-3 09:11
 
 

回复

7244

帖子

2

TA的资源

版主

板凳
 
Jacktang 发表于 2024-7-3 07:25 Transformer模型大约需要491.60 MB的显存,看来显存占的不算多

带图像的模型就占很大显存,12G不够用。。。

 
 
 

回复
您需要登录后才可以回帖 登录 | 注册

随便看看
查找数据手册?

EEWorld Datasheet 技术支持

相关文章 更多>>
关闭
站长推荐上一条 1/7 下一条

 
EEWorld订阅号

 
EEWorld服务号

 
汽车开发圈

About Us 关于我们 客户服务 联系方式 器件索引 网站地图 最新更新 手机版

站点相关: 国产芯 安防电子 汽车电子 手机便携 工业控制 家用电子 医疗电子 测试测量 网络通信 物联网

北京市海淀区中关村大街18号B座15层1530室 电话:(010)82350740 邮编:100190

电子工程世界版权所有 京B2-20211791 京ICP备10001474号-1 电信业务审批[2006]字第258号函 京公网安备 11010802033920号 Copyright © 2005-2025 EEWORLD.com.cn, Inc. All rights reserved
快速回复 返回顶部 返回列表