5375|22

553

帖子

3

TA的资源

纯净的硅(初级)

楼主
 

《机器学习算法与实现》7、神经网络训练 [复制链接]

 

 一、反向传播算法

上一篇中已经知道了神经网络的正向计算,那本篇将学习神经网络训练中非常重要的反向传播。

神经网络的每个连接上的权值如果知道,那么就可以将输入数据代入得到希望的结果。

神经网络是一个模型,那么这些权值就是模型的参数,也就是模型要学习的东西。

反向传播算法其实就是链式求导法则的应用。

按照机器学习的通用求解思路,我们先确定神经网络的目标函数,然后用随机梯度下降优化算法去求目标函数最小值时的参数值。

本篇会涉及到非常多的公式推导,比较枯燥无趣,在这里也就不照搬公式了,有兴趣的同学可以去学习反向传播的数学推导过程。

这里直接给出反向传播的关键的理论公式。

 

1)目标函数

其中, Ed 表示是样本 d 的误差, t是样本的标签值y是神经网络的输出值

目标函数是用来求梯度,然后根据梯度来更新模型参数的基础。

使用随机梯度下降算法对目标函数进行优化:

最后,推导出的权重w的最终的公式如下图:

其中, wji 是节点 i 到节点 j 的权重, η 是一个成为学习速率的常数, δj 是节点 j 的误差项, xji 是节点 i 传递给节点 j 的输入。

对于输出层来说,

其中, δi 是节点 i 的误差项, yi 是节点 i 的输出值, ti 是样本对应于节点 i 的目标值。

对于隐藏层来说:

其中, ai 是节点 i 的输出值, wki 是节点 i 到它的下一层节点 k 的连接的权重, δk 是节点 i 的下一层节点 k 的误差项。

 

由公式4可知,计算一个节点的误差项,需要先计算每个与其相连的下一层节点的误差项,这就要求误差项的计算顺序必须是从输出层开始,然后反向依次计算每个隐藏层的误差项,直到与输入层相连的那个隐藏层,这就是反向传播算法的名字的含义。

 

 二、神经网络算法实现

 

1、数据准备

import numpy as np
from sklearn import datasets, linear_model
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score

# generate sample data
np.random.seed(0)
x, y = datasets.make_moons(200, noise=0.20)

y_true = np.array(y).astype(float)


# generate nn output target
t = np.zeros((x.shape[0], 2))
print(x.shape[0])
t[np.where(y==0), 0] = 1
t[np.where(y==1), 1] = 1


# plot data
plt.scatter(x[:, 0], x[:, 1], c=y, cmap=plt.cm.Spectral)
plt.show()

2、神经网络训练程序

首先生成神经网络模型,初始化权重数据,接着定义sigmoid和网络正向计算函数,具体如下:

# generate the NN model
class NN_Model:
    epsilon = 0.01               # learning rate
    n_epoch = 1000               # iterative number
    
nn = NN_Model()
nn.n_input_dim = x.shape[1]      # input size
nn.n_hide_dim = 8                # hidden node size
nn.n_output_dim = 2              # output node size

# initial weight array
nn.W1 = np.random.randn(nn.n_input_dim, nn.n_hide_dim) / np.sqrt(nn.n_input_dim)
nn.b1 = np.zeros((1, nn.n_hide_dim))
nn.W2 = np.random.randn(nn.n_hide_dim, nn.n_output_dim) / np.sqrt(nn.n_hide_dim)
nn.b2 = np.zeros((1, nn.n_output_dim))

# define sigmod & its derivate function
def sigmod(x):
    return 1.0/(1+np.exp(-x))

# network forward calculation       
def forward(n, x):
    n.z1 = sigmod(x.dot(n.W1) + n.b1)
    n.z2 = sigmod(n.z1.dot(n.W2) + n.b2)
    return n


# 使用随机权重进行训练
forward(nn, x)
y_pred = np.argmax(nn.z2, axis=1)

# plot data
plt.scatter(x[:, 0], x[:, 1], c=y_pred, cmap=plt.cm.Spectral)
plt.show()

程序中使用了未经训练的随机权重进行了预测,结果是所有的数据的分类结果都是同一类,接下来使用反向传播来进行训练,然后再看预测结果:

# 反向传播
def backpropagation(n, x, t):
    for i in range(n.n_epoch):
        # 正向计算每个节点的输出
        forward(n, x)
        
        # print loss, accuracy
        L = np.sum((n.z2 - t)**2)
        
        y_pred = np.argmax(nn.z2, axis=1)
        acc = accuracy_score(y_true, y_pred)
        
        if i % 100 == 0:
            print("epoch [%4d] L = %f, acc = %f" % (i, L, acc))
        
        # 计算误差
        d2 = n.z2*(1-n.z2)*(t - n.z2)
        d1 = n.z1*(1-n.z1)*(np.dot(d2, n.W2.T))
        
        # 更新权重
        n.W2 += n.epsilon * np.dot(n.z1.T, d2)
        n.b2 += n.epsilon * np.sum(d2, axis=0)
        n.W1 += n.epsilon * np.dot(x.T, d1)
        n.b1 += n.epsilon * np.sum(d1, axis=0)

