4182|10

452

帖子

0

TA的资源

版主

楼主
 

#AI挑战营第一站#pytorch训练MNIST数据集实现手写数字识别 [复制链接]

邀请:@chenzhufly   @skywalker_lee   @wsdymg   @bigbat   参与回复

本帖最后由 LitchiCheng 于 2024-4-18 22:28 编辑

 

下载MNIST数据集

  • # MNIST数据集,用于训练,一次抓60 size
  • self._train_loader = torch.utils.data.DataLoader(
  • torchvision.datasets.MNIST('./data/', train=True, download=True,
  • transform=torchvision.transforms.Compose([
  • torchvision.transforms.ToTensor(),
  • torchvision.transforms.Normalize(
  • (0.1307,), (0.3081,))
  • ])),
  • batch_size=60, shuffle=True)
  • # 用于测试,一次抓500 size
  • self._test_loader = torch.utils.data.DataLoader(
  • torchvision.datasets.MNIST('./data/', train=False, download=True,
  • transform=torchvision.transforms.Compose([
  • torchvision.transforms.ToTensor(),
  • torchvision.transforms.Normalize(
  • (0.1307,), (0.3081,))
  • ])),
  • batch_size=500, shuffle=True)

编辑网络

  • # 连接序列
  • self._conv1_layer = nn.Sequential(
  • # 卷积
  • nn.Conv2d(1,15,5),
  • # 激活函数
  • nn.ReLU(),
  • # 最大池化,减少特征量,选特征最大的数,是一种下采样
  • nn.MaxPool2d(kernel_size=2, stride=2),
  • )
  • self._conv2_layer = nn.Sequential(
  • nn.Conv2d(15,30,5),
  • nn.ReLU(),
  • nn.MaxPool2d(kernel_size=2, stride=2),
  • )
  • self._full_layer = nn.Sequential(
  • # 卷积层都是四维张量,展平为二维张量给连接层用
  • nn.Flatten(),
  • nn.Linear(in_features=480, out_features=60),
  • nn.ReLU(),
  • nn.Linear(in_features=60, out_features=10),
  • )

判断是否可以是否GPU训练

  • if torch.cuda.is_available():
  • print("Use CUDA training!")
  • self._device = torch.device("cuda")
  • else:
  • print("Use CPU training!")
  • self._device = torch.device("cpu")

训练

  • def train(self):
  • loss_d = []
  • for epoch in range(1, self._epochs + 1):
  • self._cnn.train(mode=True)
  • for idx, (train_img, train_label) in enumerate(self._train_loader):
  • # 复制到device中
  • train_img = train_img.to(self._device)
  • train_label = train_label.to(self._device)
  • outputs = self._cnn(train_img)
  • # 清除梯度
  • self._optim.zero_grad()
  • loss = self._loss_func(outputs, train_label)
  • # 反向传播
  • loss.backward()
  • # 更新权重
  • self._optim.step()
  • # print('Train epoch {}: loss: {:.6f}'.format(epoch,loss.item()))
  • loss_d.append(loss.item())
  • plt.plot(range(0,len(loss_d)),loss_d)
  • plt.show()

Train的loss分布

 

Test的loss以及准确率

 

识别的结果

 

保存的pth和onnx模型

  • def savePthModel(self, pth_name:str):
  • torch.save(self._cnn.state_dict(), pth_name)
  • def saveOnnxModel(self, onnx_name:str):
  • input = torch.randn(1,1,28,28)
  • torch.onnx.export(self._cnn, input, onnx_name, verbose=True)

 

