3583|1

183

帖子

12

TA的资源

一粒金砂(高级)

楼主
 

#AI挑战营第一站#编程一小时,在PC上基于PyTorch完成MNIST手写数字识别模型训练 [复制链接]

本帖最后由 nemon 于 2024-4-24 19:39 编辑

本文以囫囵吞枣的手法,希望读者能在一小时内走马观花地了解一下PyTorch干活的基本流程

首先要安装环境,不必说那3.8以上的 python,也不必说那  python.exe -m pip install --upgrade pip,更不必说 matplotlib 和 numpy ,单就是Torch,就有无限的乐趣。(此处致敬周老师)

可以在 https://pytorch.org/get-started/locally/ 看到一个工具,自助选择你的安装平台,就能生成需要的安装说明:

可以看到,我为我1810后的破无显卡电脑选了个windows pip安装,而分院帽在嘀咕了仅仅0.01秒之后,就大声喊出了结果——史莱特林!

抱歉串词了,应该是这个:

pip3 install torch torchvision torchaudio

但是因为人生苦短,所以我用五道口技术学院:

pip3 install torch torchvision torchaudio -i https://pypi.tuna.tsinghua.edu.cn/simple

安装完之后,就可以直接进 python了

首先,要加载库,准备数据。torch自带MNIST数据集下载,你就偷着乐吧:

# 第一部分
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# 准备数据
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)

然后呢,看看数据及里都有啥好东东:

# 第二部分
# 画一下数据
import torchvision
from matplotlib import pyplot
import matplotlib.pyplot as plt
import numpy as np

def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    # img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        pyplot.imshow(npimg, cmap="Greys")
        pyplot.show()
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

for images,labels in train_loader:
	break

img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)

如果你没敲错空格,那么会看到这个:

然后关掉,开始定义网络和损失函数、优化器。简单解释一下,relu是激活函数,负责决定要谁不要谁;Conv2d、max_pool2d是2维卷积,负责把图变小;Linear其实就是全连接,上一层的每一个都和下一层连上,很像二分图;损失函数是用来评估效果的;SGD优化器是一种爬坡策略。代码如下:

#第三部分
# 定义模型
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc = nn.Linear(20 * 4 * 4, 10)
    
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2, 2)
        x = x.view(-1, 20 * 4 * 4)
        x = self.fc(x)
        return x

model = ConvNet()
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

训练过程,就是把数据喂给模型,然后算一下“损失”,也就是和目标的差距,再告诉模型把这个差距反向传递回各层,顺便打印一下:

#第四部分
# 训练模型
def train(epoch):
    running_loss = 0.0
    for batch_idx, data in enumerate(train_loader, 0):
        inputs, targets = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if batch_idx % 300 == 299:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
            running_loss = 0.0

测试结果和训练过程有点像,但是一定不能做反向传递,否则就成造假了。然后取出预测最大的那个,比较一下是不是和正确答案一样:

#第五部分
# 测试模型
def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('Accuracy on test set: %7.2f %%' % (100 * correct / total))

过程很简单,跑10圈试试:

#第六部分
# 训练10个周期
epochs = 10
for epoch in range(epochs):
    train(epoch)
    test()

我这次的效果是这样的:

[1,   300] loss: 0.559
[1,   600] loss: 0.193
[1,   900] loss: 0.140
Accuracy on test set:   96.34 %
[2,   300] loss: 0.115
[2,   600] loss: 0.098
[2,   900] loss: 0.094
Accuracy on test set:   97.78 %
[3,   300] loss: 0.078
[3,   600] loss: 0.080
[3,   900] loss: 0.072
Accuracy on test set:   98.00 %
[4,   300] loss: 0.067
[4,   600] loss: 0.068
[4,   900] loss: 0.064
Accuracy on test set:   97.93 %
[5,   300] loss: 0.059
[5,   600] loss: 0.061
[5,   900] loss: 0.055
Accuracy on test set:   98.54 %
[6,   300] loss: 0.051
[6,   600] loss: 0.053
[6,   900] loss: 0.052
Accuracy on test set:   98.58 %
[7,   300] loss: 0.049
[7,   600] loss: 0.045
[7,   900] loss: 0.048
Accuracy on test set:   98.75 %
[8,   300] loss: 0.048
[8,   600] loss: 0.041
[8,   900] loss: 0.044
Accuracy on test set:   98.68 %
[9,   300] loss: 0.044
[9,   600] loss: 0.038
[9,   900] loss: 0.041
Accuracy on test set:   98.63 %
[10,   300] loss: 0.038
[10,   600] loss: 0.040
[10,   900] loss: 0.036
Accuracy on test set:   98.78 %

可以看出,收敛的那叫一个快,要不是这方子被炼过上万遍了,还以为过拟合了呢。

然后保存一下劳动成果,先存成pth:

#第七部分
# 保存模型
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),#'epoch': epoch,
}, "mnist-0423a.pth")

pth文件:

mnist-0423a.pth (69.93 KB, 下载次数: 0)

保存成的文件可以用以下代码加载使用(注意用前要定义模型):

#第八部分
# 加载模型
checkpoint = torch.load('mnist-0423a.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
# 继续训练或测试
test()

然后还可以按本次活动要求,转换ONNX模型:

#第九部分
import torch.onnx
# 导出模型
dummy_input = torch.randn(1, 1, 28, 28)
torch.onnx.export(
    model, 
    dummy_input,
    "mnist-0423a.onnx", 
    verbose=True, 
    input_names=['input'],
    output_names=['output'])

ONNX文件:

mnist-0423a.onnx (43.11 KB, 下载次数: 0)

现在可以到 https://netron.app/ 上看图形化的模型结构了:

规定套路完成。亢龙有悔,打完收工……

到一小时了吗?

 

最新回复

图形化的模型结构很酷炫哈   详情 回复 发表于 2024-4-26 07:22
点赞 关注

回复
举报

6587

帖子

0

TA的资源

五彩晶圆(高级)

沙发
 

图形化的模型结构很酷炫哈

 
 

回复
您需要登录后才可以回帖 登录 | 注册

随便看看
查找数据手册?

EEWorld Datasheet 技术支持

相关文章 更多>>
关闭
站长推荐上一条 1/10 下一条

 
EEWorld订阅号

 
EEWorld服务号

 
汽车开发圈

About Us 关于我们 客户服务 联系方式 器件索引 网站地图 最新更新 手机版

站点相关: 国产芯 安防电子 汽车电子 手机便携 工业控制 家用电子 医疗电子 测试测量 网络通信 物联网

北京市海淀区中关村大街18号B座15层1530室 电话:(010)82350740 邮编:100190

电子工程世界版权所有 京B2-20211791 京ICP备10001474号-1 电信业务审批[2006]字第258号函 京公网安备 11010802033920号 Copyright © 2005-2024 EEWORLD.com.cn, Inc. All rights reserved
快速回复 返回顶部 返回列表