854|10

47

帖子

1

TA的资源

一粒金砂(中级)

楼主
 

#AI挑战营第二站#onnx模型转rknn模型 [复制链接]

本帖最后由 空耳- 于 2024-5-15 14:00 编辑

1.环境搭建

1.1 conda环境安装与配置

    Conda是一个开源的软件包管理系统和环境管理系统,它可以用于安装、管理和升级软件 包和依赖项,我们这里使用conda的目的只是构建一个虚拟环境,所以选择轻量话的 miniconda。miniconda的官方链接如下所示: https://docs.conda.io/en/latest/miniconda.html

    下载安装之后,换源,提高下载速度。

pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple

1.2 rknn环境创建

 

在rknn-toolkit git仓库:

链接已隐藏,如需查看请登录或者注册
 的release里下载1.4.0压缩包,解压后能看见多个平台的whl文件,这里因为我很早之前做毕设时使用过rknn,那时候版本还比较低,这里我也同样使用的是老版本1.4.0。

这里使用conda创建RKNN虚拟环境

conda create -n rknn python=3.8

命令安装numpy

pip install numpy==1.16.6 -i https://pypi.tuna.tsinghua.edu.cn/simple

然后安装瑞芯微提供的requirements_cp38-1.4.0.txt 文件所记录的依赖包

pip install -r requirements_cp38-1.4.0.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

最后安装瑞芯微提供的rknn_toolkit2-1.4.0 版本的软件包

pip install rknn_toolkit2-1.4.0_22dcfef4-cp38-cp38-linux_x86_64.whl -i https://pypi.tuna.tsinghua.edu.cn/simple

在这会提示这样一个东西

开始我没去在意,后边模型转换时报错了,通过查阅资料发现原因是:setuptool版本过高

从新安装

pip install setuptools==49.6.0

模型转换可以进行,但是哪个问题还在。我这里选择忽视,不影响正常使用。

至此,RKNN的虚拟环境就创建完成了。

2.模型转换

在模型转换时遇到过两个问题

(1)onnx模型版本要求为12

(2)onnx输入层的张量维度不对

为了解决以是两个问题我从新对手写数字模型进行了训练,以下为代码部分:

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
import torch.onnx
import torch.nn as nn
import torch.nn.functional as F

# 定义神经网络模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(32 * 7 * 7, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 32 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 获取数据加载器
def get_data_loader(is_train):
    to_tensor = transforms.Compose([transforms.ToTensor()])
    data_set = MNIST("", train=is_train, transform=to_tensor, download=True)
    return DataLoader(data_set, batch_size=15, shuffle=True)

# 评估模型准确率
def evaluate(test_data, net):
    n_correct = 0
    n_total = 0
    with torch.no_grad():
        for (x, y) in test_data:
            outputs = net.forward(x)
            for i, output in enumerate(outputs):
                if torch.argmax(output, dim=0) == y[i]:
                    n_correct += 1
                n_total += 1
    return n_correct / n_total

# 测试模型
def test(model, test_data):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_data:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('测试集准确率: %7.2f %%' % (100 * correct / total))

# 主函数
def main():
    # 获取训练集和测试集数据加载器
    train_data = get_data_loader(is_train=True)
    test_data = get_data_loader(is_train=False)
    net = Net()

    # 初始准确率
    print("initial accuracy:", evaluate(test_data, net))
    
    # 训练神经网络模型
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    for epoch in range(10):
        for (x, y) in train_data:
            net.zero_grad()
            output = net.forward(x)
            loss = torch.nn.functional.cross_entropy(output, y)
            loss.backward()
            optimizer.step()
        print("epoch", epoch, "accuracy:", evaluate(test_data, net))

    # 可视化预测结果
    for (n, (x, _)) in enumerate(test_data):
        if n > 3:
            break
        predict = torch.argmax(net.forward(x))
        plt.figure(n)
        plt.imshow(x[0].view(1, 28, 28).squeeze())
        plt.title("prediction: " + str(int(predict)))

    # 保存模型和模型状态字典
    torch.save(net, "mnist_model.pth")
    test(net, test_data)
    torch.save(net.state_dict(), "mnist_model_state_dict.pth")

    # 创建虚拟输入并导出ONNX模型(设置opset版本为12)
    dummy_input = torch.randn(1, 1, 28, 28)
    torch.onnx.export(net, dummy_input, "model.onnx", opset_version=12)

if __name__ == "__main__":
    main()

从新训练后的模型准确度:

 

新模型结构

 

rknn模型转换

from rknn.api import RKNN
# Create RKNN object
rknn = RKNN(verbose=True)
# pre-process config
rknn.config(target_platform='rv1106', mean_values=[128], std_values=[128])
# Load model
rknn.load_onnx(model='./model.onnx')
rknn.build(do_quantization=True, dataset='./data.txt')  # 构建RKNN模型,可选参数量化
rknn.export_rknn('./mnist.rknn')  # 导出RKNN模型文件
# 释放 RKNN 对象
rknn.release()

onnx2rknn_1.py

model.onnx

mnist.rknn

 

 

 

mnist.rknn

130.84 KB, 下载次数: 0

model.onnx

414.67 KB, 下载次数: 1

onnx2rknn_1.py

412 Bytes, 下载次数: 0

最新回复

