3328|1

111

帖子

0

TA的资源

一粒金砂(中级)

楼主
 

#AI挑战营第一站# pytorch环境+minst数据集训练 [复制链接]

本帖最后由 tinnu 于 2024-4-13 18:25 编辑

  • 安装显卡加速支持

    1. 英伟达
      • 安装cuda
      • cuda
      • cudnn
      • 装完之后在命令行输入 nvcc --version 没有报错即通过
        1. > nvcc --version
        2. nvcc: NVIDIA (R) Cuda compiler driver
        3. Copyright (c) 2005-2023 NVIDIA Corporation
        4. Built on Mon_Apr__3_17:36:15_Pacific_Daylight_Time_2023
        5. Cuda compilation tools, release 12.1, V12.1.105
        6. Build cuda_12.1.r12.1/compiler.32688072_0
    2. AMD
    3. 国产之光摩尔线程
    4. CPU
      • 完全不用管
  • 安装python

    • pip3 install python3
    • pip3 install matplotlib
  • 根据pytorch官网提示安装pytorch

    1. CPU安装
      • win下纯cpu是这个命令:
      • pip3 install torch torchvision torchaudio
    2. 英伟达安装
    3. AMD显卡安装
    4. 摩尔线程显卡安装
  • 因为我的显卡是N年前的MX130,装完之后发现跑得比CPU还慢,而且满三倍,于是选择纯CPU安装;安装完后一跑,发现下载贼久,因为minst的数据集在……你懂的,所以我觉得找个外面的云平台可能会比较快。

  • 找来找去,能用的云端免费平台也就大名鼎鼎的kaggle,不过虽然能用,但注册的时候还是存在你懂的环节,自行解决不再赘述。所以只要用云平台,上面这些繁琐的步骤其实统统不用……

