动手学深度学习(PyTorch版)- 【读书活动-心得分享】使用Pytorch实现线性回归
本帖最后由 御坂10032号 于 2024-10-30 22:42 编辑<p><strong><span style="font-size:24px;">简介</span></strong></p>
<p> </p>
<p>在上一个章节中我们手动的来实现了线形回归的训练和定义. 在本章节我们将借助PYtorch的API来实现相关的操作.</p>
<p> </p>
<hr />
<p> </p>
<p>今天看昨天的代码又有了一点新的理解, 不得不感叹写书人的水平还是蛮高的.</p>
<p> </p>
<p> 主要是体现在上述的代码上当我们随机生成训练数据的时候, 这个数据集的X(张量输入) 可以为多维, 训练数据集将会根据传递进去的 true_w 的权重值来创建多个x作为输入. 比如上述代码中的</p>
<p>true_w 的长度为3, 那么调用创建synthetic_data 后那么创建出来的features为长度为3的张量列表 </p>
<p> </p>
<p> </p>
<p>之后使用torch提供的数据加载器来加载训练数据到torch里. 并且定义了小批量的数据集.</p>
<p> </p>
<p>109 行这里代码从torch中导入了神经网络的包.</p>
<p> </p>
<p>110行,创建了一个输入输出为3(即X向量的长度) 和输出 为1 (即Y)的神经网络</p>
<p> </p>
<p>111行来初始化了神经网络的第一层的权重和偏差.</p>
<p> </p>
<p>112之后便定义了模型中的损失函数的计算方式即使用 <strong>nn.BCELoss()</strong></p>
<p> </p>
<p>113 行则是定义了模型的优化算法以及学习率等. 这里的优化算法采用的是 <strong>随机梯度下降</strong>(Stochastic Gradient Descent)</p>
<p> </p>
<p>之后便对模型进行训练</p>
<p> </p>
<pre>
<code class="language-python">num_epochs = 10
for epoch in range(num_epochs):
for X, y in data_iter:
l = loss(net(X), y)
# 清空梯度
trainer.zero_grad()
# 反向传播
l.backward()
# 更新模型参数
trainer.step()
l = loss(net(features), labels)
print(f'epoch {epoch + 1}, loss {l:f}')</code></pre>
<p> </p>
<p>一共训练了十轮.</p>
<p> </p>
<p></p>
<p> </p>
<p>到最后我们将完整的模型保存下来. 可以新建一个其他的文件 用于加载这个模型进行预测.</p>
<p> </p>
<p> </p>
<pre>
<code class="language-python">import numpy as np
import torch
from torch import nn
from torch.utils import data
from d2l import torch as d2l
# 加载模型参数
model = torch.load('model_complete.pth')# 直接加载整个模型
model.eval()
# 使用已加载的模型进行预测
# 示例输入特征
X_new = torch.tensor([[-0.2552, -1.1491, -2.3036]])
y_pred = loaded_net(X_new)
print(f'预测结果: {y_pred.item()}')</code></pre>
<p> </p>
<p>大家可以实际对比一下预测是值和原本的数据. 如下所示</p>
<p> </p>
<p> </p>
<p> </p>
<p>误差还是非常小的.</p>
<p> </p>
<p> </p>
<p> </p>
<p>一千个数据训练出来的完成模型也非常小, 才2 KB, 接下来要尝试一下怎么把这个模型搞到micro python里了</p>
<p> </p>
<p> </p>
<p> </p>
<div></div>
<p> </p>
线形回归?不是线性回归? zhoupxa 发表于 2024-10-30 22:39
线形回归?不是线性回归?
<p>修改了, 谢谢指正</p>
<p></p>
<p>print(f'epoch {epoch + 1}, loss {l:f}')</p>
<p></p>
<p>不懂就问,有人知道{l:f}的意思吗?l是变量,这里:f起的是什么作用?</p>
ljg2np 发表于 2024-10-30 23:12
print(f'epoch {epoch + 1}, loss {l:f}')
不懂就问,有人知道{l:f}的意思 ...
<p>格式化浮点数输出的, 跟java的stringformat或者C的 %d %f %h 差不多</p>
页:
[1]