动手学深度学习(PyTorch版)- 【读书活动-心得分享】使用Softmax实现图像分类
<p><strong><span style="font-size:22px;">简介</span></strong></p><p> </p>
<p>前几个章节我们实现了线性回归来对某一个数据值进行预测,本章节我们学习下如何使用softmax进行数据的分类。</p>
<p> </p>
<p>这个章节的数学知识有点难, 我不是很懂具体的数学是怎么计算的, 但是大概要实现什么效果我还是可以理解的。 我就我自己的理解进行代码的分析。</p>
<p> </p>
<pre>
<code class="language-python">trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)</code></pre>
<p>上述代码定义了一个数据集,即MNIST, 当执行上述的代码的时候会自动从<a href="http://fashion-mnist.s3-website.eu-central-1.amazonaws.com" target="_blank">http://fashion-mnist.s3-website.eu-central-1.amazonaws.com</a>/ 上下载对应的训练数据和测试数据</p>
<p> </p>
<p>下载之后的训练数据集和测试数据集分别为</p>
<p> </p>
<p> </p>
<p>有图这个图像的数据只是为28 * 28 单通道。 所以每个图片的形状如下所示</p>
<p> </p>
<p>之后定义了一个获取图像内所有的标签和显示图片的方法。</p>
<p> </p>
<p> </p>
<p>图像的标签数据其实在下载的时候就已经包含了。 这里不理解的是为什么这个Tshirt的和下载的时候的类型不一致</p>
<p> </p>
<p> </p>
<p> </p>
<p>图像绘制的方法我这里也没有纠结, 因为看不懂是怎么绘制的。</p>
<p> </p>
<p> </p>
<p>之后我们可以加载小批量的测试数据进行数据的加载和展示。<span style="color:#e74c3c;"><strong>如果你使用的是pycharm 或者 dataspell, 记得勾选上述的文件为信任的 , 否则会出现图片无法绘制的异常</strong></span></p>
<p> </p>
<p> </p>
<p>上图为信任模式下的正常绘制。 一共读取10 个数据, 两行 5 列。</p>
<p> </p>
<p></p>
<p> </p>
<p>后面的交叉熵和模型就开始有点听不懂了,于是对从0实现的softmax就没办法继续了。 之后我尝试了书里提供的使用pytorch实现的 softmax实现。</p>
<p> </p>
<p> </p>
<p> </p>
<p>这里如果你使用的是pycharm 或者dataspell的话也会出现图片不加载的情况。 可以直接使用Jupyter来打开。</p>
<p> </p>
<p>这里就很好理解了。 我只需要知道怎么使用torch的API即可, 我不需要去关注具体的算法实现。 然后通过torch的内部定义的损失函数等等来训练我的模型。</p>
<p> </p>
<p>不过这里, 我全都是用CPU跑的, 我还尝试问GPT怎么转移到GPU上。但是没有成功。下图可以看到CPU的使用情况。E52666V3 十个核最高占用能跑到80%</p>
<p> </p>
<p> </p>
<p> </p>
<p>通过下面的代码保存整个模型后,下次就不需要再训练了。</p>
<pre>
<code class="language-python"># 保存整个模型
torch.save(net, "softmax.pth")</code></pre>
<p> </p>
<p>模型大小大概33KB。</p>
<p> </p>
<p>我们尝试加载这个模型然后随便找个图形来试一下是否能正常是被。</p>
<p> </p>
<p> </p>
<p>上图为一个chatgpt生成的28*28的图片。</p>
<p> </p>
<p> </p>
<p>它把我的T恤分类成包了。</p>
<p> </p>
<p> </p>
<p> </p>
<p>加载训练数据的似乎没有问题。 不知道为什么没办法准确的识别除了训练和测试集外的数据。</p>
<p> </p>
<p> </p>
<pre>
<code class="language-python">import torch
from torch import nn
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
# 定义Fashion-MNIST类别标签
def get_fashion_mnist_labels(labels):
"""返回Fashion-MNIST数据集的文本标签"""
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return for i in labels]
# 1. 定义模型结构并加载参数
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))
net = torch.load("softmax.pth")
net.eval()# 将模型设置为评估模式
# 2. 定义图像预处理
transform = transforms.Compose([
transforms.Grayscale(),# 转换为灰度图像
transforms.Resize((28, 28)),# 调整大小为 28x28
transforms.ToTensor()# 转换为 tensor 并且归一化到
])
# 3. 加载并预处理图像
image_path = "1.png"# 替换为你的图片路径
image = Image.open(image_path)
image = transform(image)# 应用预处理
image = image.unsqueeze(0)# 添加批量维度,使尺寸为 (1, 1, 28, 28)
# 4. 将图像传入模型进行预测
with torch.no_grad():# 禁用梯度计算,节省内存
logits = net(image)# 获取输出的 logits
pred_idx = torch.argmax(logits, dim=1).item()# 找到概率最大的类别
pred_label = get_fashion_mnist_labels()# 获取类别标签
# 5. 输出结果
print(f"预测类别: {pred_label}")
# 6. 显示图像和预测标签
plt.imshow(image.squeeze().numpy(), cmap="gray")
plt.title(f"预测类别: {pred_label}")
plt.axis("off")
plt.show()
</code></pre>
<p> </p>
<p>鞋子的测试也错了, 不知道是因为我的图片不像一个鞋子还是怎么了<br />
<br />
(上图一共训练了20轮)</p>
<p> </p>
<div></div>
<p>知道怎么使用torch的API,通过torch的内部定义的损失函数等等来训练模型,思路清晰,学习了</p>
Jacktang 发表于 2024-11-6 07:29
知道怎么使用torch的API,通过torch的内部定义的损失函数等等来训练模型,思路清晰,学习了
<p>书中的原生实现的太难了. 看不懂</p>
页:
[1]