nn.n_epoch = 2000
backpropagation(nn, x, t)


# plot data
y_pred = np.argmax(nn.z2, axis=1)

plt.scatter(x[:, 0], x[:, 1], c=y, cmap=plt.cm.Spectral)
plt.title("ground truth")
plt.show()

plt.scatter(x[:, 0], x[:, 1], c=y_pred, cmap=plt.cm.Spectral)
plt.title("predicted")
plt.show()

真实分类,和预测的结果如下:

预测精度:
93.5%

 

最新回复

比如c语言或者基于硬件的呢,这些我不太懂,就是感觉应该会比python快点吧   详情 回复 发表于 2024-8-29 10:31
点赞(1) 关注(1)

回复
举报

755

帖子

5

TA的资源

纯净的硅(高级)

沙发
 

感谢楼主分享的神经网络训练技术知识,希望在楼主的技术内容帮助下能自己进行神经网络的训练

点评

加油加油  详情 回复 发表于 2024-7-29 18:05
 
 

回复

1388

帖子

1

TA的资源

五彩晶圆(初级)

板凳
 

继续过来点赞送花。。。。。。。。。。。。Wit加油哟。。。。。。。。

点评

感谢hellokitty_bean大佬的持续关注和支持  详情 回复 发表于 2024-7-29 18:06
 
 
 

回复

7244

帖子

2

TA的资源

版主

4
 

什么样的数据适合用这个算法?

点评

我觉得这个是一个比较通用的神经网络解决一些分类问题应该是没什么问题的。  详情 回复 发表于 2024-7-29 18:09
 
 
 

回复

553

帖子

3

TA的资源

纯净的硅(初级)

5
 
chejm 发表于 2024-7-28 11:40 感谢楼主分享的神经网络训练技术知识,希望在楼主的技术内容帮助下能自己进行神经网络的训练

加油加油

 
 
 

回复

553

帖子

3

TA的资源

纯净的硅(初级)

6
 
hellokitty_bean 发表于 2024-7-28 14:25 继续过来点赞送花。。。。。。。。。。。。Wit加油哟。。。。。。。。

感谢hellokitty_bean大佬的持续关注和支持

 
 
 

回复

553

帖子

3

TA的资源

纯净的硅(初级)

7
 
wangerxian 发表于 2024-7-29 09:16 什么样的数据适合用这个算法?

我觉得这个是一个比较通用的神经网络解决一些分类问题应该是没什么问题的。

点评

Wit,你的这个BP,epoch设置2000,那个打印是在epoch=1900时的输出, 应该打印所有epoch,看最后的2000时的acc才对吧?  详情 回复 发表于 2024-8-18 09:30
 
 
 

回复

1388

帖子

1

TA的资源

五彩晶圆(初级)

8
 
xinmeng_wit 发表于 2024-7-29 18:09 我觉得这个是一个比较通用的神经网络解决一些分类问题应该是没什么问题的。

Wit,你的这个BP,epoch设置2000,那个打印是在epoch=1900时的输出,

应该打印所有epoch,看最后的2000时的acc才对吧?

点评

对的,epoch是2000的时候是0~1999,最后一次没有打印出来,是有点问题。  详情 回复 发表于 2024-8-20 20:27
 
 
 

回复

1388

帖子

1

TA的资源

五彩晶圆(初级)

9
 

把 nn_n_epoch=2001,程序跑到epoch[2000],acc=0.945000

所以预测的acc(准确率)其实应该是94.5%

 

 

点评

是的,没错,  详情 回复 发表于 2024-8-20 20:28
 
 
 

回复

1388

帖子

1

TA的资源

五彩晶圆(初级)

10
 

发现,把epoch提高到4001时,迭代到3200准确率就到了96.5%,后面一直保持在96.5%了

也就是说,再怎么训练数据,其实没有提升了。

点评

epoch并不是越大越好,太大了可能会出现过拟合,太小了可能出现欠拟合。 当神经网络较复杂,但是训练数据较少的情况下,epoch太大容易出现过拟合。 反之,容易出现欠拟合  详情 回复 发表于 2024-8-20 20:33
 
 
 

回复

7671

帖子

2

TA的资源

五彩晶圆(高级)

11
 

谢谢分享,期待后续!

点评

感谢关注  详情 回复 发表于 2024-8-20 20:34
 
个人签名

默认摸鱼,再摸鱼。2022、9、28

 
 