训练

  • 首先加载一些程序库
    • torch
    • torchvision
    • matplotlib
      1. import torch
      2. import torchvision
      3. from torch.utils.data import DataLoader
      4. import torch.nn as nn
      5. import torch.nn.functional as F
      6. import torch.optim as optim
      7. import matplotlib.pyplot as plt
  1. 加载数据集

    • python 下可以通过 torchvision 库直接从网络下载数据集到本地,然后自动加载,前面说下载贼旧就是因为这个minst数据集。
    • pytorch 里面加载数据有一套范式,一般是通过 torch.utils.data.DataLoader ,可以控制在训练的时候控制每一轮输出的数据量
    • 以加载训练集为例:
      1. train_loader = torch.utils.data.DataLoader(
      2. torchvision.datasets.MNIST(
      3. "./data/",
      4. train=True,
      5. download=True,
      6. transform=torchvision.transforms.Compose(
      7. [
      8. torchvision.transforms.ToTensor(),
      9. torchvision.transforms.Normalize((0.1307,), (0.3081,)),
      10. ]
      11. ),
      12. ),
      13. batch_size=batch_size_train,
      14. shuffle=True,
      15. )
  2. 加载数据集

    • minst 可以应用非常典型的网络结构,比如 两个卷积->全连接层->relu->droupout->全连接层

      1. class Net(nn.Module):
      2. def __init__(self):
      3. super(Net, self).__init__()
      4. self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
      5. self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
      6. self.conv2_drop = nn.Dropout2d()
      7. self.fc1 = nn.Linear(320, 50)
      8. self.fc2 = nn.Linear(50, 10)
      9. def forward(self, x):
      10. x = F.relu(F.max_pool2d(self.conv1(x), 2))
      11. x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
      12. x = x.view(-1, 320)
      13. x = F.relu(self.fc1(x))
      14. x = F.dropout(x, training=self.training)
      15. x = self.fc2(x)
      16. return F.log_softmax(x, dim=1)
    • 配置一下优化器,这里可选收敛比较快的SGD,也可以用 Adam 反正minst训练也很快

  1. - Adam收敛速度:
  1. - 损失函数,这里没有定义损失函数,是直接在训练里面用 nll_loss (Negative Log Likelihood Loss) 说到底就是向量对数
  1. 训练
    1. def train(epoch):
    2. network.train()
    3. for batch_idx, (data, target) in enumerate(train_loader):
    4. optimizer.zero_grad()
    5. output = network(data)
    6. loss = F.nll_loss(output, target)
    7. loss.backward()
    8. optimizer.step()
    • 训练也很简单,非常标准的流程:数据对齐->前向传播->计算损失->反向传播->优化器计算
    • 5条语句,5个功能,完事
  2. 测试
    1. def test():
    2. network.eval()
    3. test_loss = 0
    4. correct = 0
    5. with torch.no_grad():
    6. for data, target in test_loader:
    7. output = network(data)
    8. pred = output.data.max(1, keepdim=True)[1]
  3. 显示
    1. examples = enumerate(test_loader)
    2. batch_idx, (example_data, example_targets) = next(examples)
    3. with torch.no_grad():
    4. output = network(example_data)
    5. fig = plt.figure()
    6. for i in range(6):
    7. plt.subplot(2, 3, i + 1)
    8. plt.tight_layout()
    9. plt.imshow(example_data<i>[0], cmap="gray", interpolation="none")
    10. plt.title("Prediction: {}".format(output.data.max(1, keepdim=True)[1]<i>.item()))
    11. plt.xticks([])
    12. plt.yticks([])
    13. plt.show()
  1. 模型导出

    1. torch.save(network.state_dict(), "./model.pth")
    2. torch.save(optimizer.state_dict(), "./optimizer.pth")
  2. onnx转化

    1. def export_to_onnx(model, input_example, output_path="model.onnx"):
    2. model.eval()
    3. torch.onnx.export(
    4. model,
    5. input_example,
    6. output_path,
    7. opset_version=11,
    8. do_constant_folding=True,
    9. input_names=["input"],
    10. output_names=["output"],
    11. dynamic_axes={
    12. "input": {0: "batch_size"},
    13. "output": {0: "batch_size"}
    14. },
    15. )
    16. print("Model exported to ONNX format successfully.")
    17. input_example = torch.randn(1, 28, 28)
    18. export_to_onnx(network, input_example=input_example)

model.onnx

(88.09 KB, 下载次数: 2)

model.pth

(88.16 KB, 下载次数: 0)

optimizer.pth

(177.19 KB, 下载次数: 0)

最新回复

牛     详情 回复 发表于 2024-4-14 22:59
点赞 关注

回复
举报

64

帖子

0

TA的资源

一粒金砂(中级)

沙发
 

 

 
 

回复
您需要登录后才可以回帖 登录 | 注册

随便看看
查找数据手册?

EEWorld Datasheet 技术支持

相关文章 更多>>
关闭
站长推荐上一条 2/8 下一条
Microchip 直播|利用motorBench开发套件高效开发电机磁场定向控制方案 报名中!
直播主题:利用motorBench开发套件高效开发电机磁场定向控制方案
直播时间:2025年3月25日(星期二)上午10:30-11:30
快来报名!

查看 »

 
EEWorld订阅号

 
EEWorld服务号

 
汽车开发圈

 
机器人开发圈

About Us 关于我们 客户服务 联系方式 器件索引 网站地图 最新更新 手机版

站点相关: 国产芯 安防电子 汽车电子 手机便携 工业控制 家用电子 医疗电子 测试测量 网络通信 物联网

北京市海淀区中关村大街18号B座15层1530室 电话:(010)82350740 邮编:100190

电子工程世界版权所有 京B2-20211791 京ICP备10001474号-1 电信业务审批[2006]字第258号函 京公网安备 11010802033920号 Copyright © 2005-2025 EEWORLD.com.cn, Inc. All rights reserved
快速回复 返回顶部 返回列表