883|1

8

帖子

0

TA的资源

一粒金砂(中级)

楼主
 

#AI挑战营第一站# MNIST手写数字识别 [复制链接]

本实验使用了由两层卷积层和一层全连接层组成的卷积神经网络实现了MNIST手写数字识别,得到的模型准确率为98.76%。

由于网络规模很小,本实验没有使用GPU,本实验使用的CPU为R5-7530u(六核十二线程,2GHz)

1、本实验构建的模型如下,完整代码见附件main.py:

class ConvolutionalNeuralWork(torch.nn.Module):
    def __init__(self, num_classes=10):
        super(ConvolutionalNeuralWork, self).__init__()
        self.conv_layer1 = torch.nn.Sequential(torch.nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(kernel_size=2, stride=2))
        self.conv_layer2 = torch.nn.Sequential(torch.nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(kernel_size=2, stride=2))
        self.linear_layer = torch.nn.Linear(1568, num_classes)

    def forward(self, x):
        x = self.conv_layer1(x)
        x = self.conv_layer2(x)
        x = x.reshape(x.size(0), -1)
        x = self.linear_layer(x)
        return x

2、训练代码如下,共训练了一百轮,完整代码见附件main.py:

wb = openpyxl.Workbook()
ws = wb.active
ws.cell(1, 1).value = "epoch"
ws.cell(1, 2).value = "train loss"
ws.cell(1, 3).value = "val accuracy"

num_epochs = 100
batch_size = 100
num_class = 10
learning_rate = 0.0001
#加载数据
train_dataset = torchvision.datasets.MNIST(root="./data0", train=True, transform=torchvision.transforms.ToTensor(), download=False)
test_dataset = torchvision.datasets.MNIST(root="./data0", train=False, transform=torchvision.transforms.ToTensor(), download=False)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

model = ConvolutionalNeuralWork(num_class)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

best_accuracy = 0
for epoch in range(num_epochs):
    loss_sum = 0
	#训练
    for image, label in train_loader:
        outputs = model(image)
        loss = criterion(outputs, label)
        loss_sum += loss.item()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    print("epoch:", epoch)
    print("train loss:", loss_sum/train_dataset.data.shape[0])
    ws.cell(epoch+2, 1).value = epoch + 1
    ws.cell(epoch+2, 2).value = loss_sum/train_dataset.data.shape[0]
    total = 0
    correct = 0
	#测试
    for image, label in test_loader:
        outputs = model(image)
        s, predicted = torch.max(outputs, 1)
        total = label.size(0) + total
        for i in range(label.size(0)):
            if label[i] == predicted[i]:
                correct = correct + 1
    print("val accuracy:", correct / total)
    ws.cell(epoch+2, 3).value = correct / total
    if correct / total > best_accuracy:
        best_accuracy = correct / total
        print("better model")
        torch.save(model, "best.pth")
        ws.cell(epoch+2, 4).value = "better model"
    print('\n')
    wb.save("result.xlsx")

3、训练过程中训练集损失函数和测试集准确率如下所示,详细信息见附件result.xlsx

   

4、最终生成的模型文件*.pth,以及转换的ONNX模型见附件

best.onnx

81.8 KB, 下载次数: 1

best.pth

84.92 KB, 下载次数: 0

main.py

3.38 KB, 下载次数: 2

result.xlsx

22.51 KB, 下载次数: 1

此帖出自ARM技术论坛

最新回复

最终不是都需要弄到RV1106开发板上吗?     详情 回复 发表于 2024-4-12 22:34
点赞 关注
 

回复
举报

6450

帖子

9

TA的资源

版主

沙发
 

最终不是都需要弄到RV1106开发板上吗?  

此帖出自ARM技术论坛
 
个人签名

在爱好的道路上不断前进,在生活的迷雾中播撒光引

 
 

回复
您需要登录后才可以回帖 登录 | 注册

随便看看
查找数据手册?

EEWorld Datasheet 技术支持

相关文章 更多>>
关闭
站长推荐上一条 1/9 下一条

 
EEWorld订阅号

 
EEWorld服务号

 
汽车开发圈

About Us 关于我们 客户服务 联系方式 器件索引 网站地图 最新更新 手机版

站点相关: 国产芯 安防电子 汽车电子 手机便携 工业控制 家用电子 医疗电子 测试测量 网络通信 物联网

北京市海淀区中关村大街18号B座15层1530室 电话:(010)82350740 邮编:100190

电子工程世界版权所有 京B2-20211791 京ICP备10001474号-1 电信业务审批[2006]字第258号函 京公网安备 11010802033920号 Copyright © 2005-2025 EEWORLD.com.cn, Inc. All rights reserved
快速回复 返回顶部 返回列表