动手学深度学习(PyTorch版)- 【读书活动-心得分享】使用Softmax实现图像分类
[复制链接]
简介
前几个章节我们实现了线性回归来对某一个数据值进行预测,本章节我们学习下如何使用softmax进行数据的分类。
这个章节的数学知识有点难, 我不是很懂具体的数学是怎么计算的, 但是大概要实现什么效果我还是可以理解的。 我就我自己的理解进行代码的分析。
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)
上述代码定义了一个数据集,即MNIST, 当执行上述的代码的时候会自动从http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/ 上下载对应的训练数据和测试数据
下载之后的训练数据集和测试数据集分别为
有图这个图像的数据只是为28 * 28 单通道。 所以每个图片的形状如下所示
之后定义了一个获取图像内所有的标签和显示图片的方法。
图像的标签数据其实在下载的时候就已经包含了。 这里不理解的是为什么这个Tshirt的和下载的时候的类型不一致
图像绘制的方法我这里也没有纠结, 因为看不懂是怎么绘制的。
之后我们可以加载小批量的测试数据进行数据的加载和展示。如果你使用的是pycharm 或者 dataspell, 记得勾选上述的文件为信任的 , 否则会出现图片无法绘制的异常
上图为信任模式下的正常绘制。 一共读取10 个数据, 两行 5 列。
后面的交叉熵和模型就开始有点听不懂了,于是对从0实现的softmax就没办法继续了。 之后我尝试了书里提供的使用pytorch实现的 softmax实现。
这里如果你使用的是pycharm 或者dataspell的话也会出现图片不加载的情况。 可以直接使用Jupyter来打开。
这里就很好理解了。 我只需要知道怎么使用torch的API即可, 我不需要去关注具体的算法实现。 然后通过torch的内部定义的损失函数等等来训练我的模型。
不过这里, 我全都是用CPU跑的, 我还尝试问GPT怎么转移到GPU上。但是没有成功。下图可以看到CPU的使用情况。E52666V3 十个核最高占用能跑到80%
通过下面的代码保存整个模型后,下次就不需要再训练了。
# 保存整个模型
torch.save(net, "softmax.pth")
模型大小大概33KB。
我们尝试加载这个模型然后随便找个图形来试一下是否能正常是被。
上图为一个chatgpt生成的28*28的图片。
它把我的T恤分类成包了。
加载训练数据的似乎没有问题。 不知道为什么没办法准确的识别除了训练和测试集外的数据。
import torch
from torch import nn
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
# 定义Fashion-MNIST类别标签
def get_fashion_mnist_labels(labels):
"""返回Fashion-MNIST数据集的文本标签"""
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
# 1. 定义模型结构并加载参数
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))
net = torch.load("softmax.pth")
net.eval() # 将模型设置为评估模式
# 2. 定义图像预处理
transform = transforms.Compose([
transforms.Grayscale(), # 转换为灰度图像
transforms.Resize((28, 28)), # 调整大小为 28x28
transforms.ToTensor() # 转换为 tensor 并且归一化到 [0, 1]
])
# 3. 加载并预处理图像
image_path = "1.png" # 替换为你的图片路径
image = Image.open(image_path)
image = transform(image) # 应用预处理
image = image.unsqueeze(0) # 添加批量维度,使尺寸为 (1, 1, 28, 28)
# 4. 将图像传入模型进行预测
with torch.no_grad(): # 禁用梯度计算,节省内存
logits = net(image) # 获取输出的 logits
pred_idx = torch.argmax(logits, dim=1).item() # 找到概率最大的类别
pred_label = get_fashion_mnist_labels([pred_idx])[0] # 获取类别标签
# 5. 输出结果
print(f"预测类别: {pred_label}")
# 6. 显示图像和预测标签
plt.imshow(image.squeeze().numpy(), cmap="gray")
plt.title(f"预测类别: {pred_label}")
plt.axis("off")
plt.show()
鞋子的测试也错了, 不知道是因为我的图片不像一个鞋子还是怎么了
(上图一共训练了20轮)
|