59|0

157

帖子

3

TA的资源

一粒金砂(高级)

楼主
 

动手学深度学习(PyTorch版)- 【读书活动-心得分享】多层感知机的实现 [复制链接]

 

简介

 

虽然读书活动结束了,但是真的想把这一本书里的东西学完。 之后也会继续更新这本书的内容。本章节学习了多层感知机。 多层感知机相对于单层感知机可以处理XOR问题。 每一层都具有每一层的权重和偏差。 第一层的输入数据, 输入到隐藏层。然后隐藏层对输入层的特征进行提取(可以指定隐藏层的大小)。然后将隐层的的特征作为下一层的输入。 到最后来获取到分类的数据。

 

我这里按照了书里的感知机运行了代码。通过调整隐藏层的大小获取到了不同的模型的训练曲线。

 

 

我们可以在上图看到,我这里一共截图了size 分别为 10 , 64 , 512 和 1024的隐藏层大小。 分别对比了不同的模型的训练曲线。

 

我们可以在上图看到如果隐藏层的层数增加(也不是绝对,512 和 1024 基本上没区别), 模型的训练正确度和测试正确度都比较平滑。 差别比较大的可以对比64 和 10 或者 10 和512. 我们发现除了损失,模型的精确度其实差不多。 对于损失函数而言我们可以发现 512 和 1024的损失基本上没区别。 但是 10 , 64,512 对比差距就比较大了。 通过上述我们发现一个规律。即:如果隐藏层的大小超过了某一个阈值即当前层无法为图像区别更多的特征(比如说一个图片根据任何的条件进行特征分类最多为10个特征, 但是这里隐藏层的长度为20). 那么隐藏层大小10 和20 将会没有区别。 但是如果隐藏层大小为小于特征分类的最大值的话。 那么损失函数的变化将会有所不同。

 

 

import torch
from torch import nn
from d2l import torch as d2l
batch_size = 256

train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
num_inputs, num_outputs, num_hiddens = 784, 10, 1024

W1 = nn.Parameter(torch.randn(
    num_inputs, num_hiddens, requires_grad=True) * 0.01)
b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))
W2 = nn.Parameter(torch.randn(
    num_hiddens, num_outputs, requires_grad=True) * 0.01)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))

params = [W1, b1, W2, b2]

def relu(X):
    a = torch.zeros_like(X)
    return torch.max(X, a)

def net(X):
    X = X.reshape((-1, num_inputs))
    H = relu(X@W1 + b1)  # 这里“@”代表矩阵乘法
    return (H@W2 + b2)

loss = nn.CrossEntropyLoss(reduction='none')

num_epochs, lr = 10, 0.1
updater = torch.optim.SGD(params, lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

d2l.predict_ch3(net, test_iter)

 

点赞 关注
 
 

回复
举报
您需要登录后才可以回帖 登录 | 注册

随便看看
查找数据手册?

EEWorld Datasheet 技术支持

相关文章 更多>>
关闭
站长推荐上一条 1/10 下一条

 
EEWorld订阅号

 
EEWorld服务号

 
汽车开发圈

About Us 关于我们 客户服务 联系方式 器件索引 网站地图 最新更新 手机版

站点相关: 国产芯 安防电子 汽车电子 手机便携 工业控制 家用电子 医疗电子 测试测量 网络通信 物联网

北京市海淀区中关村大街18号B座15层1530室 电话:(010)82350740 邮编:100190

电子工程世界版权所有 京B2-20211791 京ICP备10001474号-1 电信业务审批[2006]字第258号函 京公网安备 11010802033920号 Copyright © 2005-2024 EEWORLD.com.cn, Inc. All rights reserved
快速回复 返回顶部 返回列表