完整代码

  • import torch
  • import torch.nn as nn
  • import torchvision.datasets
  • import matplotlib.pyplot as plt
  • import numpy as np
  • class CNN(nn.Module):
  • def __init__(self):
  • super(CNN, self).__init__()
  • # 连接序列
  • self._conv1_layer = nn.Sequential(
  • # 卷积
  • nn.Conv2d(1,15,5),
  • # 激活函数
  • nn.ReLU(),
  • # 最大池化,减少特征量,选特征最大的数,是一种下采样
  • nn.MaxPool2d(kernel_size=2, stride=2),
  • )
  • self._conv2_layer = nn.Sequential(
  • nn.Conv2d(15,30,5),
  • nn.ReLU(),
  • nn.MaxPool2d(kernel_size=2, stride=2),
  • )
  • self._full_layer = nn.Sequential(
  • # 卷积层都是四维张量,展平为二维张量给连接层用
  • nn.Flatten(),
  • nn.Linear(in_features=480, out_features=60),
  • nn.ReLU(),
  • nn.Linear(in_features=60, out_features=10),
  • )
  • def forward(self, input):
  • # 层层连接,两个卷积层,最后全连接层
  • output = self._conv1_layer(input)
  • output = self._conv2_layer(output)
  • output = self._full_layer(output)
  • return output
  • class Test:
  • def __init__(self):
  • # MNIST数据集,用于训练,一次抓60 size
  • self._train_loader = torch.utils.data.DataLoader(
  • torchvision.datasets.MNIST('./data/', train=True, download=True,
  • transform=torchvision.transforms.Compose([
  • torchvision.transforms.ToTensor(),
  • torchvision.transforms.Normalize(
  • (0.1307,), (0.3081,))
  • ])),
  • batch_size=60, shuffle=True)
  • # 用于测试,一次抓500 size
  • self._test_loader = torch.utils.data.DataLoader(
  • torchvision.datasets.MNIST('./data/', train=False, download=True,
  • transform=torchvision.transforms.Compose([
  • torchvision.transforms.ToTensor(),
  • torchvision.transforms.Normalize(
  • (0.1307,), (0.3081,))
  • ])),
  • batch_size=500, shuffle=True)
  • # 训练次数
  • self._epochs = 3
  • self._cnn = CNN()
  • # 交叉熵损失函数,刻画的是两个概率分布的距离,交叉熵越小,概率分布越接近
  • self._loss_func = nn.CrossEntropyLoss()
  • # 优化器
  • self._optim = torch.optim.Adam(self._cnn.parameters(), lr=0.01)
  • if torch.cuda.is_available():
  • print("Use CUDA training!")
  • self._device = torch.device("cuda")
  • else:
  • print("Use CPU training!")
  • self._device = torch.device("cpu")
  • def train(self):
  • loss_d = []
  • for epoch in range(1, self._epochs + 1):
  • self._cnn.train(mode=True)
  • for idx, (train_img, train_label) in enumerate(self._train_loader):
  • # 复制到device中
  • train_img = train_img.to(self._device)
  • train_label = train_label.to(self._device)
  • outputs = self._cnn(train_img)
  • # 清除梯度
  • self._optim.zero_grad()
  • loss = self._loss_func(outputs, train_label)
  • # 反向传播
  • loss.backward()
  • # 更新权重
  • self._optim.step()
  • # print('Train epoch {}: loss: {:.6f}'.format(epoch,loss.item()))
  • loss_d.append(loss.item())
  • plt.plot(range(0,len(loss_d)),loss_d)
  • plt.show()
  • def test(self):
  • correct_num = 0
  • total_num = 0
  • loss_d = []
  • self._cnn.train(mode=False)
  • with torch.no_grad():
  • for idx, (test_img, test_label) in enumerate(self._test_loader):
  • test_img = test_img.to(self._device)
  • test_label = test_label.to(self._device)
  • total_num += test_label.size(0)
  • outputs = self._cnn(test_img)
  • loss = self._loss_func(outputs, test_label)
  • loss_d.append(loss.item())
  • predictions = torch.argmax(outputs, dim=1)
  • correct_num += torch.sum(predictions == test_label)
  • acc_num = ((correct_num.item()/total_num)*100)
  • title_str ="Accuracy:"+str(acc_num)+"%"
  • plt.title(title_str)
  • plt.plot(range(0,len(loss_d)),loss_d)
  • plt.show()
  • def plotTestResult(self):
  • iteration = enumerate(self._test_loader)
  • idx, (test_img, test_label) = next(iteration)
  • with torch.no_grad():
  • outputs = self._cnn(test_img)
  • fig = plt.figure()
  • for i in range(4 * 2):
  • plt.subplot(4, 2, i + 1)
  • plt.tight_layout()
  • plt.imshow(test_img[0], cmap='gray', interpolation='none')
  • plt.title('real: {}, predict: {}'.format(
  • test_label, outputs.data.max(1, keepdim=True)[1].item()
  • ))
  • plt.xticks([])
  • plt.yticks([])
  • plt.show()
  • def savePthModel(self, pth_name:str):
  • torch.save(self._cnn.state_dict(), pth_name)
  • def saveOnnxModel(self, onnx_name:str):
  • input = torch.randn(1,1,28,28)
  • torch.onnx.export(self._cnn, input, onnx_name, verbose=True)
  • if __name__ == "__main__":
  • mt = Test()
  • mt.train()
  • mt.test()
  • mt.plotTestResult()
  • mt.savePthModel("model.pth")
  • mt.saveOnnxModel("model.onnx")

