#AI挑战营第一站#基于pytorch模型mnist数据及训练测试
[复制链接]
本帖最后由 wakojosin 于 2024-4-15 20:39 编辑
环境
数据准备
data_train = MNIST('./data/mnist',
download=True,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()]))
data_test = MNIST('./data/mnist',
train=False,
download=True,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()]))
data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8)
data_test_loader = DataLoader(data_test, batch_size=1024, num_workers=8)
LeNet模型构建
'''
卷积神将网络的计算公式为:
N=(W-F+2P)/S+1
其中
N:输出大小
W:输入大小
F:卷积核大小
P:填充值的大小
S:步长大小
'''
net = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5,5)), #卷积层提特征,6*28*28
nn.Sigmoid(), #激活函数:多项式提供非线性
nn.AvgPool2d(kernel_size=(5,5), stride=2), #下采样,(28-5+2*0)/2+1=>6*12*12
nn.Conv2d(in_channels=6, out_channels=16, kernel_size=(5,5)),#16*8*8
nn.Sigmoid(),
nn.AvgPool2d(kernel_size=(5,5), stride=2),#(8-5+2*0)/2+1=>16*2*2
nn.Flatten(),
nn.Linear(16*2*2, 120), nn.Sigmoid(),#全连接:固定特征维度,激活函数:提供非线性网络
nn.Linear(120, 84), nn.Sigmoid(),
nn.Linear(84, 10)
)
训练
def train(epoch):
net.train() #训练模式
loss_list, batch_list = [], []
for i, (images, labels) in enumerate(data_train_loader):
optimizer.zero_grad()
output = net(images) #喂数据
loss = criterion(output, labels) #计算loss
loss_list.append(loss.detach().cpu().item())
batch_list.append(i+1)
loss.backward() #返传
optimizer.step() #优化器参数迭代
模型验证及转换
def test(epoch):
net.eval() #测试模式
total_correct = 0
avg_loss = 0.0
for i, (images, labels) in enumerate(data_test_loader):
output = net(images)
avg_loss += criterion(output, labels).sum()
pred = output.detach().max(1)[1]
total_correct += pred.eq(labels.view_as(pred)).sum()
avg_loss /= len(data_test)
print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.detach().cpu().item(), float(total_correct) / len(data_test)))
torch.save(net.state_dict(), "pth/epoch{}_accuracy{}.pth".format(epoch, float(total_correct) / len(data_test)))
dummy_input = torch.randn(1, 1, 32, 32, requires_grad=True) #batch_size=1,channels=1,w*h=32*32
torch.onnx.export(net, dummy_input, "lenet.onnx") #先推理后根据tensor把模型图变成静态
onnx_model = onnx.load("lenet.onnx")
onnx.checker.check_model(onnx_model) #检查onnx模型
验证输出结果: Test Avg. Loss: 0.000170, Accuracy: 0.945300
模型图
小结
基于经典的LeNet进行MNIST数据集训练及验证,学习了整体流程.
模型见附件.
|