回复

553

帖子

3

TA的资源

纯净的硅(初级)

12
 
hellokitty_bean 发表于 2024-8-18 09:30 Wit,你的这个BP,epoch设置2000,那个打印是在epoch=1900时的输出, 应该打印所有epoch,看最后的200 ...

对的,epoch是2000的时候是0~1999,最后一次没有打印出来,是有点问题。

 
 
 

回复

553

帖子

3

TA的资源

纯净的硅(初级)

13
 
hellokitty_bean 发表于 2024-8-18 09:35 把 nn_n_epoch=2001,程序跑到epoch[2000],acc=0.945000 所以预测的acc(准确率)其实应该是94.5% & ...

是的,没错,

 
 
 

回复

553

帖子

3

TA的资源

纯净的硅(初级)

14
 
hellokitty_bean 发表于 2024-8-18 09:38 发现,把epoch提高到4001时,迭代到3200准确率就到了96.5%,后面一直保持在96.5%了 也就是说,再怎么训 ...

epoch并不是越大越好,太大了可能会出现过拟合,太小了可能出现欠拟合。

当神经网络较复杂,但是训练数据较少的情况下,epoch太大容易出现过拟合。

反之,容易出现欠拟合

点评

欠拟合,那是没学好,acc会不好。 过拟合是什么现象?什么叫过拟合?  详情 回复 发表于 2024-8-20 20:39
 
 
 

回复

553

帖子

3

TA的资源

纯净的硅(初级)

15
 
freebsder 发表于 2024-8-20 18:00 谢谢分享,期待后续!

感谢关注

 
 
 

回复

1388

帖子

1

TA的资源

五彩晶圆(初级)

16
 
xinmeng_wit 发表于 2024-8-20 20:33 epoch并不是越大越好,太大了可能会出现过拟合,太小了可能出现欠拟合。 当神经网络较复杂,但是训练 ...

欠拟合,那是没学好,acc会不好。

过拟合是什么现象?什么叫过拟合?

点评

过拟合是把训练数据的噪声当作特征拟合进取了。 现象就是,在训练集上预测结果很好,但是在测试集上表现不好  详情 回复 发表于 2024-8-20 21:20
 
 
 

回复

553

帖子

3

TA的资源

纯净的硅(初级)

17
 
hellokitty_bean 发表于 2024-8-20 20:39 欠拟合,那是没学好,acc会不好。 过拟合是什么现象?什么叫过拟合?

过拟合是把训练数据的噪声当作特征拟合进取了。

现象就是,在训练集上预测结果很好,但是在测试集上表现不好

点评

嗯嗯嗯,透彻了。。。。 就是平时训练成绩都蛮好,等到奥运赛场,16强都进不了。。。  详情 回复 发表于 2024-8-21 11:14
 
 
 

回复

1388

帖子

1

TA的资源

五彩晶圆(初级)

18
 
xinmeng_wit 发表于 2024-8-20 21:20 过拟合是把训练数据的噪声当作特征拟合进取了。 现象就是,在训练集上预测结果很好,但是在测试集上表 ...

嗯嗯嗯,透彻了。。。。

就是平时训练成绩都蛮好,等到奥运赛场,16强都进不了。。。

点评

很形象,哈哈  详情 回复 发表于 2024-8-21 12:52
 
 
 

回复

553

帖子

3

TA的资源

纯净的硅(初级)

19
 
hellokitty_bean 发表于 2024-8-21 11:14 嗯嗯嗯,透彻了。。。。 就是平时训练成绩都蛮好,等到奥运赛场,16强都进不了。。。

很形象,哈哈

 
 
 

回复

10

帖子

0

TA的资源

一粒金砂(中级)

20
 
大佬有没有兴趣用更高速率的语言去实现比较base的算法来加速收敛呢?

点评

更高速率的语言是指?  详情 回复 发表于 2024-8-28 06:24
 
 
 

回复
您需要登录后才可以回帖 登录 | 注册

随便看看
查找数据手册?

EEWorld Datasheet 技术支持

相关文章 更多>>
关闭
站长推荐上一条 1/6 下一条

 
EEWorld订阅号

 
EEWorld服务号

 
汽车开发圈

About Us 关于我们 客户服务 联系方式 器件索引 网站地图 最新更新 手机版

站点相关: 国产芯 安防电子 汽车电子 手机便携 工业控制 家用电子 医疗电子 测试测量 网络通信 物联网

北京市海淀区中关村大街18号B座15层1530室 电话:(010)82350740 邮编:100190

电子工程世界版权所有 京B2-20211791 京ICP备10001474号-1 电信业务审批[2006]字第258号函 京公网安备 11010802033920号 Copyright © 2005-2025 EEWORLD.com.cn, Inc. All rights reserved
快速回复 返回顶部 返回列表