4030|6

47

帖子

1

TA的资源

一粒金砂(中级)

楼主
 

#【AI挑战营第一站】#基于pytorch的MNIST手写数字识别 [复制链接]

本帖最后由 空耳- 于 2024-4-25 22:40 编辑

 大致流程

1.准备数据

2.构建模型

3.训练模型

4.模型评估

5.模型转换与保存

1.开发环境

本文使用的环境是:pycharm + conda

GPU:MX350

pytorch = 1.8.0

python = 3.8

2.导入相关的库

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt

3.构建数据集

def get_data_loader(is_train):
    to_tensor = transforms.Compose([transforms.ToTensor()])
    data_set = MNIST("", is_train, transform=to_tensor, download=True)
    return DataLoader(data_set, batch_size=15, shuffle=True)

4.神经网络模型定义

 def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(28 * 28, 64)
        self.fc2 = torch.nn.Linear(64, 64)
        self.fc3 = torch.nn.Linear(64, 64)
        self.fc4 = torch.nn.Linear(64, 10)


def forward(self, x):
    x = torch.nn.functional.relu(self.fc1(x))
    x = torch.nn.functional.relu(self.fc2(x))
    x = torch.nn.functional.relu(self.fc3(x))
    x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
   		 return x

5.评估函数

def evaluate(test_data, net):
    n_correct = 0
    n_total = 0
    with torch.no_grad():
        for (x, y) in test_data:
            outputs = net.forward(x.view(-1, 28 * 28))
            for i, output in enumerate(outputs):
                if torch.argmax(output) == y[i]:
                    n_correct += 1
                n_total += 1
    return n_correct / n_total

6.主函数

def main():
    train_data = get_data_loader(is_train=True)
    test_data = get_data_loader(is_train=False)
    net = Net()

    print("initial accuracy:", evaluate(test_data, net))
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    for epoch in range(2):
        for (x, y) in train_data:
            net.zero_grad()
            output = net.forward(x.view(-1, 28 * 28))
            loss = torch.nn.functional.nll_loss(output, y)
            loss.backward()
            optimizer.step()
        print("epoch", epoch, "accuracy:", evaluate(test_data, net))

    for (n, (x, _)) in enumerate(test_data):
        if n > 3:
            break
        predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28)))
        plt.figure(n)
        plt.imshow(x[0].view(28, 28))
        plt.title("prediction: " + str(int(predict)))
    plt.show()

7.模型保存与转换

    # 保存整个模型
    torch.save(net, "mnist_model.pth")

    # 保存模型的状态字典
    torch.save(net.state_dict(), "mnist_model_state_dict.pth")

    # 创建一个示例输入张量,并调整形状以匹配模型期望的形状
    dummy_input = torch.randn(1, 28 * 28)  # 调整示例输入数据的形状为(batch_size, 28 * 28)

    # 导出模型为ONNX格式
    torch.onnx.export(net, dummy_input, "mnist_model.onnx", verbose=True)

8.查看模型结构

Netron支持在线和离线的操作,可以直接在网页上进行展示,
在线运行地址:https://netron.app/

9.完整代码

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
import torch.onnx

# 定义神经网络模型
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(28 * 28, 64)  # 输入层到隐藏层的全连接层
        self.fc2 = torch.nn.Linear(64, 64)  # 隐藏层到隐藏层的全连接层
        self.fc3 = torch.nn.Linear(64, 64)  # 隐藏层到隐藏层的全连接层
        self.fc4 = torch.nn.Linear(64, 10)  # 隐藏层到输出层的全连接层

    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.nn.functional.relu(self.fc3(x))
        x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)  # 输出层使用log_softmax激活函数
        return x

# 获取数据加载器
def get_data_loader(is_train):
    to_tensor = transforms.Compose([transforms.ToTensor()])
    data_set = MNIST("", is_train, transform=to_tensor, download=True)  # 下载MNIST数据集
    return DataLoader(data_set, batch_size=15, shuffle=True)  # 返回数据加载器

# 评估模型准确率
def evaluate(test_data, net):
    n_correct = 0
    n_total = 0
    with torch.no_grad():
        for (x, y) in test_data:
            outputs = net.forward(x.view(-1, 28 * 28))
            for i, output in enumerate(outputs):
                if torch.argmax(output) == y[i]:
                    n_correct += 1
                n_total += 1
    return n_correct / n_total

