3605|1

40

帖子

0

TA的资源

一粒金砂(中级)

楼主
 

#AI挑战营第一站# MNIST手写数字识别模型训练及优化 [复制链接]

本帖最后由 luyism 于 2024-4-17 16:37 编辑

#AI挑战营第一站# MNIST手写数字识别模型训练及优化

 

[背景介绍]

在人工智能的浪潮中,边缘计算作为推动AI技术落地的重要力量,其在嵌入式系统中的应用日益广泛。在这个项目中,我们致力于开发一个轻量级的手写数字识别模型,以便在资源受限的嵌入式设备上运行。我们将使用经典的MNIST数据集,这是一个包含手写数字图像的大型数据集,每个图像都标有相应的数字标签。我们的目标是训练一个模型,能够准确地识别这些手写数字,并将其应用于嵌入式设备,如RV1106开发板。由于嵌入式设备的计算能力和内存资源有限,需要特别注意模型的体积和计算效率,以确保在嵌入式系统中的实时性和稳定性。

 

[环境设置和数据准备]

首先,我们需要准备训练环境,需要确保已经安装了PyTorch和其他必要的库,包括matplotlib和onnxruntime。在安装了Python和PyTorch库后,我们使用 `torchvision` 来下载MNIST数据集,并通过 `DataLoader` 进行加载。这个数据集包含了60,000个训练样本和10,000个测试样本,每个样本都是28x28像素的灰度图像,每张图片都有相应的标签。

from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

def get_data_loader(is_train):
    # 数据加载函数
    to_tensor = transforms.Compose([transforms.ToTensor()])
    data_set = MNIST("", is_train, transform=to_tensor, download=True)
    return DataLoader(data_set, batch_size=15, shuffle=True)

 

[模型设计]

构建一个简单的前馈神经网络模型,每个隐藏层包含64个神经元,使用ReLU激活函数。输出层是一个包含10个神经元的softmax层,用于对10个数字进行分类。基于模型的轻量级和计算效率考虑,以便在嵌入式设备上运行。

class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # 神经网络层定义
        self.fc1 = torch.nn.Linear(28 * 28, 64)
        self.fc2 = torch.nn.Linear(64, 64)
        self.fc3 = torch.nn.Linear(64, 64)
        self.fc4 = torch.nn.Linear(64, 10)

    def forward(self, x):
        # 前向传播过程
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.nn.functional.relu(self.fc3(x))
        x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
        return x

 

[训练和优化]

使用MNIST数据集进行模型的训练和优化。训练过程中采用Adam优化器和负对数似然损失函数。模型在3个周期内进行训练,并在每个周期结束后评估其在测试集上的准确率。通过反复训练和优化,我们使模型在保持准确性的同时尽可能减小了参数量和计算复杂度,以适应嵌入式设备的限制。

def evaluate(test_data, net):
    # 模型评估函数
    n_correct = 0
    n_total = 0
    with torch.no_grad():
        for (x, y) in test_data:
            outputs = net.forward(x.view(-1, 28 * 28))
            for i, output in enumerate(outputs):
                if torch.argmax(output) == y[i]:
                    n_correct += 1
                n_total += 1
    return n_correct / n_total

def main():
    train_data = get_data_loader(is_train=True)
    test_data = get_data_loader(is_train=False)
    net = Net()

    print("模型初始准确率:", evaluate(test_data, net))
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    for epoch in range(3):
        for (x, y) in train_data:
            net.zero_grad()
            output = net.forward(x.view(-1, 28 * 28))
            loss = torch.nn.functional.nll_loss(output, y)
            loss.backward()
            optimizer.step()
        print("第", epoch+1, "轮准确率:", evaluate(test_data, net))

 

[测试和结果]

在测试集上评估模型的准确率,确保其具有良好的泛化能力。

 

我们可以看到,在没有开始训练时,模型初始的准确率为0.1左右,但第一轮后准确率就直接达到了0.9471,第三轮就达到了0.9706,这个准确度相当不错了。

(忽略第2、3行的警告信息,原因是的电脑没有安装显卡)

 

 

[模型导出与测试]

将训练好的模型保存为.pth文件,并转换为ONNX格式,以便在RV1106开发板上进行部署。通过这些步骤,我们成功地开发并优化了一个适用于嵌入式设备的轻量级数字识别模型。

def save_model(net, filename):
    # 保存模型为.pth文件
    torch.save(net.state_dict(), filename)

def convert_to_onnx(net, input_size, filename):
    # 转换模型到ONNX格式
    model = net.eval()  # 设置模型为评估模式
    x = torch.randn(input_size, requires_grad=True)  # 创建一个随机输入
    torch_out = model(x)

    # 导出模型
    torch.onnx.export(model,               # 运行模型
                      x,                   # 模型输入 (或一个元组对于多个输入)
                      filename,            # where to save the model (can be a file or file-like object)
                      export_params=True,  # store the trained parameter weights inside the model file
                      opset_version=10,    # the ONNX version to export the model to
                      do_constant_folding=True,  # whether to execute constant folding for optimization
                      input_names = ['input'],   # the model's input names
                      output_names = ['output'], # the model's output names
                      dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                    'output' : {0 : 'batch_size'}})

 

