cc1989summer 发表于 2024-10-4 00:20

《动手学深度学习PyTorch版》阅读分享四 手写识别小试牛刀(基于CNN)

<div class='showpostmsg'> 本帖最后由 cc1989summer 于 2024-10-4 00:29 编辑

<p><p>&nbsp;</p></p>
<p></p>
<p><p>上一篇我们学习了卷积神经网络(CNN)的概念与原理</p></p>
<p></p>
<p><p><a&nbsp;href="https://bbs.eeworld.com.cn/thread-1295074-1-1.html"&nbsp;target="_blank">https://bbs.eeworld.com.cn/thread-1295074-1-1.html</a></p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>下面是运用卷积神经网络,进行经典的手写数字识别的过程。</p></p>
<p></p>
<p><p>&nbsp;&nbsp;</p></p>
<p></p>
<p><p>输入是一张28*28像素的黑白照片。经过卷积、池化、激活、全链接,最后输出0~10个数字的概率。</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>但说了那么多,理论和实践总有隔阂,不如跑个例子实践实践,加深认识。</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p><strong>在开始Pytorch进行手写识别前,首先安装Python、Anaconda、与Pytorch,设置好虚拟环境。</strong></p></p>
<p></p>
<p><p>详见:</p></p>
<p></p>
<p><p><a&nbsp;href="https://bbs.eeworld.com.cn/thread-1294276-1-1.html"&nbsp;target="_blank">https://bbs.eeworld.com.cn/thread-1294276-1-1.html</a></p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>接下来安装Pycharm软件:(免费社区版)</p></p>
<p></p>
<p><p><a&nbsp;href="https://www.jetbrains.com/pycharm/download/?section=windows"&nbsp;target="_blank">https://www.jetbrains.com/pycharm/download/?section=windows</a></p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>&nbsp;&nbsp;</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>安装完成后需要<strong>关联好前面设置的虚拟环境。</strong></p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p><strong>Anaconda中设置好的虚拟环境:</strong></p></p>
<p></p>
<p><p><img&nbsp;src="https://12.eewimg.cn/bbs/data/attachment/forum/202409/19/230136w55ppee38egpan3k.png"&nbsp;/></p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p><strong>该虚拟环境目录下对应的Python程序</strong></p></p>
<p></p>
<p><p>&nbsp;&nbsp;</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>Pycharm中<strong>关联好前面设置的虚拟环境</strong></p></p>
<p></p>
<p><p>&nbsp;&nbsp;</p></p>
<p></p>
<p><p>&nbsp;&nbsp;</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>设置完成后,进行测试:</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><pre></p>
<p><code>import&nbsp;torch</p>
<p>print(torch.__version__)</code></pre></p>
<p></p>
<p><pre></p>
<p>运行结果正常:&nbsp;&nbsp;</p>
<p></pre></p>
<p></p>
<p><p>接下来安装几个基本的软件包。</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><pre></p>
<p><code>pip&nbsp;install&nbsp;torch&nbsp;torchvision&nbsp;matplotlib</code></pre></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>缺什么就在Pycharm命令行输入安装指令:</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>&nbsp;&nbsp;</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>接下来我们步入正题,进行手写识别。</p></p>
<p></p>
<p><h2>需要用到MNIST数据集</h2></p>
<p></p>
<p><p>MNIST数据集来自美国国家标准与技术研究所,&nbsp;National&nbsp;Institute&nbsp;of&nbsp;Standards&nbsp;and&nbsp;Technology&nbsp;(NIST)。训练集(training&nbsp;set)由来自250个不同人手写的数字构成,其中50%是高中学生,50%来自人口普查局(the&nbsp;Census&nbsp;Bureau)的工作人员。测试集(test&nbsp;set)也是同样比例的手写数字数据,但保证了测试集和训练集的作者集不相交。</p></p>
<p></p>
<p><p>MNIST数据集一共有7万张图片,其中6万张是训练集,1万张是测试集。每张图片是28&nbsp;&times;&nbsp;28&nbsp;28\times&nbsp;2828&times;28的0&nbsp;&minus;&nbsp;9&nbsp;0-90&minus;9的手写数字图片组成。每个图片是黑底白字的形式,黑底用0表示,白字用0-1之间的浮点数表示,越接近1,颜色越白。</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>其实就是下面几个文件,接下来的例程会自动帮我们下载。</p></p>
<p></p>
<p><p>&nbsp;&nbsp;</p></p>
<p></p>
<p><p><span&nbsp;style="font-size:20px;"><strong>第一步,加载MNIST数据集</strong></span></p></p>
<p></p>
<p><pre></p>
<p><code>import&nbsp;torch</p>
<p>from&nbsp;torch.utils.data&nbsp;import&nbsp;DataLoader</p>
<p>from&nbsp;torchvision&nbsp;import&nbsp;datasets,&nbsp;transforms</p>
<p>import&nbsp;matplotlib.pyplot&nbsp;as&nbsp;plt</p>
<p></p>
<p>#&nbsp;数据预处理:将图像转换为张量,并进行标准化</p>
<p>transform&nbsp;=&nbsp;transforms.Compose()</p>
<p></p>
<p>#&nbsp;下载并加载&nbsp;MNIST&nbsp;训练集和测试集</p>
<p>train_dataset&nbsp;=&nbsp;datasets.MNIST(root='./data',&nbsp;train=True,&nbsp;download=True,&nbsp;transform=transform)</p>
<p>test_dataset&nbsp;=&nbsp;datasets.MNIST(root='./data',&nbsp;train=False,&nbsp;download=True,&nbsp;transform=transform)</p>
<p></p>
<p>#&nbsp;加载数据集</p>
<p>train_loader&nbsp;=&nbsp;DataLoader(dataset=train_dataset,&nbsp;batch_size=64,&nbsp;shuffle=True)</p>
<p>test_loader&nbsp;=&nbsp;DataLoader(dataset=test_dataset,&nbsp;batch_size=64,&nbsp;shuffle=False)</p>
<p></p>
<p>#&nbsp;查看数据集的大小</p>
<p>print(f"训练集大小:&nbsp;{len(train_dataset)}")</p>
<p>print(f"测试集大小:&nbsp;{len(test_dataset)}")</p>
<p></p>
<p>#&nbsp;可视化部分样本</p>
<p>examples&nbsp;=&nbsp;enumerate(train_loader)</p>
<p>batch_idx,&nbsp;(example_data,&nbsp;example_targets)&nbsp;=&nbsp;next(examples)</p>
<p>plt.figure(figsize=(10,&nbsp;3))</p>
<p>for&nbsp;i&nbsp;in&nbsp;range(6):</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;plt.subplot(1,&nbsp;6,&nbsp;i&nbsp;+&nbsp;1)</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;plt.imshow(example_data,&nbsp;cmap='gray')</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;plt.title(f"Label:&nbsp;{example_targets}")</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;plt.axis('off')</p>
<p>plt.show()</p>
<p></code></pre></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>以下是运行结果。</p></p>
<p></p>
<p><p>&nbsp;&nbsp;</p></p>
<p></p>
<p><p>还可以查看更多。</p></p>
<p></p>
<p><p>&nbsp;&nbsp;</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p><span&nbsp;style="font-size:20px;"><strong>第二步,构建卷积神经网络</strong></span></p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><pre></p>
<p><code>import&nbsp;torch.nn&nbsp;as&nbsp;nn</p>
<p>import&nbsp;torch.nn.functional&nbsp;as&nbsp;F</p>
<p></p>
<p>#&nbsp;定义&nbsp;CNN&nbsp;模型</p>
<p>class&nbsp;CNN(nn.Module):</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;def&nbsp;__init__(self):</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;super(CNN,&nbsp;self).__init__()</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;#&nbsp;卷积层1:&nbsp;输入通道为1(灰度图),输出通道为16,卷积核大小为3x3</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;self.conv1&nbsp;=&nbsp;nn.Conv2d(1,&nbsp;16,&nbsp;kernel_size=3)</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;#&nbsp;卷积层2:&nbsp;输入通道为16,输出通道为32,卷积核大小为3x3</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;self.conv2&nbsp;=&nbsp;nn.Conv2d(16,&nbsp;32,&nbsp;kernel_size=3)</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;#&nbsp;全连接层1:&nbsp;输入为32*5*5(展平后的特征图),输出为128</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;self.fc1&nbsp;=&nbsp;nn.Linear(32&nbsp;*&nbsp;5&nbsp;*&nbsp;5,&nbsp;128)</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;#&nbsp;全连接层2:&nbsp;输入为128,输出为10(10个类别)</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;self.fc2&nbsp;=&nbsp;nn.Linear(128,&nbsp;10)</p>
<p></p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;def&nbsp;forward(self,&nbsp;x):</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;#&nbsp;卷积层&nbsp;+&nbsp;ReLU&nbsp;+&nbsp;最大池化层</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;x&nbsp;=&nbsp;F.relu(F.max_pool2d(self.conv1(x),&nbsp;2))</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;x&nbsp;=&nbsp;F.relu(F.max_pool2d(self.conv2(x),&nbsp;2))</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;#&nbsp;展平成一维向量</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;x&nbsp;=&nbsp;x.view(-1,&nbsp;32&nbsp;*&nbsp;5&nbsp;*&nbsp;5)</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;#&nbsp;全连接层&nbsp;+&nbsp;ReLU</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;x&nbsp;=&nbsp;F.relu(self.fc1(x))</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;#&nbsp;输出层</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;x&nbsp;=&nbsp;self.fc2(x)</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;return&nbsp;x</p>
<p></p>
<p>#&nbsp;实例化模型</p>
<p>model&nbsp;=&nbsp;CNN()</p>
<p>print(model)</p>
<p></code></pre></p>
<p></p>
<p><p>仔细研究代码,正是我们前面讲到的卷积层、池化、激活、全链接。</p></p>
<p></p>
<p><p>运行结果如下:</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>&nbsp;&nbsp;</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p><span&nbsp;style="font-size:20px;"><strong>第三步,训练模型</strong></span></p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><pre></p>
<p><code>import&nbsp;torch</p>
<p>from&nbsp;torch.utils.data&nbsp;import&nbsp;DataLoader</p>
<p>from&nbsp;torchvision&nbsp;import&nbsp;datasets,&nbsp;transforms</p>
<p>import&nbsp;matplotlib.pyplot&nbsp;as&nbsp;plt</p>
<p>import&nbsp;torch.nn&nbsp;as&nbsp;nn</p>
<p>import&nbsp;torch.nn.functional&nbsp;as&nbsp;F</p>
<p>import&nbsp;torch.optim&nbsp;as&nbsp;optim</p>
<p></p>
<p></p>
<p>#&nbsp;数据预处理:将图像转换为张量,并进行标准化</p>
<p>transform&nbsp;=&nbsp;transforms.Compose()</p>
<p></p>
<p>#&nbsp;下载并加载&nbsp;MNIST&nbsp;训练集和测试集</p>
<p>train_dataset&nbsp;=&nbsp;datasets.MNIST(root='./data',&nbsp;train=True,&nbsp;download=True,&nbsp;transform=transform)</p>
<p>test_dataset&nbsp;=&nbsp;datasets.MNIST(root='./data',&nbsp;train=False,&nbsp;download=True,&nbsp;transform=transform)</p>
<p></p>
<p>#&nbsp;加载数据集</p>
<p>train_loader&nbsp;=&nbsp;DataLoader(dataset=train_dataset,&nbsp;batch_size=64,&nbsp;shuffle=True)</p>
<p>test_loader&nbsp;=&nbsp;DataLoader(dataset=test_dataset,&nbsp;batch_size=64,&nbsp;shuffle=False)</p>
<p></p>
<p></p>
<p></p>
<p>#&nbsp;定义&nbsp;CNN&nbsp;模型</p>
<p>class&nbsp;CNN(nn.Module):</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;def&nbsp;__init__(self):</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;super(CNN,&nbsp;self).__init__()</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;#&nbsp;卷积层1:&nbsp;输入通道为1(灰度图),输出通道为16,卷积核大小为3x3</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;self.conv1&nbsp;=&nbsp;nn.Conv2d(1,&nbsp;16,&nbsp;kernel_size=3)</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;#&nbsp;卷积层2:&nbsp;输入通道为16,输出通道为32,卷积核大小为3x3</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;self.conv2&nbsp;=&nbsp;nn.Conv2d(16,&nbsp;32,&nbsp;kernel_size=3)</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;#&nbsp;全连接层1:&nbsp;输入为32*5*5(展平后的特征图),输出为128</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;self.fc1&nbsp;=&nbsp;nn.Linear(32&nbsp;*&nbsp;5&nbsp;*&nbsp;5,&nbsp;128)</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;#&nbsp;全连接层2:&nbsp;输入为128,输出为10(10个类别)</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;self.fc2&nbsp;=&nbsp;nn.Linear(128,&nbsp;10)</p>
<p></p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;def&nbsp;forward(self,&nbsp;x):</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;#&nbsp;卷积层&nbsp;+&nbsp;ReLU&nbsp;+&nbsp;最大池化层</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;x&nbsp;=&nbsp;F.relu(F.max_pool2d(self.conv1(x),&nbsp;2))</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;x&nbsp;=&nbsp;F.relu(F.max_pool2d(self.conv2(x),&nbsp;2))</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;#&nbsp;展平成一维向量</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;x&nbsp;=&nbsp;x.view(-1,&nbsp;32&nbsp;*&nbsp;5&nbsp;*&nbsp;5)</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;#&nbsp;全连接层&nbsp;+&nbsp;ReLU</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;x&nbsp;=&nbsp;F.relu(self.fc1(x))</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;#&nbsp;输出层</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;x&nbsp;=&nbsp;self.fc2(x)</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;return&nbsp;x</p>
<p></p>
<p>#&nbsp;实例化模型</p>
<p>model&nbsp;=&nbsp;CNN()</p>
<p></p>
<p>#&nbsp;定义损失函数和优化器</p>
<p>criterion&nbsp;=&nbsp;nn.CrossEntropyLoss()</p>
<p>optimizer&nbsp;=&nbsp;optim.Adam(model.parameters(),&nbsp;lr=0.001)</p>
<p></p>
<p>#&nbsp;将模型移动到&nbsp;GPU(如果可用)</p>
<p>device&nbsp;=&nbsp;torch.device("cuda"&nbsp;if&nbsp;torch.cuda.is_available()&nbsp;else&nbsp;"cpu")</p>
<p>model.to(device)</p>
<p></p>
<p>#&nbsp;训练模型</p>
<p>epochs&nbsp;=&nbsp;5</p>
<p>for&nbsp;epoch&nbsp;in&nbsp;range(epochs):</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;running_loss&nbsp;=&nbsp;0.0</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;for&nbsp;images,&nbsp;labels&nbsp;in&nbsp;train_loader:</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;images,&nbsp;labels&nbsp;=&nbsp;images.to(device),&nbsp;labels.to(device)</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;#&nbsp;前向传播</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;outputs&nbsp;=&nbsp;model(images)</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;loss&nbsp;=&nbsp;criterion(outputs,&nbsp;labels)</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;#&nbsp;反向传播和优化</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;optimizer.zero_grad()</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;loss.backward()</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;optimizer.step()</p>
<p></p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;running_loss&nbsp;+=&nbsp;loss.item()</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;print(f"Epoch&nbsp;[{epoch&nbsp;+&nbsp;1}/{epochs}],&nbsp;Loss:&nbsp;{running_loss&nbsp;/&nbsp;len(train_loader):.4f}")</p>
<p>print("训练完成!")</p>
<p></code></pre></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>经过5次训练,损失函数越来越低,说明正确率提高了</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><div&nbsp;style="text-align:&nbsp;left;"></div></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p><span&nbsp;style="font-size:20px;"><strong>第四步,测试模型</strong></span></p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>测试环节,我们可以自己画一个图去进行识别,也可以从训练集里抽一个图去测试。</p></p>
<p></p>
<p><pre></p>
<p><code>#&nbsp;从测试集中取出一个样本</p>
<p>example_data,&nbsp;example_target&nbsp;=&nbsp;next(iter(test_loader))</p>
<p>example_data&nbsp;=&nbsp;example_data.to(device)</p>
<p></p>
<p>#&nbsp;使用模型进行预测</p>
<p>model.eval()</p>
<p>with&nbsp;torch.no_grad():</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;output&nbsp;=&nbsp;model(example_data)</p>
<p></p>
<p>#&nbsp;可视化预测结果</p>
<p>plt.figure(figsize=(10,&nbsp;3))</p>
<p>for&nbsp;i&nbsp;in&nbsp;range(6):</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;plt.subplot(1,&nbsp;6,&nbsp;i&nbsp;+&nbsp;1)</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;plt.imshow(example_data.cpu(),&nbsp;cmap='gray')</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;plt.title(f"预测:&nbsp;{torch.argmax(output).item()}")</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;plt.axis('off')</p>
<p>plt.show()</p>
<p></code></pre></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>运行结果如以下,百发百中有木有?</p></p>
<p></p>
<p><p>准确率98.93%</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>&nbsp;&nbsp;</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>而如果我们把训练模型这段程序拿掉,正确率一下子就很低了,基本全错。</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><pre></p>
<p><code>#&nbsp;训练模型</p>
<p>epochs&nbsp;=&nbsp;5</p>
<p>for&nbsp;epoch&nbsp;in&nbsp;range(epochs):</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;running_loss&nbsp;=&nbsp;0.0</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;for&nbsp;images,&nbsp;labels&nbsp;in&nbsp;train_loader:</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;images,&nbsp;labels&nbsp;=&nbsp;images.to(device),&nbsp;labels.to(device)</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;#&nbsp;前向传播</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;outputs&nbsp;=&nbsp;model(images)</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;loss&nbsp;=&nbsp;criterion(outputs,&nbsp;labels)</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;#&nbsp;反向传播和优化</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;optimizer.zero_grad()</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;loss.backward()</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;optimizer.step()</p>
<p></p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;running_loss&nbsp;+=&nbsp;loss.item()</p>
<p>&nbsp;&nbsp;&nbsp;&nbsp;print(f"Epoch&nbsp;[{epoch&nbsp;+&nbsp;1}/{epochs}],&nbsp;Loss:&nbsp;{running_loss&nbsp;/&nbsp;len(train_loader):.4f}")</p>
<p>print("训练完成!")</p>
<p></p>
<p></p>
<p></code></pre></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>&nbsp;&nbsp;</p></p>
<p></p>
<p><p>&nbsp;</p></p>
<p></p>
<p><p>本次的分享就到这里。</p></p>
<p></p>
</div><script>                                        var loginstr = '<div class="locked">查看本帖全部内容,请<a href="javascript:;"   style="color:#e60000" class="loginf">登录</a>或者<a href="https://bbs.eeworld.com.cn/member.php?mod=register_eeworld.php&action=wechat" style="color:#e60000" target="_blank">注册</a></div>';
                                       
                                        if(parseInt(discuz_uid)==0){
                                                                                                (function($){
                                                        var postHeight = getTextHeight(400);
                                                        $(".showpostmsg").html($(".showpostmsg").html());
                                                        $(".showpostmsg").after(loginstr);
                                                        $(".showpostmsg").css({height:postHeight,overflow:"hidden"});
                                                })(jQuery);
                                        }                </script><script type="text/javascript">(function(d,c){var a=d.createElement("script"),m=d.getElementsByTagName("script"),eewurl="//counter.eeworld.com.cn/pv/count/";a.src=eewurl+c;m.parentNode.insertBefore(a,m)})(document,523)</script>

