1086|1

3

帖子

0

TA的资源

一粒金砂(中级)

楼主
 

基于Pytorch的神经网络识别MNIST数据集  [复制链接]

本帖最后由 wangyoujie 于 2024-4-19 10:50 编辑

一、首先,下载安装环境:

matplotlib>=3.2.2

numpy>=1.19.5

pandas>=1.1.5

Pillow>=9.0.0

torch>=1.10.1

torchvision>=0.11.2

tqdm>=4.62.3

二、设计一个网络模型

import torch


# 神经网络结构
class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)
        self.pooling = torch.nn.MaxPool2d(2)
        self.fc = torch.nn.Linear(320, 10)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        batch_size = x.size(0)

        x = self.conv1(x)
        x = self.pooling(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.pooling(x)
        x = self.relu(x)

        x = x.view(batch_size, -1)
        x = self.fc(x)

        return x

三、添加数据集

import torch
import os
from PIL import Image
from torch.utils.data import Dataset


class MyMnistDataset(Dataset):
    def __init__(self, root, transform):

        self.myMnistPath = root
        self.imagesData = []
        self.labelsData = []
        self.labelsDict = {}
        self.trans = transform

        self.loadLabelsDate()
        self.loadImageData()

    # 读取标签txt文件,并生成字典
    def loadLabelsDate(self):
        labelsPath = os.path.join(self.myMnistPath, "labels", "labels.txt")
        f = open(labelsPath)
        lines = f.readlines()
        for line in lines:
            name = line.split(' ')[0]
            label = line.split(' ')[1]
            self.labelsDict[name] = int(label)

    # 读取手写图片数据,并将图片数据和对应的标签组合在一起
    def loadImageData(self):
        imagesFolderPath = os.path.join(self.myMnistPath, 'images')
        imageFiles = os.listdir(imagesFolderPath)

        for imageName in imageFiles:
            imagePath = os.path.join(imagesFolderPath, imageName)
            image = Image.open(imagePath)
            grayImage = image.convert("L")

            imageTensor = self.trans(grayImage)
            self.imagesData.append(imageTensor)

            self.labelsData.append(self.labelsDict[imageName])

        self.labelsData = torch.Tensor(self.labelsData)

    # 重写魔法函数
    def __getitem__(self, index):
        return self.imagesData[index], self.labelsData[index]

    # 重写魔法函数
    def __len__(self):
        return len(self.labelsData)

 

四、训练代码

import os.path

import torch
import argparse
import pandas as pd
import numpy as np
from torch.optim import lr_scheduler
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from tqdm import tqdm
from model import CNN


def parse_opt():
    parser = argparse.ArgumentParser(description='ResNet-MNIST')
    parser.add_argument('--epochs', type=int, default=30, help='input total epoch')
    parser.add_argument('--batch_size', type=int, default=64, help='dataloader batch size')
    parser.add_argument('--lr', type=float, default=0.001, help='optimizer learning rate')
    parser.add_argument('--weight_decay', type=float, default=0.00001, help='optimizer weight_decay')

    args = parser.parse_args()
    return args


def train(epoch):
    total_loss = 0
    total_correct = 0
    total_data = 0
    global iteration

    model.train()
    train_bar = tqdm(train_loader)
    for data in train_bar:
        images, labels = data
        images, labels = images.to(device), labels.to(device)

        # 梯度清零
        optimizer.zero_grad()
        # 正向传播
        outputs = model(images)
        _, predicted = torch.max(outputs.data, dim=1)
        total_correct += torch.eq(predicted, labels).sum().item()
        # 计算损失
        loss = criterion(outputs, labels)
        total_loss += loss.item()
        # 反向传播
        loss.backward()
        # 权重更新
        optimizer.step()

        total_data += labels.size(0)
        iteration = iteration + 1

        train_bar.desc = "train epoch[{}/{}] loss:{:.3f} iteration:{}".format(epoch + 1,
                                                                               args.epochs,
                                                                               loss,
                                                                               iteration)
    # 更新学习率
    scheduler.step()
    print(optimizer.state_dict()['param_groups'][0]['lr'])

    loss = total_loss / len(train_loader)
    acc = 100 * total_correct / total_data
    train_loss.append(loss)
    train_acc.append(acc)

    print('accuracy on train set:%d %%' % acc)


# 验证函数
def validate(epoch):
    total_loss = 0
    total_correct = 0
    total_data = 0
    with torch.no_grad():
        test_bar = tqdm(test_loader)
        for data in test_bar:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            # 正向传播
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)
            total_correct += torch.eq(predicted, labels).sum().item()

            # 计算损失
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            total_data += labels.size(0)

            # 进度条描述训练进度
            test_bar.desc = "validate epoch[{}/{}]".format(epoch + 1,
                                                           args.epochs)

        loss = total_loss / len(train_loader)
        acc = 100 * total_correct / total_data
        validate_loss.append(loss)
        validate_acc.append(acc)

        print('accuracy on validate set:%d %%\n' % acc)