# 主函数
def main():
    train_data = get_data_loader(is_train=True)  # 获取训练数据加载器
    test_data = get_data_loader(is_train=False)  # 获取测试数据加载器
    net = Net()  # 初始化神经网络模型

    print("initial accuracy:", evaluate(test_data, net))  # 打印初始准确率
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)  # 使用Adam优化器
    for epoch in range(2):  # 训练两个epoch
        for (x, y) in train_data:
            net.zero_grad()  # 梯度清零
            output = net.forward(x.view(-1, 28 * 28))  # 前向传播
            loss = torch.nn.functional.nll_loss(output, y)  # 计算损失
            loss.backward()  # 反向传播
            optimizer.step()  # 更新参数
        print("epoch", epoch, "accuracy:", evaluate(test_data, net))  # 打印每个epoch的准确率

    for (n, (x, _)) in enumerate(test_data):
        if n > 3:  # 只展示前4个样本的预测结果
            break
        predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28)))  # 对样本进行预测
        plt.figure(n)
        plt.imshow(x[0].view(28, 28))
        plt.title("prediction: " + str(int(predict)))  # 显示预测结果

    # 保存整个模型
    torch.save(net, "mnist_model.pth")

    # 保存模型的状态字典
    torch.save(net.state_dict(), "mnist_model_state_dict.pth")

    # 创建一个示例输入张量,并调整形状以匹配模型期望的形状
    dummy_input = torch.randn(1, 28 * 28)  # 调整示例输入数据的形状为(batch_size, 28 * 28)

    # 导出模型为ONNX格式
    torch.onnx.export(net, dummy_input, "mnist_model.onnx", verbose=True)

    plt.show()  # 展示图像

if __name__ == "__main__":
    main()

10.训练结果

11.模型文件

 

补充内容 (2024-5-8 19:06): 预测结果如下:initial accuracy: 0.0941 epoch 0 accuracy: 0.9543 epoch 1 accuracy: 0.966 epoch 2 accuracy: 0.9673 epoch 3 accuracy: 0.9722 epoch 4 accuracy: 0.9729 epoch 5 accuracy: 0.975 epoch 6 accuracy: 0.9724 epoch 7 accuracy: 0.973 epoch 8 accuracy: 0.9738 epoch 9 accuracy: 0.979 测试集准确率: 97.90 %

my_MNISt.py

3.36 KB, 下载次数: 1

mnist_model.onnx

232.04 KB, 下载次数: 1

mnist_model.pth

234.86 KB, 下载次数: 0

最新回复

  cpu也可以玩简单的。   详情 回复 发表于 2024-5-13 13:29
点赞 关注

回复
举报

47

帖子

1

TA的资源

一粒金砂(中级)

沙发
 

加油

补充内容 (2024-5-8 19:23): # 测试模型 def test(model, test_data): correct = 0 total = 0 with torch.no_grad(): for data in test_data: images, labels = data outputs = model(images.view(-1, 28*28)) _, predicted = torch.max(outputs, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print('测试集准确率: %7.2f %%' % (100 * correct / total))
 
 

回复

47

帖子

1

TA的资源

一粒金砂(中级)

板凳
 

 

 
 
 

回复

47

帖子

1

TA的资源

一粒金砂(中级)

4
 

 

 
 
 

回复

7608

帖子

18

TA的资源

五彩晶圆(高级)

5
 

1.8 对于日新月异的AI来说有点年头了。

点评

我的显卡太菜了,mx350    详情 回复 发表于 2024-5-11 19:52
 
个人签名

默认摸鱼,再摸鱼。2022、9、28

 
 

回复

47

帖子

1

TA的资源

一粒金砂(中级)

6
 
freebsder 发表于 2024-5-11 17:30 1.8 对于日新月异的AI来说有点年头了。

我的显卡太菜了,mx350

 

点评

  cpu也可以玩简单的。  详情 回复 发表于 2024-5-13 13:29
 
 
 

回复

7608

帖子

18

TA的资源

五彩晶圆(高级)

7
 
空耳- 发表于 2024-5-11 19:52 我的显卡太菜了,mx350  

  cpu也可以玩简单的。

 
个人签名

默认摸鱼,再摸鱼。2022、9、28

 
 

回复
您需要登录后才可以回帖 登录 | 注册

随便看看
查找数据手册?

EEWorld Datasheet 技术支持

相关文章 更多>>
关闭
站长推荐上一条 1/9 下一条

 
EEWorld订阅号

 
EEWorld服务号

 
汽车开发圈

About Us 关于我们 客户服务 联系方式 器件索引 网站地图 最新更新 手机版

站点相关: 国产芯 安防电子 汽车电子 手机便携 工业控制 家用电子 医疗电子 测试测量 网络通信 物联网

北京市海淀区中关村大街18号B座15层1530室 电话:(010)82350740 邮编:100190

电子工程世界版权所有 京B2-20211791 京ICP备10001474号-1 电信业务审批[2006]字第258号函 京公网安备 11010802033920号 Copyright © 2005-2025 EEWORLD.com.cn, Inc. All rights reserved
快速回复 返回顶部 返回列表