463|0

70

帖子

1

TA的资源

一粒金砂(高级)

楼主
 

#AI挑战营第一站# PyTorch实现MNIST手写数字识别 [复制链接]

本帖最后由 aramy 于 2024-4-15 10:19 编辑

手写识别,基本是机器学习的hello world了。几乎每个机器学习入门编程都是这个手写识别。

1、 环境安装:这次使用的是pytorch框架,首先去pytorch官网按指南在PC上安装好pytorch环境。这里需要用到的库如下,需要预先引入

import torch
import torchvision
from torch.utils.data import DataLoader

2、 准备数据:手写识别的数据有现成的库了,不需要再去手工收集了。可以显示一下数据图片的内容。

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('./data/', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ])),
    batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('./data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ])),
    batch_size=batch_size_test, shuffle=True)

 

3、构建一个网络。使用框架这点还是不错的,可以自由地选择框架提供的函数,构建自己想要的网络。这里使用两个2d卷积层,然后是两个全连接(或线性)层。作为激活函数,我们将选择整流线性单元(简称ReLUs),作为正则化的手段,我们将使用两个dropout层。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

4、模型训练。这里设置了一些超参。最后两行能够保证重复试验,产生相同的随机数。

n_epochs = 30
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01  # 设置超参数学习率为0.01
momentum = 0.5
log_interval = 10
random_seed = 1
torch.manual_seed(random_seed)

添加一个测试方法,用来验证每轮训练的结果。

def test():
    network.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = network(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

最后便是训练了。

for epoch in range(1, n_epochs + 1):
    train(epoch)
    test()

 

 

  

 

经过多轮的训练,就已经能够达到测试集97%的准确率(貌似也很难提高了)!使用PyTorch框架,使得我们很容易建立起神经网络,并且可以轻易地增加神经网络层数。

5 导出onnx文件。第一次接触onnx文件。B站上学习了一下,大概了解到ONNX是一个训练模型到终端应用的中间层,能够通明模型到终端应用的过程。值得好好学习一番。导出ONNX文件倒也挺容易的。

torch.onnx.export(network, dummy_input, onnx_path,
                  opset_version=11,  # ONNX 算子集版本
                  input_names=['input'], output_names=['output']  # 输入和输出节点名称
                  )  # 设置动态轴,允许在推理时调整batch_size

对导出的模型进行检查:

import onnx
# 读取 ONNX 模型
onnx_model = onnx.load('best_model.onnx')

# 检查模型格式是否正确
onnx.checker.check_model(onnx_model)

print('无报错,onnx模型载入成功')
print(onnx.helper.printable_graph(onnx_model.graph))
无报错,onnx模型载入成功
graph torch-jit-export (
  %input[FLOAT, 1x1x28x28]
) initializers (
  %conv1.weight[FLOAT, 10x1x5x5]
  %conv1.bias[FLOAT, 10]
  %conv2.weight[FLOAT, 20x10x5x5]
  %conv2.bias[FLOAT, 20]
  %fc1.weight[FLOAT, 50x320]
  %fc1.bias[FLOAT, 50]
  %fc2.weight[FLOAT, 10x50]
  %fc2.bias[FLOAT, 10]
) {
  %9 = Conv[dilations = [1, 1], group = 1, kernel_shape = [5, 5], pads = [0, 0, 0, 0], strides = [1, 1]](%input, %conv1.weight, %conv1.bias)
  %10 = MaxPool[ceil_mode = 0, kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]](%9)
  %11 = Relu(%10)
  %12 = Conv[dilations = [1, 1], group = 1, kernel_shape = [5, 5], pads = [0, 0, 0, 0], strides = [1, 1]](%11, %conv2.weight, %conv2.bias)
  %13 = MaxPool[ceil_mode = 0, kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]](%12)
  %14 = Relu(%13)
  %15 = Constant[value = <Tensor>]()
  %16 = Reshape(%14, %15)
  %17 = Gemm[alpha = 1, beta = 1, transB = 1](%16, %fc1.weight, %fc1.bias)
  %18 = Relu(%17)
  %19 = Gemm[alpha = 1, beta = 1, transB = 1](%18, %fc2.weight, %fc2.bias)
  %output = LogSoftmax[axis = 1](%19)
  return %output
}

Process finished with exit code 0

还可以去https://netron.app/网站上图形化看一下改模型结构.

 

mnist_train.py

4.34 KB, 下载次数: 0

训练

nn_net.py

843 Bytes, 下载次数: 0

神经网络

mnist_model.pth

87.98 KB, 下载次数: 0

check_onnx.py

380 Bytes, 下载次数: 0

onnx文件检查

best_model.onnx

86.5 KB, 下载次数: 2

2024041510191440.zip

165.21 KB, 下载次数: 0

打包上传

此帖出自编程基础论坛
点赞 关注
 

回复
举报
您需要登录后才可以回帖 登录 | 注册

随便看看
查找数据手册?

EEWorld Datasheet 技术支持

相关文章 更多>>
关闭
站长推荐上一条 1/10 下一条

 
EEWorld订阅号

 
EEWorld服务号

 
汽车开发圈

About Us 关于我们 客户服务 联系方式 器件索引 网站地图 最新更新 手机版

站点相关: 国产芯 安防电子 汽车电子 手机便携 工业控制 家用电子 医疗电子 测试测量 网络通信 物联网

北京市海淀区中关村大街18号B座15层1530室 电话:(010)82350740 邮编:100190

电子工程世界版权所有 京B2-20211791 京ICP备10001474号-1 电信业务审批[2006]字第258号函 京公网安备 11010802033920号 Copyright © 2005-2024 EEWORLD.com.cn, Inc. All rights reserved
快速回复 返回顶部 返回列表