对于freebsd老哥而言,这就不淡定了。 刚有点思路写到脸庞这里,思路飘逸了 好不容易稳住心神,继续写代码,突然写到丝袜这里了,哎呀,思路又飘走了   所以,空耳大哥,你是为他好呢,还是害他呢,说不好了   详情 回复 发表于 2024-5-17 09:11
点赞 关注

回复
举报

7608

帖子

2

TA的资源

五彩晶圆(高级)

沙发
 

谢谢分享,头像很诱人!

点评

[attachimg]809003[/attachimg]  这是我的壁纸,送你    详情 回复 发表于 2024-5-15 22:32
啊哈。。。。。。这个老哥你也信呀。。。。。。。。 这不是都是Virtual的吗  详情 回复 发表于 2024-5-15 20:59
 
个人签名

默认摸鱼,再摸鱼。2022、9、28

 

回复

1364

帖子

1

TA的资源

五彩晶圆(初级)

板凳
 
freebsder 发表于 2024-5-15 16:27 谢谢分享,头像很诱人!

啊哈。。。。。。这个老哥你也信呀。。。。。。。。

这不是都是Virtual的吗

点评

美丽的东西部分physical还是virtual,哈哈哈哈  详情 回复 发表于 2024-5-16 09:26
 
 
 

回复

47

帖子

1

TA的资源

一粒金砂(中级)

4
 
freebsder 发表于 2024-5-15 16:27 谢谢分享,头像很诱人!

  这是我的壁纸,送你

 

点评

啊哈。。。。。。。。。。楼主大方的!! 话说看到大图,确实不错,freebsder老哥目光如炬呀! ,  详情 回复 发表于 2024-5-16 12:09
 
 
 

回复

7608

帖子

2

TA的资源

五彩晶圆(高级)

5
 
hellokitty_bean 发表于 2024-5-15 20:59 啊哈。。。。。。这个老哥你也信呀。。。。。。。。 这不是都是Virtual的吗

美丽的东西部分physical还是virtual,哈哈哈哈

点评

老哥,马斯克亲吻了猫女机器人,还和她跳舞。。。。 那皮肤看起来是非常逼真,确实从Virtual走向了Reality 慢慢确实界限越来越模糊了  详情 回复 发表于 2024-5-16 12:06
 
个人签名

默认摸鱼,再摸鱼。2022、9、28

 
 

回复

7608

帖子

2

TA的资源

五彩晶圆(高级)

6
 
空耳- 发表于 2024-5-15 22:32   这是我的壁纸,送你  

 
个人签名

默认摸鱼,再摸鱼。2022、9、28

 
 

回复

1364

帖子

1

TA的资源

五彩晶圆(初级)

7
 
freebsder 发表于 2024-5-16 09:26 美丽的东西部分physical还是virtual,哈哈哈哈

老哥,马斯克亲吻了猫女机器人,还和她跳舞。。。。

那皮肤看起来是非常逼真,确实从Virtual走向了Reality

慢慢确实界限越来越模糊了

 
 
 

回复

1364

帖子

1

TA的资源

五彩晶圆(初级)

8
 
空耳- 发表于 2024-5-15 22:32   这是我的壁纸,送你  

啊哈。。。。。。。。。。楼主大方的!!

话说看到大图,确实不错,freebsder老哥目光如炬呀!

点评

[attachimg]809324[/attachimg]  做壁纸nice的很  详情 回复 发表于 2024-5-16 17:36
 
 
 

回复

47

帖子

1

TA的资源

一粒金砂(中级)

9
 
hellokitty_bean 发表于 2024-5-16 12:09 啊哈。。。。。。。。。。楼主大方的!! 话说看到大图,确实不错,freebsder老哥目光如炬呀! , ...

  做壁纸nice的很

 
 
 

回复

1364

帖子

1

TA的资源

五彩晶圆(初级)

10
 
空耳- 发表于 2024-5-16 17:36   做壁纸nice的很

对于freebsd老哥而言,这就不淡定了。

刚有点思路写到脸庞这里,思路飘逸了

好不容易稳住心神,继续写代码,突然写到丝袜这里了,哎呀,思路又飘走了

 

所以,空耳大哥,你是为他好呢,还是害他呢,说不好了

点评

劳逸结合,劳逸结合  详情 回复 发表于 2024-5-17 19:36
 
 
 

回复

47

帖子

1

TA的资源

一粒金砂(中级)

11
 
hellokitty_bean 发表于 2024-5-17 09:11 对于freebsd老哥而言,这就不淡定了。 刚有点思路写到脸庞这里,思路飘逸了 好不容易稳住心神,继 ...

劳逸结合,劳逸结合

 
 
 

回复
您需要登录后才可以回帖 登录 | 注册

随便看看
查找数据手册?

EEWorld Datasheet 技术支持

相关文章 更多>>
关闭
站长推荐上一条 1/9 下一条

 
EEWorld订阅号

 
EEWorld服务号

 
汽车开发圈

About Us 关于我们 客户服务 联系方式 器件索引 网站地图 最新更新 手机版

站点相关: 国产芯 安防电子 汽车电子 手机便携 工业控制 家用电子 医疗电子 测试测量 网络通信 物联网

北京市海淀区中关村大街18号B座15层1530室 电话:(010)82350740 邮编:100190

电子工程世界版权所有 京B2-20211791 京ICP备10001474号-1 电信业务审批[2006]字第258号函 京公网安备 11010802033920号 Copyright © 2005-2025 EEWORLD.com.cn, Inc. All rights reserved
快速回复 返回顶部 返回列表