同时我们也可以重新加载导出的模型,测试导出模型是否能够正常工作;

import torch
import torch.onnx
import onnxruntime as ort
import numpy as np
from pytorch_number import Net, get_data_loader, evaluate

def load_model(filename):
    # 加载.pth文件中的模型
    net = Net()
    net.load_state_dict(torch.load(filename))
    net.eval()
    return net

def test_model(model, test_data):
    # 模型测试函数
    accuracy = evaluate(test_data, model)
    print(f".pth文件模型在测试集上的准确率: {accuracy * 100:.2f}%")

def load_onnx_model(filename):
    # 使用onnxruntime加载onnx模型
    ort_session = ort.InferenceSession(filename)
    return ort_session

def test_onnx_model(ort_session, test_data):
    n_correct = 0
    n_total = 0
    for (x, y) in test_data:
        x = x.view(-1, 28 * 28).numpy().astype(np.float32)
        ort_inputs = {ort_session.get_inputs()[0].name: x}
        ort_outs = ort_session.run(None, ort_inputs)
        predicted = np.argmax(ort_outs[0], axis=1)
        n_correct += np.sum(predicted == y.numpy())
        n_total += y.size(0)
    accuracy = n_correct / n_total
    print(f".onnx文件模型在测试集上的准确率: {accuracy * 100:.2f}%")

if __name__ == "__main__":
    # 获取测试数据
    test_data = get_data_loader(is_train=False)

    # 加载.pth文件中的模型并测试
    loaded_model = load_model("mnist_model.pth")
    test_model(loaded_model, test_data)

    # 加载.onnx文件中的模型并测试
    onnx_model = load_onnx_model("mnist_model.onnx")
    test_onnx_model(onnx_model, test_data)

 

输出内容如下所示,与我们在训练完成后的模型一致。

 

[代码与模型分享]
 
最后附上原始代码,以及导出的模型。
 
pytorch_number.py (5.29 KB, 下载次数: 6)

model_test.py (1.55 KB, 下载次数: 2)

mnist_model.pth (233.45 KB, 下载次数: 0)

mnist_model.onnx (232.07 KB, 下载次数: 3)

 
 
 

 

 

 

 

 

最新回复

模型的体积和计算效率是个注意点   详情 回复 发表于 2024-4-17 22:09
点赞 关注

回复
举报

6587

帖子

0

TA的资源

五彩晶圆(高级)

沙发
 

模型的体积和计算效率是个注意点

 
 

回复
您需要登录后才可以回帖 登录 | 注册

随便看看
查找数据手册?

EEWorld Datasheet 技术支持

相关文章 更多>>
推荐帖子
短信幽默

某人一心埋头在电脑中,其母推门进来递过一杯水,他茫然地看看水杯,若有所思地问:妈妈,你是拷贝过来的,还是移动过来的? 

PCB工艺流程

PCB工艺流程

实用电子秤(C)(山东省电子设计大赛)

摘 要 本系统采用单片机AT89S52为控制核心,实现电子秤的基本控制功能。系统的硬件部分包括最小系统板,数据采集、人机交 ...

【DIY新作】LED点阵时钟.3种字体.4种动画.自动旋屏.GPS校时【附原理图】

半年来一直在DIY辉光管、荧光管的时钟,其实算来,我最早DIY的时钟是LED点阵的。 时隔3年,再次汇集这段时间以来的想法,重新DI ...

FPGA经典100问之<仿真 20问>.pdf

FPGA经典100问之<仿真 20问>.pdf

【国产FPGA高云GW1N-4系列开发板测评】FLASH操作

本帖最后由 怀揣少年梦 于 2021-12-28 14:12 编辑 一、目的 使用高云GW1N-4系列开发板来擦除、页写FLASH,FLASH使用的型号 ...

求推荐PGA280的替代芯片?

公司采购说PGA280价格飞涨还买不上,求大家推荐一款无需改电路的替代型号? 具体型号是:PGA280AIPW

ADI如何查看数据手册资料分享

ADI如何查看数据手册资料分享

【2024 DigiKey创意大赛】家居气象台-备料

围绕项目需求购置了以下必要的附料,在此和大家进行一一分享。 第1个,NUCLEO-F411RE开发板的usb接口没有与时俱进采用type C ...

关闭
站长推荐上一条 1/10 下一条

 
EEWorld订阅号

 
EEWorld服务号

 
汽车开发圈

About Us 关于我们 客户服务 联系方式 器件索引 网站地图 最新更新 手机版

站点相关: 国产芯 安防电子 汽车电子 手机便携 工业控制 家用电子 医疗电子 测试测量 网络通信 物联网

北京市海淀区中关村大街18号B座15层1530室 电话:(010)82350740 邮编:100190

电子工程世界版权所有 京B2-20211791 京ICP备10001474号-1 电信业务审批[2006]字第258号函 京公网安备 11010802033920号 Copyright © 2005-2024 EEWORLD.com.cn, Inc. All rights reserved
快速回复 返回顶部 返回列表