安装pytorch
加载程序库
import torch
import os
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
定义超参数
input_size = 1 * 28 * 28
num_classes = 10
num_epochs = 20
batch_size = 256
learning_rate = 0.0001
定义CNN模型
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.drop_out = nn.Dropout()
self.fc1 = nn.Linear(7 * 7 * 64, 1000)
self.fc2 = nn.Linear(1000, num_classes)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = out.reshape(out.size(0), -1)
out = self.drop_out(out)
out = self.fc1(out)
out = self.fc2(out)
return out
训练模型
# 训练循环
for epoch in range(num_epochs):
model.train() # 设置模型为训练模式
running_loss = 0.0
corrects = 0
for inputs, labels in train_loader:
optimizer.zero_grad() # 清空之前的梯度
outputs = model(inputs) # 获取模型输出
loss = criterion(outputs, labels) # 计算损失
loss.backward() # 反向传播计算梯度
optimizer.step() # 使用优化器更新权重
_, preds = torch.max(outputs, 1) # 获取预测结果
corrects += torch.sum(preds == labels.data) # 计算正确预测的数量
running_loss += loss.item() # 累加损失
epoch_loss = running_loss / len(train_loader)
epoch_acc = corrects.double() / train_total
# 验证模型性能
model.eval() # 设置模型为评估模式
val_loss = 0.0
val_corrects = 0
with torch.no_grad(): # 不计算梯度
for inputs, labels in val_loader:
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, preds = torch.max(outputs, 1)
val_corrects += torch.sum(preds == labels.data)
val_loss = val_loss / len(val_loader)
val_acc = val_corrects.double() / val_total
# 打印训练结果和验证结果
print(
f'Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
# 如果当前验证准确率比之前的最佳准确率要好,保存当前权重
if val_acc > best_acc:
best_acc = val_acc
best_model_wts = model.state_dict() # 获取模型权重
测试模型
def test():
network.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = network(data)
pred = output.data.max(1, keepdim=True)[1]