Jacktang 发表于 2024-10-6 09:29

<p>如果我们把训练模型这段程序拿掉,正确率一下子就很低了,基本全错,好吧</p>

cc1989summer 发表于 2024-10-6 11:18

Jacktang 发表于 2024-10-6 09:29
如果我们把训练模型这段程序拿掉,正确率一下子就很低了,基本全错,好吧

<p>目前还没搞清楚,怎么把训练过的模型保存固定下来,后面做识别直接拿来用,而不用每次识别都要重新从头训练一遍,训练太耗时了(>2分钟)。</p>

MioChan 发表于 2024-10-12 17:50

<p>保存整个模型(序列化模型和权重,依赖python环境,不灵活)</p>

<p>torch.save(model, &quot;model.pth&quot;)</p>

<p>加载整个模型</p>

<p>model = torch.load(&quot;model.pth&quot;)</p>

<p>&nbsp;</p>

<p>只保存权重文件(模型在代码里写好,保存和加载模型和优化器权重,更灵活)</p>

<p>save_dict = {}<br />
save_dict = model.state_dict()<br />
save_dict = optimizer.state_dict()<br />
torch.save(save_dict,&nbsp; &quot;state_dict.pth&quot;)</p>

<p>&nbsp;</p>

<p>加载权重文件</p>

