#AI挑战营第一站#基于PyTorch,在PC上完成MNIST手写数字识别模型训练
[复制链接]
#AI挑战营第一站# 手写数字的识别
本人是AI菜鸟,跟着大佬后面跑。跟着葫芦画瓢。
首先导入相应的库:
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
其次,准备数据集:
batch_size = 64
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./dataset/mnist/', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_dataset = datasets.MNIST(root='./dataset/mnist/', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)
接下来定义网络结构:
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)
self.pooling = torch.nn.MaxPool2d(2)
self.fc = torch.nn.Linear(320, 10)
def forward(self, x):
# flatten data from (n,1,28,28) to (n, 784)
batch_size = x.size(0)
x = F.relu(self.pooling(self.conv1(x)))
x = F.relu(self.pooling(self.conv2(x)))
x = x.view(batch_size, -1)
# print("x.shape",x.shape)
x = self.fc(x)
return x
model = Net()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
# construct loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
定义训练和测试函数:
def train(epoch):
running_loss = 0.0
for batch_idx, data in enumerate(train_loader, 0):
inputs, target = data
inputs, target = inputs.to(device), target.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
if batch_idx % 300 == 299:
print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
running_loss = 0.0
def test():
correct = 0
total = 0
with torch.no_grad():
for data in test_loader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, dim=1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('accuracy on test set: %d %% ' % (100 * correct / total))
return correct / total
进行训练:
if __name__ == '__main__':
epoch_list = []
acc_list = []
for epoch in range(20):
train(epoch)
acc = test()
epoch_list.append(epoch)
acc_list.append(acc)
plt.plot(epoch_list, acc_list)
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.show()
模型导出为pth文件:
#Save Model as pth format:
torch.save(model.state_dict(), 'mnist_101_model.pth')
导出Onnx文件:
#导出为onnx模型
dummy_input = torch.randn(1, 1, 28, 28)
torch.onnx.export(model, dummy_input.to(device), "mnist_101_model.onnx", verbose=False)
用新数据测试:
测试脚本:
import torch
import torchvision.transforms as transforms
from torchvision.transforms.functional import resize
from PIL import Image
from mnist_gpu_version import Net
#找到测试文件夹下面的png图片
import os, glob
# 定义图像预处理函数
def preprocess_image(image_path):
image = Image.open(image_path).convert("L")
image = resize(image, (28, 28)) # 将图像大小调整为模型接受的大小 (28x28)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
image = transform(image).unsqueeze(0) # 增加一维表示批次 (batch)
return image
def predict_image(image_path):
# 进行图像识别并输出结果
predicted_label = recognize_image(image_path)
actual_number = png_file.split('.')[0].split('_')[1]
if(int(actual_number) == int(predicted_label)):
compareResult = "Matched"
else:
compareResult = "DisMatch"
print(f"Actual Number is: {actual_number} " + f" 预测的标签 = {predicted_label} " + compareResult)
# 创建模型实例并加载预训练模型参数
model = Net()
model.load_state_dict(torch.load('mnist_101_model.pth'))
model.eval()
# 设置设备(GPU 或 CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 进行图像识别
def recognize_image(image_path):
image = preprocess_image(image_path).to(device)
output = model(image)
_, predicted = torch.max(output.data, 1)
return predicted.item()
# 指定要识别的图片路径
folder_path = 'images'
#使用glob模块找到所有的Png文件
png_files_path = glob.glob(os.path.join(folder_path, "*.png"))
#遍历处理所有png文件
for png_file in png_files_path:
predict_image(png_file)
测试结果:
|