1063|2

365

帖子

0

TA的资源

版主

楼主
 

一起读《深度强化学习实战》- 构建Q函数神经网络,计算Loss损失 [复制链接]

定义

  • 模型,需要抽象出游戏的状态向量、输入层、隐藏层、输出层(方便计算机表达的形式)。Gridworld游戏中,玩家p、目标+、坑-、墙w,这四个位置分别对应4x4的可能,4个维度再组合,则成为4x4x4大小的状态向量(64)。输出层则是给定状态下每个动作的Q值向量(移动的方向、上下左右)。
  • PyTorch的nn模块是一个专门为深度学习设计的模块,它提供了创建和训练神经网络的各种工具。其核心的数据结构是Module,它是一个抽象的概念,既可以表示神经网络中的某个层,也可以表示一个包含很多层的神经网络。此模块包含了各种神经网络层,例如卷积层、池化层、全连接层等等。同时,用户可以根据具体的问题对这些模块进行定制,从而充分发挥这些模块的功能。
  • MSE = Σ(yi - f(xi))^2 / n,均方差损失函数的值越小,代表模型的预测值越接近真实值,因此模型的鲁棒性就越好。此外,均方差损失函数还可以评价数据的变化程度,MSE的值越小,说明预测模型描述实验数据具有更好的精确度。

 

  • argmax函数用于返回数组中最大值的索引
  • PyTorch中的no_grad()是一个上下文管理器,用于在计算图中禁用梯度计算。当你不需要计算梯度时,可以使用它来提高性能
  • PyTorch中的detach()方法用于将张量从计算图中分离,使其不再参与梯度计算。这在训练过程中非常有用
  • PyTorch中的squeeze()方法用于移除张量中维度为1的尺寸。如果张量在某个维度上的大小为1,则该维度会被移除,并且结果张量的形状会相应地改变
  • 添加噪声,是因为ReLU函数在0处不可微。
  • 在PyTorch中,计算图是由节点和边组成的。节点表示参与运算的变量,比如张量或者Function。例如,在我们前面提到的线性回归模型中,输入数据x和权重w就可以被视为节点。这些节点在计算图中通过边来表示它们之间的依赖关系,例如,x和w之间的相乘操作可以由torch.mul()表示。特别需要注意的是,在神经网络中还有一种特殊的节点,称为叶子节点。这些节点是由用户创建的Variable变量,不需要再依赖其他变量进行计算。在进行网络优化时,反向传播算法会从输出开始,通过链式求导法则计算出每一个叶子节点对最终输出结果的梯度。此外,PyTorch的计算图是动态图。这意味着在前向传播过程中,每条语句都会在计算图中动态添加节点和边,并立即执行正向传播得到计算结果。一旦进行了反向传播或者计算了梯度,创建的计算图就会立即销毁,释放存储空间。因此,每次执行前向传播过程时,都需要重新构建计算图。这种动态性使得PyTorch能够更有效地处理复杂的计算任务和大规模的数据。
  • 反向传播的主要目标是最小化损失函数。在反向传播过程中,误差是从输出层向输入层反向传播的,通过链式法则计算损失函数对各参数的梯度(该参数对损失函数的贡献(即梯度)),并更新参数。

实践

计算loss损失图

import numpy as np
import torch
from Gridworld import Gridworld
# from IPython.display import clear_output
import random
from matplotlib import pylab as plt

l1 = 64
l2 = 150
l3 = 100
l4 = 4

model = torch.nn.Sequential(
    torch.nn.Linear(l1, l2),
    torch.nn.ReLU(),
    torch.nn.Linear(l2, l3),
    torch.nn.ReLU(),
    torch.nn.Linear(l3,l4)
)
loss_fn = torch.nn.MSELoss()

learning_rate = 1e-3
gamma = 0.9
epsilon = 1.0
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


action_set = {
    0: 'u',
    1: 'd',
    2: 'l',
    3: 'r',
}

epochs = 1000
losses = [] #A
for i in range(epochs): #B
    game = Gridworld(size=4, mode='static') #C
    state_ = game.board.render_np().reshape(1,64) + np.random.rand(1,64)/10.0 #D
    # https://github.com/DeepReinforcementLearning/DeepReinforcementLearningInAction/blob/master/Chapter%203/Ch3_book.ipynb
    # 如下给错了,应该是state1
    state1 = torch.from_numpy(state_).float() #E
    status = 1 #F
    while(status == 1): #G
        qval = model(state1) #H
        qval_ = qval.data.numpy()
        if (random.random() < epsilon): #I
            action_ = np.random.randint(0,4)
        else:
            action_ = np.argmax(qval_)
        
        action = action_set[action_] #J
        game.makeMove(action) #K
        state2_ = game.board.render_np().reshape(1,64) + np.random.rand(1,64)/10.0
        state2 = torch.from_numpy(state2_).float() #L
        reward = game.reward()
        with torch.no_grad():
            newQ = model(state2.reshape(1,64))
        maxQ = torch.max(newQ) #M
        if reward == -1: #N
            Y = reward + (gamma * maxQ)
        else:
            Y = reward
        Y = torch.Tensor([Y]).detach()
        X = qval.squeeze()[action_] #O
        loss = loss_fn(X, Y) #P
        print(i, loss.item())
        #不适用IPython,这个可以删除
        # clear_output(wait=True)
        optimizer.zero_grad()
        loss.backward()
        losses.append(loss.item())
        optimizer.step()
        state1 = state2
        if reward != -1: #Q
            status = 0
    if epsilon > 0.1: #R
        epsilon -= (1/epochs)

plt.figure(figsize=(10,7))
plt.plot(losses)
plt.xlabel("Epochs",fontsize=22)
plt.ylabel("Loss",fontsize=22)

plt.show()

 视频讲解


 

最新回复

感谢分享,期待深度后续   详情 回复 发表于 2023-11-23 20:53
点赞 关注

回复
举报

755

帖子

4

TA的资源

纯净的硅(高级)

沙发
 

感谢楼主提供的技术分享,先收藏学习再发表个人意见,顶起来

 
 

回复

7671

帖子

2

TA的资源

五彩晶圆(高级)

板凳
 

感谢分享,期待深度后续

 
个人签名

默认摸鱼,再摸鱼。2022、9、28

 
 

回复
您需要登录后才可以回帖 登录 | 注册

查找数据手册?

EEWorld Datasheet 技术支持

相关文章 更多>>
关闭
站长推荐上一条 1/7 下一条

 
EEWorld订阅号

 
EEWorld服务号

 
汽车开发圈

About Us 关于我们 客户服务 联系方式 器件索引 网站地图 最新更新 手机版

站点相关: 国产芯 安防电子 汽车电子 手机便携 工业控制 家用电子 医疗电子 测试测量 网络通信 物联网

北京市海淀区中关村大街18号B座15层1530室 电话:(010)82350740 邮编:100190

电子工程世界版权所有 京B2-20211791 京ICP备10001474号-1 电信业务审批[2006]字第258号函 京公网安备 11010802033920号 Copyright © 2005-2025 EEWORLD.com.cn, Inc. All rights reserved
快速回复 返回顶部 返回列表