<p>checkpoint = torch.load(&quot;state_dict.pth&quot;)<br />
model.load_state_dict(checkpoint[&#39;model_state_dict&#39;])</p>

<p>optimizer.load_state_dict(checkpoint[&#39;optimizer_state_dict&#39;])</p>

xinmeng_wit 发表于 2024-10-12 20:50

<p>训练好的模型是可以保存的,下次使用直接load,无需重新训练</p>

cc1989summer 发表于 2024-10-12 20:56

MioChan 发表于 2024-10-12 17:50
保存整个模型(序列化模型和权重,依赖python环境,不灵活)

torch.save(model, &quot;model.pth&quot;)

加 ...

<p>谢谢,很详细的解答,我跑跑试试。<img height="48" src="https://bbs.eeworld.com.cn/static/editor/plugins/hkemoji/sticker/facebook/loveliness.gif" width="48" /></p>

cc1989summer 发表于 2024-10-12 20:56

xinmeng_wit 发表于 2024-10-12 20:50
训练好的模型是可以保存的,下次使用直接load,无需重新训练

<p>好的,谢谢解答。<img height="50" src="https://bbs.eeworld.com.cn/static/editor/plugins/hkemoji/sticker/facebook/wanwan33.gif" width="58" /></p>

hellokitty_bean 发表于 2024-10-14 09:19

<p>感谢分享呀。。。。。。。。。。。。。。。。。。。。。。<img height="48" src="https://bbs.eeworld.com.cn/static/editor/plugins/hkemoji/sticker/facebook/loveliness.gif" width="48" /></p>

cc1989summer 发表于 2024-10-14 15:32

hellokitty_bean 发表于 2024-10-14 09:19
感谢分享呀。。。。。。。。。。。。。。。。。。。。。。

<p>嘿嘿,大家共同进步。{:1_138:}</p>
页: [1]
查看完整版本: 《动手学深度学习PyTorch版》阅读分享四 手写识别小试牛刀(基于CNN)