294|4

163

帖子

3

TA的资源

一粒金砂(高级)

楼主
 

动手学深度学习(PyTorch版)- 【读书活动-心得分享】使用Pytorch实现线性回归 [复制链接]

  本帖最后由 御坂10032号 于 2024-10-30 22:42 编辑

简介

 

在上一个章节中我们手动的来实现了线形回归的训练和定义. 在本章节我们将借助PYtorch的API来实现相关的操作.

 


 

今天看昨天的代码又有了一点新的理解, 不得不感叹写书人的水平还是蛮高的.

 

 主要是体现在上述的代码上当我们随机生成训练数据的时候, 这个数据集的X(张量输入) 可以为多维, 训练数据集将会根据传递进去的 true_w 的权重值来创建多个x作为输入. 比如上述代码中的

true_w 的长度为3, 那么调用创建synthetic_data 后那么创建出来的features为长度为3的张量列表 

 

 

之后使用torch提供的数据加载器来加载训练数据到torch里. 并且定义了小批量的数据集.

 

109 行这里代码从torch中导入了神经网络的包.

 

110行,创建了一个输入输出为3(即X向量的长度) 和输出 为1 (即Y)的神经网络

 

111行来初始化了神经网络的第一层的权重和偏差.

 

112之后便定义了模型中的损失函数的计算方式即使用 nn.BCELoss()

 

113 行则是定义了模型的优化算法以及学习率等. 这里的优化算法采用的是 随机梯度下降(Stochastic Gradient Descent)

 

之后便对模型进行训练

 

num_epochs = 10
for epoch in range(num_epochs):
    for X, y in data_iter:
        l = loss(net(X), y)
        # 清空梯度
        trainer.zero_grad()
        # 反向传播
        l.backward()
        # 更新模型参数
        trainer.step()
    l = loss(net(features), labels)
    print(f'epoch {epoch + 1}, loss {l:f}')

 

一共训练了十轮.

 

 

到最后我们将完整的模型保存下来. 可以新建一个其他的文件 用于加载这个模型进行预测.

 

 

import numpy as np
import torch
from torch import nn
from torch.utils import data
from d2l import torch as d2l

# 加载模型参数
model = torch.load('model_complete.pth')  # 直接加载整个模型
model.eval()

# 使用已加载的模型进行预测
# 示例输入特征
X_new = torch.tensor([[-0.2552, -1.1491, -2.3036]])
y_pred = loaded_net(X_new)
print(f'预测结果: {y_pred.item()}')

 

大家可以实际对比一下预测是值和原本的数据. 如下所示

 

 

 

误差还是非常小的.

 

 

 

一千个数据训练出来的完成模型也非常小, 才2 KB, 接下来要尝试一下怎么把这个模型搞到micro python里了

 

 

 

02-simple-way.ipynb.zip (2.51 KB, 下载次数: 3)

 

最新回复

print(f'epoch {epoch + 1}, loss {l:f}') 不懂就问,有人知道{l:f}的意思吗?l是变量,这里:f起的是什么作用?   详情 回复 发表于 2024-10-30 23:12
点赞 关注(1)
 
 

回复
举报

278

帖子

0

TA的资源

一粒金砂(高级)

沙发
 
线形回归?不是线性回归?

点评

修改了, 谢谢指正  详情 回复 发表于 2024-10-30 22:45
 
 
 

回复

163

帖子

3

TA的资源

一粒金砂(高级)

板凳
 
zhoupxa 发表于 2024-10-30 22:39 线形回归?不是线性回归?

修改了, 谢谢指正

 
 
 

回复

149

帖子

2

TA的资源

一粒金砂(高级)

4
 

print(f'epoch {epoch + 1}, loss {l:f}')

不懂就问,有人知道{l:f}的意思吗?l是变量,这里:f起的是什么作用?

点评

格式化浮点数输出的, 跟java的stringformat或者C的 %d %f %h 差不多  详情 回复 发表于 2024-10-30 23:27
 
 
 

回复

163

帖子

3

TA的资源

一粒金砂(高级)

5
 
ljg2np 发表于 2024-10-30 23:12 print(f'epoch {epoch + 1}, loss {l:f}') 不懂就问,有人知道{l:f}的意思 ...

格式化浮点数输出的, 跟java的stringformat或者C的 %d %f %h 差不多

 
 
 

回复
您需要登录后才可以回帖 登录 | 注册

随便看看
查找数据手册?

EEWorld Datasheet 技术支持

相关文章 更多>>
关闭
站长推荐上一条 1/10 下一条

 
EEWorld订阅号

 
EEWorld服务号

 
汽车开发圈

About Us 关于我们 客户服务 联系方式 器件索引 网站地图 最新更新 手机版

站点相关: 国产芯 安防电子 汽车电子 手机便携 工业控制 家用电子 医疗电子 测试测量 网络通信 物联网

北京市海淀区中关村大街18号B座15层1530室 电话:(010)82350740 邮编:100190

电子工程世界版权所有 京B2-20211791 京ICP备10001474号-1 电信业务审批[2006]字第258号函 京公网安备 11010802033920号 Copyright © 2005-2024 EEWORLD.com.cn, Inc. All rights reserved
快速回复 返回顶部 返回列表