本次测评将尝试搭建手写数字识别的第一部分:pytorch框架的搭建。当然百度的EasyDL可以进行神经网络的搭建和部署,但是pytorch作为当下较为流行的神经网络框架也是很值得我们学习的。LetNet网络一直被作为神经网络的当中的“Hello World”,本次实验也已利用PYNQ框架实现手写数字识别作为神经网络加速的入门教程。首先简要介绍一下LetNet网络:
其首先将输入大小为32*32的单通道图片,然后通过若干卷积层和全连接层得到最终的分量向量,从而得到最终的分类结果。可以看到LetNet网络结构非常的简单明了,然后使用pytorch框架对网络进行了搭建和训练,首先导入相关库函数:
然后加载pytorch自带的MNIST数据集,并将其可视化:
然后定义网络结构,并且使用优化器进行训练,可以看到,得到的训练结果如下所示:
准确率在98%到99%之间波动,效果良好,然后对权重进行保存。保存的pth文件,pynq难以直接读取,所以可以将其转换为其他形式的文件,如bin,txt等。这里为了让参数更加直观,将其保存为了txt文件,由于pth文件其为字典形式的文件,所以需要编写一个脚本文件,打开第一个卷积层的权重可以看到其有150个数据,符合6*5*5=150的预期,LetNet的网络训练部分告一段落。