3594|1

136

帖子

2

TA的资源

一粒金砂(高级)

楼主
 

#AI挑战营第一站#基于PyTorch,在PC上完成MNIST手写数字识别模型训练 [复制链接]

#AI挑战营第一站# 手写数字的识别

 

本人是AI菜鸟,跟着大佬后面跑。跟着葫芦画瓢。

首先导入相应的库:

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt

其次,准备数据集:

batch_size = 64
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

train_dataset = datasets.MNIST(root='./dataset/mnist/', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_dataset = datasets.MNIST(root='./dataset/mnist/', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)

接下来定义网络结构:

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)
        self.pooling = torch.nn.MaxPool2d(2)
        self.fc = torch.nn.Linear(320, 10)

    def forward(self, x):
        # flatten data from (n,1,28,28) to (n, 784)

        batch_size = x.size(0)
        x = F.relu(self.pooling(self.conv1(x)))
        x = F.relu(self.pooling(self.conv2(x)))
        x = x.view(batch_size, -1)
        # print("x.shape",x.shape)
        x = self.fc(x)

        return x


model = Net()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

# construct loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

定义训练和测试函数:

def train(epoch):
    running_loss = 0.0
    for batch_idx, data in enumerate(train_loader, 0):
        inputs, target = data
        inputs, target = inputs.to(device), target.to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 300 == 299:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
            running_loss = 0.0


def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('accuracy on test set: %d %% ' % (100 * correct / total))
    return correct / total

进行训练:

if __name__ == '__main__':
    epoch_list = []
    acc_list = []

    for epoch in range(20):
        train(epoch)
        acc = test()
        epoch_list.append(epoch)
        acc_list.append(acc)

    plt.plot(epoch_list, acc_list)
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.show()

 

 

模型导出为pth文件:

    #Save Model as pth format:
    torch.save(model.state_dict(), 'mnist_101_model.pth')
mnist_101_model.pth (35.31 KB, 下载次数: 1)

导出Onnx文件:

    #导出为onnx模型
    dummy_input = torch.randn(1, 1, 28, 28)
    torch.onnx.export(model, dummy_input.to(device), "mnist_101_model.onnx", verbose=False)
mnist_101_model.onnx (34.27 KB, 下载次数: 1)

用新数据测试:

 

测试脚本:

import torch
import torchvision.transforms as transforms
from torchvision.transforms.functional import resize
from PIL import Image

from mnist_gpu_version import Net

#找到测试文件夹下面的png图片
import os, glob


# 定义图像预处理函数
def preprocess_image(image_path):
    image = Image.open(image_path).convert("L")
    image = resize(image, (28, 28))  # 将图像大小调整为模型接受的大小 (28x28)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    image = transform(image).unsqueeze(0)  # 增加一维表示批次 (batch)
    return image

def predict_image(image_path):
    # 进行图像识别并输出结果
    predicted_label = recognize_image(image_path)
    actual_number = png_file.split('.')[0].split('_')[1]
    if(int(actual_number) == int(predicted_label)):
        compareResult = "Matched"
    else:
        compareResult = "DisMatch"
    print(f"Actual Number is: {actual_number} " + f"  预测的标签 = {predicted_label}    " + compareResult)

# 创建模型实例并加载预训练模型参数
model = Net()
model.load_state_dict(torch.load('mnist_101_model.pth'))
model.eval()

# 设置设备(GPU 或 CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


# 进行图像识别
def recognize_image(image_path):
    image = preprocess_image(image_path).to(device)
    output = model(image)
    _, predicted = torch.max(output.data, 1)
    return predicted.item()


# 指定要识别的图片路径
folder_path = 'images'
#使用glob模块找到所有的Png文件
png_files_path = glob.glob(os.path.join(folder_path, "*.png"))

#遍历处理所有png文件
for png_file in png_files_path:
    predict_image(png_file)

测试结果:

 

最新回复

虽然是葫芦画瓢,在PC上完成MNIST手写数字识别模型训练也是整的不错滴   详情 回复 发表于 2024-4-23 07:30
点赞 关注

回复
举报

6815

帖子

0

TA的资源

五彩晶圆(高级)

沙发
 

虽然是葫芦画瓢,在PC上完成MNIST手写数字识别模型训练也是整的不错滴

 
 

回复
您需要登录后才可以回帖 登录 | 注册

随便看看
查找数据手册?

EEWorld Datasheet 技术支持

相关文章 更多>>
关闭
站长推荐上一条 1/8 下一条

 
EEWorld订阅号

 
EEWorld服务号

 
汽车开发圈

About Us 关于我们 客户服务 联系方式 器件索引 网站地图 最新更新 手机版

站点相关: 国产芯 安防电子 汽车电子 手机便携 工业控制 家用电子 医疗电子 测试测量 网络通信 物联网

北京市海淀区中关村大街18号B座15层1530室 电话:(010)82350740 邮编:100190

电子工程世界版权所有 京B2-20211791 京ICP备10001474号-1 电信业务审批[2006]字第258号函 京公网安备 11010802033920号 Copyright © 2005-2025 EEWORLD.com.cn, Inc. All rights reserved
快速回复 返回顶部 返回列表