3305|0

81

帖子

3

TA的资源

一粒金砂(中级)

楼主
 

#AI挑战营第一站#基于pytorch模型mnist数据及训练测试 [复制链接]

本帖最后由 wakojosin 于 2024-4-15 20:39 编辑

环境

 

 

数据准备

data_train = MNIST('./data/mnist',
                   download=True,
                   transform=transforms.Compose([
                       transforms.Resize((32, 32)),
                       transforms.ToTensor()]))
data_test = MNIST('./data/mnist',
                  train=False,
                  download=True,
                  transform=transforms.Compose([
                      transforms.Resize((32, 32)),
                      transforms.ToTensor()]))
data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8)
data_test_loader = DataLoader(data_test, batch_size=1024, num_workers=8)

LeNet模型构建

'''
卷积神将网络的计算公式为:
N=(W-F+2P)/S+1
其中
N:输出大小
W:输入大小
F:卷积核大小
P:填充值的大小
S:步长大小
'''
net = nn.Sequential(
    nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5,5)), #卷积层提特征,6*28*28
    nn.Sigmoid(), #激活函数:多项式提供非线性
    nn.AvgPool2d(kernel_size=(5,5), stride=2), #下采样,(28-5+2*0)/2+1=>6*12*12
    nn.Conv2d(in_channels=6, out_channels=16, kernel_size=(5,5)),#16*8*8
    nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=(5,5), stride=2),#(8-5+2*0)/2+1=>16*2*2
    nn.Flatten(),
    nn.Linear(16*2*2, 120), nn.Sigmoid(),#全连接:固定特征维度,激活函数:提供非线性网络
    nn.Linear(120, 84), nn.Sigmoid(),
    nn.Linear(84, 10)
)

训练

def train(epoch):
    net.train() #训练模式
    loss_list, batch_list = [], []
    for i, (images, labels) in enumerate(data_train_loader):
        optimizer.zero_grad()

        output = net(images) #喂数据
        loss = criterion(output, labels) #计算loss

        loss_list.append(loss.detach().cpu().item())
        batch_list.append(i+1)

        loss.backward() #返传
        optimizer.step() #优化器参数迭代

模型验证及转换

def test(epoch):
    net.eval() #测试模式
    total_correct = 0
    avg_loss = 0.0
    for i, (images, labels) in enumerate(data_test_loader):
        output = net(images)
        avg_loss += criterion(output, labels).sum()
        pred = output.detach().max(1)[1]
        total_correct += pred.eq(labels.view_as(pred)).sum()

    avg_loss /= len(data_test)
    print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.detach().cpu().item(), float(total_correct) / len(data_test)))
    torch.save(net.state_dict(), "pth/epoch{}_accuracy{}.pth".format(epoch, float(total_correct) / len(data_test)))

    dummy_input = torch.randn(1, 1, 32, 32, requires_grad=True) #batch_size=1,channels=1,w*h=32*32
    torch.onnx.export(net, dummy_input, "lenet.onnx") #先推理后根据tensor把模型图变成静态

    onnx_model = onnx.load("lenet.onnx")
    onnx.checker.check_model(onnx_model) #检查onnx模型

验证输出结果: Test Avg. Loss: 0.000170, Accuracy: 0.945300

 

模型图

 

小结

基于经典的LeNet进行MNIST数据集训练及验证,学习了整体流程.

模型见附件.

 

epoch15_accuracy0.9453.pth

87.32 KB, 下载次数: 0

lenet.onnx

85.18 KB, 下载次数: 0

点赞 关注(1)

回复
举报
您需要登录后才可以回帖 登录 | 注册

随便看看
查找数据手册?

EEWorld Datasheet 技术支持

相关文章 更多>>
关闭
站长推荐上一条 1/8 下一条

 
EEWorld订阅号

 
EEWorld服务号

 
汽车开发圈

About Us 关于我们 客户服务 联系方式 器件索引 网站地图 最新更新 手机版

站点相关: 国产芯 安防电子 汽车电子 手机便携 工业控制 家用电子 医疗电子 测试测量 网络通信 物联网

北京市海淀区中关村大街18号B座15层1530室 电话:(010)82350740 邮编:100190

电子工程世界版权所有 京B2-20211791 京ICP备10001474号-1 电信业务审批[2006]字第258号函 京公网安备 11010802033920号 Copyright © 2005-2025 EEWORLD.com.cn, Inc. All rights reserved
快速回复 返回顶部 返回列表