407|2

194

帖子

4

TA的资源

纯净的硅(初级)

楼主
 

动手学深度学习(PyTorch版)- 【读书活动-心得分享】使用Softmax实现图像分类 [复制链接]

 

简介

 

前几个章节我们实现了线性回归来对某一个数据值进行预测,本章节我们学习下如何使用softmax进行数据的分类。

 

这个章节的数学知识有点难, 我不是很懂具体的数学是怎么计算的, 但是大概要实现什么效果我还是可以理解的。 我就我自己的理解进行代码的分析。

 

trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)

上述代码定义了一个数据集,即MNIST, 当执行上述的代码的时候会自动从http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/ 上下载对应的训练数据和测试数据

 

下载之后的训练数据集和测试数据集分别为

 

 

有图这个图像的数据只是为28 * 28 单通道。 所以每个图片的形状如下所示

 

之后定义了一个获取图像内所有的标签和显示图片的方法。

 

 

图像的标签数据其实在下载的时候就已经包含了。 这里不理解的是为什么这个Tshirt的和下载的时候的类型不一致

 

 

 

图像绘制的方法我这里也没有纠结, 因为看不懂是怎么绘制的。

 

 

之后我们可以加载小批量的测试数据进行数据的加载和展示。如果你使用的是pycharm 或者 dataspell, 记得勾选上述的文件为信任的 , 否则会出现图片无法绘制的异常

 

 

上图为信任模式下的正常绘制。 一共读取10 个数据, 两行 5 列。

 

 

后面的交叉熵和模型就开始有点听不懂了,于是对从0实现的softmax就没办法继续了。 之后我尝试了书里提供的使用pytorch实现的 softmax实现。

 

 

 

这里如果你使用的是pycharm 或者dataspell的话也会出现图片不加载的情况。 可以直接使用Jupyter来打开。

 

这里就很好理解了。 我只需要知道怎么使用torch的API即可, 我不需要去关注具体的算法实现。 然后通过torch的内部定义的损失函数等等来训练我的模型。

 

不过这里, 我全都是用CPU跑的, 我还尝试问GPT怎么转移到GPU上。但是没有成功。下图可以看到CPU的使用情况。E52666V3 十个核最高占用能跑到80%

 

 

 

通过下面的代码保存整个模型后,下次就不需要再训练了。

# 保存整个模型
torch.save(net, "softmax.pth")

 

 

模型大小大概33KB。

 

我们尝试加载这个模型然后随便找个图形来试一下是否能正常是被。

 

 

上图为一个chatgpt生成的28*28的图片。

 

 

它把我的T恤分类成包了。

 

 

 

加载训练数据的似乎没有问题。 不知道为什么没办法准确的识别除了训练和测试集外的数据。

 

 

import torch
from torch import nn
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt

# 定义Fashion-MNIST类别标签
def get_fashion_mnist_labels(labels):
    """返回Fashion-MNIST数据集的文本标签"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

# 1. 定义模型结构并加载参数
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))
net = torch.load("softmax.pth")
net.eval()  # 将模型设置为评估模式

# 2. 定义图像预处理
transform = transforms.Compose([
    transforms.Grayscale(),  # 转换为灰度图像
    transforms.Resize((28, 28)),  # 调整大小为 28x28
    transforms.ToTensor()  # 转换为 tensor 并且归一化到 [0, 1]
])

# 3. 加载并预处理图像
image_path = "1.png"  # 替换为你的图片路径
image = Image.open(image_path)
image = transform(image)  # 应用预处理
image = image.unsqueeze(0)  # 添加批量维度,使尺寸为 (1, 1, 28, 28)

# 4. 将图像传入模型进行预测
with torch.no_grad():  # 禁用梯度计算,节省内存
    logits = net(image)  # 获取输出的 logits
    pred_idx = torch.argmax(logits, dim=1).item()  # 找到概率最大的类别
    pred_label = get_fashion_mnist_labels([pred_idx])[0]  # 获取类别标签

# 5. 输出结果
print(f"预测类别: {pred_label}")

# 6. 显示图像和预测标签
plt.imshow(image.squeeze().numpy(), cmap="gray")
plt.title(f"预测类别: {pred_label}")
plt.axis("off")
plt.show()

 

鞋子的测试也错了, 不知道是因为我的图片不像一个鞋子还是怎么了

  (上图一共训练了20轮)

 

softmax.pth (32.35 KB, 下载次数: 2)

最新回复

知道怎么使用torch的API,通过torch的内部定义的损失函数等等来训练模型,思路清晰,学习了   详情 回复 发表于 2024-11-6 07:29
点赞 关注
 
 

回复
举报

6828

帖子

0

TA的资源

五彩晶圆(高级)

沙发
 

知道怎么使用torch的API,通过torch的内部定义的损失函数等等来训练模型,思路清晰,学习了

点评

书中的原生实现的太难了. 看不懂  详情 回复 发表于 2024-11-6 12:20
 
 
 

回复

194

帖子

4

TA的资源

纯净的硅(初级)

板凳
 
Jacktang 发表于 2024-11-6 07:29 知道怎么使用torch的API,通过torch的内部定义的损失函数等等来训练模型,思路清晰,学习了

书中的原生实现的太难了. 看不懂

 
 
 

回复
您需要登录后才可以回帖 登录 | 注册

随便看看
查找数据手册?

EEWorld Datasheet 技术支持

相关文章 更多>>
关闭
站长推荐上一条 1/7 下一条

 
EEWorld订阅号

 
EEWorld服务号

 
汽车开发圈

About Us 关于我们 客户服务 联系方式 器件索引 网站地图 最新更新 手机版

站点相关: 国产芯 安防电子 汽车电子 手机便携 工业控制 家用电子 医疗电子 测试测量 网络通信 物联网

北京市海淀区中关村大街18号B座15层1530室 电话:(010)82350740 邮编:100190

电子工程世界版权所有 京B2-20211791 京ICP备10001474号-1 电信业务审批[2006]字第258号函 京公网安备 11010802033920号 Copyright © 2005-2025 EEWORLD.com.cn, Inc. All rights reserved
快速回复 返回顶部 返回列表