视频讲解


 

查看本帖全部内容,请登录或者注册

最新回复

好好学习,天天向上,加油每一个人,加油自己,加油!!!   详情 回复 发表于 2024-10-29 21:11
点赞 关注

回复
举报

2943

帖子

4

TA的资源

五彩晶圆(中级)

沙发
 

我对python实在是没有兴趣,所以就不搞pytorch了,不过我的帖子AI到底在搞个“毛儿”希望兄弟能够积极参与。

点评

666,搞个"毛儿"  详情 回复 发表于 2024-4-19 15:09
 
 

回复

452

帖子

0

TA的资源

版主

板凳
 
bigbat 发表于 2024-4-19 11:21 我对python实在是没有兴趣,所以就不搞pytorch了,不过我的帖子AI到底在搞个“毛儿”希望兄弟能 ...

666,搞个"毛儿"

 
 
 

回复

6909

帖子

9

TA的资源

版主

4
 

这要是布局到MCU里,如何生成C文件  

点评

C应该有别的推理框架,pytorch肯定上不了  详情 回复 发表于 2024-4-21 10:29
 
个人签名

在爱好的道路上不断前进,在生活的迷雾中播撒光引

 
 

回复

452

帖子

0

TA的资源

版主

5
 
秦天qintian0303 发表于 2024-4-19 23:33 这要是布局到MCU里,如何生成C文件  

C应该有别的推理框架,pytorch肯定上不了

 
 
 

回复

701

帖子

5

TA的资源

纯净的硅(高级)

6
 

感谢楼主分享的技术内容信息,非常详实,实用价值非常大,值得学习

点评

感谢支持,一起进步啦  详情 回复 发表于 2024-4-21 21:54
 
 
 

回复

452

帖子

0

TA的资源

版主

7
 
chejm 发表于 2024-4-21 21:25 感谢楼主分享的技术内容信息,非常详实,实用价值非常大,值得学习

感谢支持,一起进步啦

 
 
 

回复

59

帖子

0

TA的资源

一粒金砂(中级)

8
 

心目中的最佳贡献给楼主...帮助非常大

点评

哈哈哈,感谢  详情 回复 发表于 2024-4-29 22:10
 
 
 

回复

452

帖子

0

TA的资源

版主

9
 
crimsonsnow 发表于 2024-4-25 11:12 心目中的最佳贡献给楼主...帮助非常大

哈哈哈,感谢

 
 
 

回复

417

帖子

0

TA的资源

纯净的硅(中级)

10
 

好好学习,天天向上,加油每一个人,加油自己,加油!!!

点评

加油  详情 回复 发表于 2024-10-29 21:27
 
 
 

回复

452

帖子

0

TA的资源

版主

11
 
通途科技 发表于 2024-10-29 21:11 好好学习,天天向上,加油每一个人,加油自己,加油!!!

加油

 
 
 

回复
您需要登录后才可以回帖 登录 | 注册

随便看看
查找数据手册?

EEWorld Datasheet 技术支持

相关文章 更多>>
关闭
站长推荐上一条 1/10 下一条
有奖直播:当AI遇见仿真,会有什么样的电子行业革新之路?
首场直播:Simcenter AI 赋能电子行业研发创新
直播时间:04月15日14:00-14:50

查看 »

 
EEWorld订阅号

 
EEWorld服务号

 
汽车开发圈

 
机器人开发圈

About Us 关于我们 客户服务 联系方式 器件索引 网站地图 最新更新 手机版

站点相关: 国产芯 安防电子 汽车电子 手机便携 工业控制 家用电子 医疗电子 测试测量 网络通信 物联网 13

北京市海淀区中关村大街18号B座15层1530室 电话:(010)82350740 邮编:100190

电子工程世界版权所有 京B2-20211791 京ICP备10001474号-1 电信业务审批[2006]字第258号函 京公网安备 11010802033920号 Copyright © 2005-2025 EEWORLD.com.cn, Inc. All rights reserved
快速回复 返回顶部 返回列表