3582|0

422

帖子

4

TA的资源

纯净的硅(初级)

楼主
 

5. 用pytorch设计计算图 [复制链接]

本帖最后由 北方 于 2018-9-29 11:05 编辑

1、在CMSIS-NN中,比较直观的有一个计算图,每个步骤对应一个函数。那么在pytorch是同等可以自己定义和设计的,下面的步骤就是说明如何建立计算图给后续进行移植和转换的。
2. 在pytorch中文本识别也是如下图,因为已经在CMSIS-NN中介绍过了,就不再对照卷积层,池化层和全连层的对比了。直接见图。

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F


  4. class Net(nn.Module):

  5.     def __init__(self):
  6.         super(Net, self).__init__()
  7.         # 1 input image channel, 6 output channels, 5x5 square convolution
  8.         # kernel
  9.         self.conv1 = nn.Conv2d(1, 6, 5)
  10.         self.conv2 = nn.Conv2d(6, 16, 5)
  11.         # an affine operation: y = Wx + b
  12.         self.fc1 = nn.Linear(16 * 5 * 5, 120)
  13.         self.fc2 = nn.Linear(120, 84)
  14.         self.fc3 = nn.Linear(84, 10)

  15.     def forward(self, x):
  16.         # Max pooling over a (2, 2) window
  17.         x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
  18.         # If the size is a square you can only specify a single number
  19.         x = F.max_pool2d(F.relu(self.conv2(x)), 2)
  20.         x = x.view(-1, self.num_flat_features(x))
  21.         x = F.relu(self.fc1(x))
  22.         x = F.relu(self.fc2(x))
  23.         x = self.fc3(x)
  24.         return x

  25.     def num_flat_features(self, x):
  26.         size = x.size()[1:]  # all dimensions except the batch dimension
  27.         num_features = 1
  28.         for s in size:
  29.             num_features *= s
  30.         return num_features
  31. net = Net()
  32. print(net)
复制代码



简单对照一下,
在CMSIS-NN中的第一层,
卷积函数 arm_convolve_HWC_q7_RGB();
激活函数arm_relu_q7();
池化层函数arm_maxpool_q7_HWC();


分别对应
  1. self.conv1 = nn.Conv2d(1, 6, 5)
  2. x = F.relu(self.fc1(x))
  3. x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
复制代码



实际执行也是按照如下顺序进行的
  1. input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d
  2.       -> view -> linear -> relu -> linear -> relu -> linear
  3.       -> MSELoss
  4.       -> loss
复制代码



3. 训练这模型
3.1 典型的训练图按照以下步骤创建和执行
-定义NN网络的学习参数,维数
- 输入数据集并且迭代
- 用这个网络处理输入的数据
- 计算这个初始weight的损失,也就是偏差
- 扩散并计算参数,把输出的数据迭代计算并回溯
- 用定义的weight函数,更新weight,常见的规则是weight = weight -learning_rate * gradient

3.2 数据集需要独立采集和录入
- 加载并预处理数据,即数据清洗
- 定义CNN网络,如上程序所示,其中的数据和层数不是固定的,是需要自己定义的当然,结果会有比较大的不同。
- 定义损失函数和优化器,这里引入一个交叉熵的概念
  1. import torch.optim as optim

  2. criterion = nn.CrossEntropyLoss()
  3. optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
复制代码


[size=1.125]- 用这个网络处理训练数据,一般是用循环语句,把全部数据集加载并计算一遍
  1. for i, data in enumerate(trainloader, 0):
  2.   ... ...
复制代码


- 测试这个网络和计算结果的正确率。这个和人工规则模拟的最大不同,是这个不会全对,在一定的正确率下就是最优。那么测试的原理就是用一组测试数据集(远远小于训练集),逐个输入新的weight和loss数值下的图中,得出新的分类结果。确定这个计算图的有效性。在不正确的参数选择和数据不干净的情况下,正确率会低至15%,这个和乱猜的结果差不多了,到底只有10个数字。
4. 这个是使用pytorch训练数据的简单流程,具体每个步骤都比较重要,也会影响这个结果。
分析这个主要是说明,这个训练是如何使得训练集用在GD32F3xxx上的。上述训练的结果,其实对应2步工作:
4.1 定义不同的layer,如
Conv2d(1, 6, 5)对应
arm_convolve_HWC_q7_fast(img_buffer2, CONV2_IM_DIM, CONV2_IM_CH, conv2_wt, CONV2_OUT_CH, CONV2_KER_DIM,
                           CONV2_PADDING, CONV2_STRIDE, conv2_bias, CONV2_BIAS_LSHIFT, CONV2_OUT_RSHIFT, img_buffer1,

                           CONV2_OUT_DIM, (q15_t *) col_buffer, NULL);
其中,需要对照的是
static q7_t conv1_wt[CONV1_IM_CH * CONV1_KER_DIM * CONV1_KER_DIM * CONV1_OUT_CH] = CONV1_WT;
等等,这里面已经定义了对应的维度的。
4.2 导入训练后的数据,
#include "arm_nnexamples_cifar10_weights.h"
这里的数据就是从这个pytorch模型中导出的数据weight数据。在torch.nn里面导出模型中就有。
4.3 可以看出,这个是一个不小的工程,需要很多新的概念,足够数据集,强大的计算平台和硬件。当然,这个训练集训练一次就可以反复用,对别人如果不解释也是看不明白的。
供大家参考指证。
下一步,就要开始搭建属于本项目的训练模型了。引入计算模型,可以避免自己使用dsp这样的编程和应用,直接用cmsisi-nn的引用就可以了。而且还可以尝试使用http://librosa.github.io/librosa/ 这样的专业的音频处理库进行音频的分割录入和处理,其实是降低了数据处理难度,从软件上辅助编程。

此帖出自GD32 MCU论坛
点赞 关注
 

回复
举报
您需要登录后才可以回帖 登录 | 注册

随便看看
查找数据手册?

EEWorld Datasheet 技术支持

相关文章 更多>>
关闭
站长推荐上一条 1/10 下一条

 
EEWorld订阅号

 
EEWorld服务号

 
汽车开发圈

About Us 关于我们 客户服务 联系方式 器件索引 网站地图 最新更新 手机版

站点相关: 国产芯 安防电子 汽车电子 手机便携 工业控制 家用电子 医疗电子 测试测量 网络通信 物联网

北京市海淀区中关村大街18号B座15层1530室 电话:(010)82350740 邮编:100190

电子工程世界版权所有 京B2-20211791 京ICP备10001474号-1 电信业务审批[2006]字第258号函 京公网安备 11010802033920号 Copyright © 2005-2024 EEWORLD.com.cn, Inc. All rights reserved
快速回复 返回顶部 返回列表