420|8

365

帖子

0

TA的资源

版主

楼主
 

一起读《动手学深度学习(PyTorch版)》- RNN-sequence-model [复制链接]

 

sin 1000加上噪声

import torch
from torch import nn
import matplotlib.pyplot as plt

T = 1000 
time = torch.arange(1, T + 1, dtype=torch.float32)
x = torch.sin(0.01 * time) + torch.normal(0, 0.2, (T,))
plt.plot(time.tolist(), x.tolist())
plt.show()

 

训练后,进行单步预测

import torch
from torch import nn
from torch.utils import data
from torchvision import transforms
import matplotlib.pyplot as plt

def load_array(data_arrays, batch_size, is_train=True):
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset, batch_size, shuffle=is_train, num_workers=6)

T = 1000
time = torch.arange(1, T + 1, dtype=torch.float32)
x = torch.sin(0.01 * time) + torch.normal(0, 0.2, (T,))
# plt.plot(time.tolist(), x.tolist())
# plt.show()

tau = 4
# [996, 4]
features = torch.zeros((T - tau, tau))
for i in range(tau):
    # pick up 996 elements of x and then slide 1 element every time
    features[:, i] = x[i: T - tau + i]
    # print(features[:, i])
    # print(features[:, i].shape)
labels = x[tau:].reshape((-1, 1))

batch_size, n_train = 16, 600
train_iter = load_array((features[:n_train], labels[:n_train]),
                            batch_size, is_train=True)

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight)

def get_net():
    net = nn.Sequential(nn.Linear(4, 10),
                        nn.ReLU(),
                        nn.Linear(10, 1))
    net.apply(init_weights)
    return net

class Accumulator:
    def __init__(self, n) -> None:
        self.data = [0.0]*n
    
    def add(self, *args):
        # args is a tupe
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def evaluate_loss(net, data_iter, loss):
    metric = Accumulator(2)
    for X, y in data_iter:
        out = net(X)
        y = y.reshape(out.shape)
        l = loss(out, y)
        metric.add(l.sum(), l.numel())
    return metric[0] / metric[1]

loss = nn.MSELoss(reduction='none')

def train(net, train_iter, loss, epochs, lr):
    trainer = torch.optim.Adam(net.parameters(), lr)
    for epoch in range(epochs):
        for X, y in train_iter:
            trainer.zero_grad()
            l = loss(net(X), y)
            l.sum().backward()
            trainer.step()
        print(f'epoch {epoch + 1}, '
              f'loss: {evaluate_loss(net, train_iter, loss):f}')

net = get_net()
train(net, train_iter, loss, 5, 0.01)

onestep_preds = net(features)
plt.plot(time.tolist(), x.tolist())
plt.plot(time[tau:].tolist(), onestep_preds.tolist())
plt.show()

 

最新回复

添加一个注释,reshape((-1,1)),-1表示自动计算行数,1列;同理,对于(1,-1)只给定行数1,也可以自动计算出新数组的列数。   详情 回复 发表于 2024-11-4 09:34
点赞 关注
 
 

回复
举报

1364

帖子

1

TA的资源

五彩晶圆(初级)

沙发
 

不清不楚。。。。。。。。。。。。。。。。。

RNN,其实也有变种。。。。。。。。

原理和方法都没

点评

这个就书里面的,其他例子等你分享,hhh  详情 回复 发表于 2024-11-2 18:16
 
 
 

回复

365

帖子

0

TA的资源

版主

板凳
 
hellokitty_bean 发表于 2024-11-2 13:51 不清不楚。。。。。。。。。。。。。。。。。 RNN,其实也有变种。。。。。。。。 原理和方法都没

这个就书里面的,其他例子等你分享,hhh

点评

还没开始仔细看呢。。。。。。。。等闲下来再看后分享了。。。。。。。。  详情 回复 发表于 2024-11-3 08:21
 
 
 

回复

6450

帖子

10

TA的资源

版主

4
 

训练后的文件是什么类型的?C能调用吗?

点评

可以保存成模型,然后一般推理引擎都有各种语言的实现,C只是一种  详情 回复 发表于 2024-11-3 11:28
个人签名

在爱好的道路上不断前进,在生活的迷雾中播撒光引

 
 
 

回复

1364

帖子

1

TA的资源

五彩晶圆(初级)

5
 
LitchiCheng 发表于 2024-11-2 18:16 这个就书里面的,其他例子等你分享,hhh

还没开始仔细看呢。。。。。。。。等闲下来再看后分享了。。。。。。。。

点评

kkk  详情 回复 发表于 2024-11-3 11:27
 
 
 

回复

365

帖子

0

TA的资源

版主

6
 
hellokitty_bean 发表于 2024-11-3 08:21 还没开始仔细看呢。。。。。。。。等闲下来再看后分享了。。。。。。。。

kkk

 
 
 

回复

365

帖子

0

TA的资源

版主

7
 
秦天qintian0303 发表于 2024-11-2 23:18 训练后的文件是什么类型的?C能调用吗?

可以保存成模型,然后一般推理引擎都有各种语言的实现,C只是一种

 
 
 

回复

223

帖子

2

TA的资源

一粒金砂(高级)

8
 

作为初学者,这里,添加一个注释说明:x=(T,)表示一维数组,y=(T,1)表示二维数组,x==y.flatten()的值为True。

个人签名

.~

 
 
 

回复

223

帖子

2

TA的资源

一粒金砂(高级)

9
 

添加一个注释,reshape((-1,1)),-1表示自动计算行数,1列;同理,对于(1,-1)只给定行数1,也可以自动计算出新数组的列数。

个人签名

.~

 
 
 

回复
您需要登录后才可以回帖 登录 | 注册

随便看看
查找数据手册?

EEWorld Datasheet 技术支持

相关文章 更多>>
关闭
站长推荐上一条 1/9 下一条

 
EEWorld订阅号

 
EEWorld服务号

 
汽车开发圈

About Us 关于我们 客户服务 联系方式 器件索引 网站地图 最新更新 手机版

站点相关: 国产芯 安防电子 汽车电子 手机便携 工业控制 家用电子 医疗电子 测试测量 网络通信 物联网

北京市海淀区中关村大街18号B座15层1530室 电话:(010)82350740 邮编:100190

电子工程世界版权所有 京B2-20211791 京ICP备10001474号-1 电信业务审批[2006]字第258号函 京公网安备 11010802033920号 Copyright © 2005-2025 EEWORLD.com.cn, Inc. All rights reserved
快速回复 返回顶部 返回列表