4070|1

428

帖子

10

TA的资源

纯净的硅(初级)

楼主
 

#AI挑战营第一站#pytorch训练MNIST数据集 [复制链接]

提前3天看到消息,终于赶在deadline之前完成代码(好在很简单,找到旧硬盘里面以前的练习代码改改就能用)。

目标:

基于pytorch,使用mnist数据集训练识别数字0~9

 

实现思路:

首先是搭建一个神经网络模型,模型搭建说实话只能凑。这里介绍搭建二维模型常用几个模块吧:

1)nn.conv2d

执行2D卷积操作。

参考文档:https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html

其格式为:

torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None)

in_channels:输入的通道数,在mnist数据集中,只有灰度,因此in_channels选择1。

out_channels:输出的通道数,这里的输出可以作为下一层的输入使用。

kernel_size:卷积核的大小,一般我们会使用5x5、3x3这种左右两个数相同的卷积核,因此这种情况只需要写 kernel_size = 5这样的就行了。如果左右两个数不同,比如3x5的卷积核,那么写作kernel_size = (3, 5),注意需要写一个tuple,而不能写一个list。

stride:卷积核在图像窗口上每次平移的间隔,即所谓的步长,一般采用默认值1。

padding:指图像填充,后面的int型常数代表填充的多少(行数、列数),默认为0。需要注意的是这里的填充包括图像的上下左右,以padding=1为例,若原始图像大小为[32, 32],那么padding后的图像大小就变成了[34, 34]。

dilation:是否采用空洞卷积,默认为1(不采用)。

groups:决定了是否采用分组卷积。

bias:即是否要添加偏置参数作为可学习参数,默认为True。

padding_mode:即padding的模式,默认采用零填充。

2)nn.MaxPool2d

二维最大池化层。它用于在神经网络中执行最大池化操作,以减少特征图的空间尺寸并提取出主要特征。参考文档https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html

其格式为:

torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)

kernel_size:最大池化窗口大小,当最大池化窗口是方形的时候,只需要一个整数边长即可;最大池化窗口不是方形时,使用tuple输入高和宽。

stride:max pooling 的窗口移动的步长。默认值是 kernel_size

padding:输入的每一条边补充0的层数

dilation:一个控制窗口中元素步幅的参数

return_indices:如果等于 True,会返回输出最大值的序号,对于上采样操作会有帮助

ceil_mode:如果等于True,计算输出信号大小的时候,会使用向上取整,代替默认的向下取整的操作

3)nn.AdaptiveMaxPool2d

二维自适应最大池化层。

参考文档:https://pytorch.org/docs/stable/generated/torch.nn.AdaptiveMaxPool2d.html

格式为:

torch.nn.AdaptiveMaxPool2d(output_size, return_indices=False)

output_size: 输出信号的尺寸,可以使用整数边长或tuple输入高和宽。

return_indices: 如果设置为True,会返回输出的索引。

 

4)nn.AvgPool2d

二维平均池化层。

参考文档:https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html

格式为:

torch.nn.AvgPool2d(kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None)

 

5)nn.LPPool2d

参考文档:https://pytorch.org/docs/stable/generated/torch.nn.LPPool2d.html

格式为:

torch.nn.LPPool2d(norm_type, kernel_size, stride=None, ceil_mode=False)

kernel_size:滑动窗口大小,可以为一个整数,也可以为一个元组

stride:步幅,默认等于kernel_size

 

6)nn.Dropout2d

以一定概率丢弃一部分数据,用于加速。

参考文档:https://pytorch.org/docs/stable/generated/torch.nn.Dropout2d.html

格式为:

torch.nn.Dropout2d(p=0.5, inplace=False)

p:对于input中各元素zero out的概率,如当p=1时,output为全0。

inplace:表示是否对tensor本身操作,若选择True,将会设置tensor为0。

经过丢弃层,输入输出的尺寸相同。

 

7)nn.ReLU

激活层。

参考文档:https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html

格式为:

torch.nn.ReLU(inplace=False)

inplace:是否改变输入数据,如果设置为True,则会直接修改输入数据;如果设置为False,则不对输入数据做修改。

经过激活层,输入输出的尺寸相同。

 

8)nn.Linear

线性层(也称为全连接层或仿射层),用于构建神经网络中的线性变换。

参考文档:https://pytorch.org/docs/stable/generated/torch.nn.Linear.html

格式为:

torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)

in_features:输入的神经元个数

out_features:输出神经元个数

bias=True:是否包含偏置

 

9)nn.Flatten

拉平层,将连续的维度范围展平为张量。

参考文档:

https://pytorch.org/docs/stable/generated/torch.nn.Flatten.htm

格式为:

torch.nn.Flatten(start_dim=1, end_dim=-1)

 

数据源:

采用mnist数据源,直接从网上拉下,包含四个文件:

train-images-idx3-ubyte.gz

train-labels-idx1-ubyte.gz

t10k-images-idx3-ubyte.gz

t10k-labels-idx1-ubyte.gz

分别作为训练集和测试集。

 

训练过程:

经过10轮训练,可以看到结果看起来还是令人满意的。

 

 

 

 

 

 

查看本帖全部内容,请登录或者注册
点赞 关注

回复
举报

428

帖子

10

TA的资源

纯净的硅(初级)

沙发
 
527887255 发表于 2024-4-26 18:12 首先是搭建一个神经网络模型,模型搭建说实话只能凑。这里介绍搭建二维模型常用几个模块吧

哈哈哈,无非是卷积、拉平、池化、丢弃之类的……

 
 

回复
您需要登录后才可以回帖 登录 | 注册

查找数据手册?

EEWorld Datasheet 技术支持

相关文章 更多>>
关闭
站长推荐上一条 1/10 下一条
立即报名 | 2025 瑞萨电子工业以太网技术日即将开启!
3月-4月 深圳、广州、北京、苏州、西安、上海 走进全国6城
2025瑞萨电子工业以太网技术巡回沙龙聚焦工业4.0核心需求,为工程师与企业决策者提供实时通信技术最佳解决方案。
预报从速,好礼等您拿~

查看 »

 
EEWorld订阅号

 
EEWorld服务号

 
汽车开发圈

 
机器人开发圈

About Us 关于我们 客户服务 联系方式 器件索引 网站地图 最新更新 手机版

站点相关: 国产芯 安防电子 汽车电子 手机便携 工业控制 家用电子 医疗电子 测试测量 网络通信 物联网

北京市海淀区中关村大街18号B座15层1530室 电话:(010)82350740 邮编:100190

电子工程世界版权所有 京B2-20211791 京ICP备10001474号-1 电信业务审批[2006]字第258号函 京公网安备 11010802033920号 Copyright © 2005-2025 EEWORLD.com.cn, Inc. All rights reserved
快速回复 返回顶部 返回列表