495|8

439

帖子

0

TA的资源

版主

楼主
 

一起读《动手学深度学习(PyTorch版)》- RNN-sequence-model [复制链接]

 

sin 1000加上噪声

  • import torch
  • from torch import nn
  • import matplotlib.pyplot as plt
  • T = 1000
  • time = torch.arange(1, T + 1, dtype=torch.float32)
  • x = torch.sin(0.01 * time) + torch.normal(0, 0.2, (T,))
  • plt.plot(time.tolist(), x.tolist())
  • plt.show()

 

训练后,进行单步预测

  • import torch
  • from torch import nn
  • from torch.utils import data
  • from torchvision import transforms
  • import matplotlib.pyplot as plt
  • def load_array(data_arrays, batch_size, is_train=True):
  • dataset = data.TensorDataset(*data_arrays)
  • return data.DataLoader(dataset, batch_size, shuffle=is_train, num_workers=6)
  • T = 1000
  • time = torch.arange(1, T + 1, dtype=torch.float32)
  • x = torch.sin(0.01 * time) + torch.normal(0, 0.2, (T,))
  • # plt.plot(time.tolist(), x.tolist())
  • # plt.show()
  • tau = 4
  • # [996, 4]
  • features = torch.zeros((T - tau, tau))
  • for i in range(tau):
  • # pick up 996 elements of x and then slide 1 element every time
  • features[:, i] = x[i: T - tau + i]
  • # print(features[:, i])
  • # print(features[:, i].shape)
  • labels = x[tau:].reshape((-1, 1))
  • batch_size, n_train = 16, 600
  • train_iter = load_array((features[:n_train], labels[:n_train]),
  • batch_size, is_train=True)
  • def init_weights(m):
  • if type(m) == nn.Linear:
  • nn.init.xavier_uniform_(m.weight)
  • def get_net():
  • net = nn.Sequential(nn.Linear(4, 10),
  • nn.ReLU(),
  • nn.Linear(10, 1))
  • net.apply(init_weights)
  • return net
  • class Accumulator:
  • def __init__(self, n) -> None:
  • self.data = [0.0]*n
  • def add(self, *args):
  • # args is a tupe
  • self.data = [a + float(b) for a, b in zip(self.data, args)]
  • def reset(self):
  • self.data = [0.0] * len(self.data)
  • def __getitem__(self, idx):
  • return self.data[idx]
  • def evaluate_loss(net, data_iter, loss):
  • metric = Accumulator(2)
  • for X, y in data_iter:
  • out = net(X)
  • y = y.reshape(out.shape)
  • l = loss(out, y)
  • metric.add(l.sum(), l.numel())
  • return metric[0] / metric[1]
  • loss = nn.MSELoss(reduction='none')
  • def train(net, train_iter, loss, epochs, lr):
  • trainer = torch.optim.Adam(net.parameters(), lr)
  • for epoch in range(epochs):
  • for X, y in train_iter:
  • trainer.zero_grad()
  • l = loss(net(X), y)
  • l.sum().backward()
  • trainer.step()
  • print(f'epoch {epoch + 1}, '
  • f'loss: {evaluate_loss(net, train_iter, loss):f}')
  • net = get_net()
  • train(net, train_iter, loss, 5, 0.01)
  • onestep_preds = net(features)
  • plt.plot(time.tolist(), x.tolist())
  • plt.plot(time[tau:].tolist(), onestep_preds.tolist())
  • plt.show()

 

查看本帖全部内容,请登录或者注册

最新回复

添加一个注释,reshape((-1,1)),-1表示自动计算行数,1列;同理,对于(1,-1)只给定行数1,也可以自动计算出新数组的列数。   详情 回复 发表于 2024-11-4 09:34
点赞 关注
 
 

回复
举报

1490

帖子

1

TA的资源

五彩晶圆(初级)

沙发
 

不清不楚。。。。。。。。。。。。。。。。。

RNN,其实也有变种。。。。。。。。

原理和方法都没

点评

这个就书里面的,其他例子等你分享,hhh  详情 回复 发表于 2024-11-2 18:16
 
 
 

回复

439

帖子

0

TA的资源

版主

板凳
 
hellokitty_bean 发表于 2024-11-2 13:51 不清不楚。。。。。。。。。。。。。。。。。 RNN,其实也有变种。。。。。。。。 原理和方法都没

这个就书里面的,其他例子等你分享,hhh

点评

还没开始仔细看呢。。。。。。。。等闲下来再看后分享了。。。。。。。。  详情 回复 发表于 2024-11-3 08:21
 
 
 

回复

6876

帖子

10

TA的资源

版主

4
 

训练后的文件是什么类型的?C能调用吗?

点评

可以保存成模型,然后一般推理引擎都有各种语言的实现,C只是一种  详情 回复 发表于 2024-11-3 11:28
个人签名

在爱好的道路上不断前进,在生活的迷雾中播撒光引

 
 
 

回复

1490

帖子

1

TA的资源

五彩晶圆(初级)

5
 
LitchiCheng 发表于 2024-11-2 18:16 这个就书里面的,其他例子等你分享,hhh

还没开始仔细看呢。。。。。。。。等闲下来再看后分享了。。。。。。。。

点评

kkk  详情 回复 发表于 2024-11-3 11:27
 
 
 

回复

439

帖子

0

TA的资源

版主

6
 
hellokitty_bean 发表于 2024-11-3 08:21 还没开始仔细看呢。。。。。。。。等闲下来再看后分享了。。。。。。。。

kkk

 
 
 

回复

439

帖子

0

TA的资源

版主

7
 
秦天qintian0303 发表于 2024-11-2 23:18 训练后的文件是什么类型的?C能调用吗?

可以保存成模型,然后一般推理引擎都有各种语言的实现,C只是一种

 
 
 

回复

260

帖子

2

TA的资源

纯净的硅(初级)

8
 

作为初学者,这里,添加一个注释说明:x=(T,)表示一维数组,y=(T,1)表示二维数组,x==y.flatten()的值为True。

个人签名

波光潋滟.~

 
 
 

回复

260

帖子

2

TA的资源

纯净的硅(初级)

9
 

添加一个注释,reshape((-1,1)),-1表示自动计算行数,1列;同理,对于(1,-1)只给定行数1,也可以自动计算出新数组的列数。

个人签名

波光潋滟.~

 
 
 

回复
您需要登录后才可以回帖 登录 | 注册

开源项目 更多>>
    随便看看
    查找数据手册?

    EEWorld Datasheet 技术支持

    相关文章 更多>>
      关闭
      站长推荐上一条 1/10 下一条
      中星联华&ADI明日直播
      直播主题:大咖面对面,轻松玩转高速ADC性能测试
      直播时间:3月25日(周二)14:00
      活动奖励:京东卡、双肩包

      查看 »

       
      EEWorld订阅号

       
      EEWorld服务号

       
      汽车开发圈

       
      机器人开发圈

      About Us 关于我们 客户服务 联系方式 器件索引 网站地图 最新更新 手机版

      站点相关: 国产芯 安防电子 汽车电子 手机便携 工业控制 家用电子 医疗电子 测试测量 网络通信 物联网 12

      北京市海淀区中关村大街18号B座15层1530室 电话:(010)82350740 邮编:100190

      电子工程世界版权所有 京B2-20211791 京ICP备10001474号-1 电信业务审批[2006]字第258号函 京公网安备 11010802033920号 Copyright © 2005-2025 EEWORLD.com.cn, Inc. All rights reserved
      快速回复 返回顶部 返回列表