|
#AI挑战营第一站# pytorch环境+minst数据集训练
[复制链接]
本帖最后由 tinnu 于 2024-4-13 18:25 编辑
训练
- 首先加载一些程序库
- torch
- torchvision
- matplotlib
import torch
import torchvision
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
加载数据集
- python 下可以通过 torchvision 库直接从网络下载数据集到本地,然后自动加载,前面说下载贼旧就是因为这个minst数据集。
- pytorch 里面加载数据有一套范式,一般是通过 torch.utils.data.DataLoader ,可以控制在训练的时候控制每一轮输出的数据量
- 以加载训练集为例:
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST(
"./data/",
train=True,
download=True,
transform=torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,)),
]
),
),
batch_size=batch_size_train,
shuffle=True,
)
加载数据集
- 损失函数,这里没有定义损失函数,是直接在训练里面用 nll_loss (Negative Log Likelihood Loss) 说到底就是向量对数
- 训练
def train(epoch):
network.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = network(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
- 训练也很简单,非常标准的流程:数据对齐->前向传播->计算损失->反向传播->优化器计算
- 5条语句,5个功能,完事
- 测试
def test():
network.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = network(data)
pred = output.data.max(1, keepdim=True)[1]
- 显示
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
with torch.no_grad():
output = network(example_data)
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(example_data<i>[0], cmap="gray", interpolation="none")
plt.title("Prediction: {}".format(output.data.max(1, keepdim=True)[1]<i>.item()))
plt.xticks([])
plt.yticks([])
plt.show()
模型导出
torch.save(network.state_dict(), "./model.pth")
torch.save(optimizer.state_dict(), "./optimizer.pth")
onnx转化
def export_to_onnx(model, input_example, output_path="model.onnx"):
model.eval()
torch.onnx.export(
model,
input_example,
output_path,
opset_version=11,
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size"},
"output": {0: "batch_size"}
},
)
print("Model exported to ONNX format successfully.")
input_example = torch.randn(1, 28, 28)
export_to_onnx(network, input_example=input_example)
|
|