#AI挑战营第一站# 基于pytorch的MNIST手写数字识别
[复制链接]
本帖最后由 不语arc 于 2024-4-13 20:07 编辑
本机配置:CPU训练,torch2.0版本
训练代码与生成的模型见附件
训练过程可以简单概括为三步:1 数据集加载,2 实例化模型,3 循环训练
1. 数据集预处理
1.1 数据集的划分
在深度学习领域,通常来说,数据集会划分成训练集、验证集、测试集。
其中训练集用于训练模型;验证集用于训练过程中,在每训练一轮之后进行验证计算指标,判断模型的好坏;测试集用于已训练好模型的指标计算,同样用于判断模型好坏。
对MNIST数据集,有60000+10000共70000张图,其中有10000张图被指定为测试集。我将训练集和验证集划分比例为9:1,因此训练集有54000张,验证集有6000张。
即 train:val:test = 54:6:10。
1.2 MNIST数据集下载
在torchvision中内置了MNIST数据集的调用函数,具体来说是下面这条指令。其中root表示保存目录,train=True表示使用训练集。download表示如果本地没有MNIST就从服务器下载,transform是对图片的预处理工作(如尺寸resize、rgb转灰度图等)
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
1.3 代码实现 数据集加载
# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 将10%的数据用作验证集
percent = 0.1
num_train = len(train_dataset)
indices = list(range(num_train))
split = int(np.floor(percent * num_train))
train_idx, val_idx = random_split(indices, [num_train - split, split])
train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
val_sampler = torch.utils.data.SubsetRandomSampler(val_idx)
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)
val_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=val_sampler)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
2. 实例化模型
这里定义了模型结构,输入图片(1,1,28,28)经过 两个卷积层、被卷平后经过dropout和两个全连接网络,会输出一个1x10的张量。10对应了10种可能结果,哪个数字最大代表这张图片属于哪一类,从而完成了手写数字的识别。
# 定义模型
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.drop_out = nn.Dropout()
self.fc1 = nn.Linear(7 * 7 * 64, 1000)
self.fc2 = nn.Linear(1000, num_classes)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = out.reshape(out.size(0), -1)
out = self.drop_out(out)
out = self.fc1(out)
out = self.fc2(out)
return out
3. 循环训练
指定训练次数epoch,对每一轮训练,都会进行:前向传播计算损失、反向传播梯度下降更新权重。在每一轮训练好之后,再计算 验证集的指标,当指标达到最优时保存该轮的权重。
# 训练循环
for epoch in range(num_epochs):
model.train() # 设置模型为训练模式
running_loss = 0.0
corrects = 0
for inputs, labels in train_loader:
optimizer.zero_grad() # 清空之前的梯度
outputs = model(inputs) # 获取模型输出
loss = criterion(outputs, labels) # 计算损失
loss.backward() # 反向传播计算梯度
optimizer.step() # 使用优化器更新权重
_, preds = torch.max(outputs, 1) # 获取预测结果
corrects += torch.sum(preds == labels.data) # 计算正确预测的数量
running_loss += loss.item() # 累加损失
epoch_loss = running_loss / len(train_loader)
epoch_acc = corrects.double() / train_total
# 验证模型性能
model.eval() # 设置模型为评估模式
val_loss = 0.0
val_corrects = 0
with torch.no_grad(): # 不计算梯度
for inputs, labels in val_loader:
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, preds = torch.max(outputs, 1)
val_corrects += torch.sum(preds == labels.data)
val_loss = val_loss / len(val_loader)
val_acc = val_corrects.double() / val_total
# 打印训练结果和验证结果
print(
f'Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
# 如果当前验证准确率比之前的最佳准确率要好,保存当前权重
if val_acc > best_acc:
best_acc = val_acc
best_model_wts = model.state_dict() # 获取模型权重
4. 模型测试
训练结束后,还需要使用得到的最优模型在测试集上验证效果。并将权重保存成pth文件。
下图是训练过程,Epoch表示训练轮数,这里总次数被指定为20。Loss:训练集的损失;Train Acc:训练集的准确度;Val Loss:验证集的损失;Val Acc:验证集的准确度。
红框是最后的测试集表现,手写数字识别准确度达到99.22%。
5. 模型推理
推理过程,是指输入全新图片,导入训练的权重,对手写数字图片仅进行前向传播完成识别。
数据预处理:
自己手写的图片大致长这样:
数据集MNIST图片长这样:
区别:我是白底黑字,数据集是黑底白字。所以在图像预处理部分,除了要将图片resize、灰度化还需要取反。
6. pth模型转onnx
调用onnx接口,model为实例化的模型,dummy_imput为指定输入尺寸,onnx_file_path指定onnx输出路径。
# 导出模型为ONNX格式
torch.onnx.export(model, dummy_input, onnx_file_path,
input_names=['input'], output_names=['output'], # 输入和输出节点名称
opset_version=12, # 使用的ONNX操作集版本,选择与您PyTorch版本兼容的一个
do_constant_folding=True, # 是否执行常量折叠优化
export_params=True, # 是否包含权重
verbose=False, # 是否打印转换细节
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}) # 设置动态轴,允许在推理时调整batch_size
|