if __name__ == "__main__":
    args = parse_opt()
    device_type = "GPU" if torch.cuda.is_available() else "CPU"
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    # 由于神经网络中数据对象为tensor,所以需要用transform将普通数据转换为tensor
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    # 训练数据集,torchvision中封装了数据集的下载方式,调用下面函数就会自动下载
    train_dataset = datasets.MNIST(root='../dataset/', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size)

    # 测试数据集
    test_dataset = datasets.MNIST(root='../dataset/', train=False, download=True, transform=transform)
    test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size)

    # 生成神经网络实例
    model = CNN()
    model.to(device)

    # 损失函数和优化器
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    # 设置动态学习率
    scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

    train_loss = []
    train_acc = []
    validate_loss = []
    validate_acc = []
    iteration = 1

    for i in range(args.epochs):
        train(i)
        validate(i)

    torch.save(model.state_dict(), "./{}/CNN.pth".format(device_type))
    epoch = np.arange(1, args.epochs + 1)
    dataframe = pd.DataFrame({'epoch': epoch,
                              'train loss': train_loss,
                              'train accuracy': train_acc,
                              'validate loss': validate_loss,
                              'validate accuracy': validate_acc
                              })
    dataframe.to_csv(r"./{}/loss&acc.csv".format(device_type))

五、网络训练结果:使用GPU训练

 

六、网络训练模型权重转换

import torch
from model import CNN

# 加载预训练模型
model = CNN()
model.load_state_dict(torch.load("E:\PythonProjects\MNIST-pytorch\CNN\GPU\Mnisty.pth"),strict=False)
# 设置模型为评估模式
# model.eval()

# 输入示例,需符合模型输入要求
dummy_input = torch.randn(1, 1, 28, 28)

# 转换模型为 ONNX 格式
torch.onnx.export(model, dummy_input, "Mnist.onnx", verbose=True, input_names=['input'], output_names=['output'])


 

 

Mnisty.pth

35.15 KB, 下载次数: 0

Mnist.onnx

45.96 KB, 下载次数: 0

此帖出自ARM技术论坛

最新回复

代码增加一些注释呗,,,     详情 回复 发表于 2024-4-23 07:29
点赞 关注
 

回复
举报

6822

帖子

0

TA的资源

五彩晶圆(高级)

沙发
 

代码增加一些注释呗,,,

 

此帖出自ARM技术论坛
 
 
 

回复
您需要登录后才可以回帖 登录 | 注册

随便看看
查找数据手册?

EEWorld Datasheet 技术支持

相关文章 更多>>
关闭
站长推荐上一条 1/7 下一条

 
EEWorld订阅号

 
EEWorld服务号

 
汽车开发圈

About Us 关于我们 客户服务 联系方式 器件索引 网站地图 最新更新 手机版

站点相关: 国产芯 安防电子 汽车电子 手机便携 工业控制 家用电子 医疗电子 测试测量 网络通信 物联网

北京市海淀区中关村大街18号B座15层1530室 电话:(010)82350740 邮编:100190

电子工程世界版权所有 京B2-20211791 京ICP备10001474号-1 电信业务审批[2006]字第258号函 京公网安备 11010802033920号 Copyright © 2005-2025 EEWORLD.com.cn, Inc. All rights reserved
快速回复 返回顶部 返回列表