【Silicon Labs 开发套件评测】基于深度学习的tensorflow应用代码分析
[复制链接]
基于深度学习的tensorflow应用代码分析
1. Silab这个版本的SDK最有趣的更新就是把深度学习框架tensorflow引入嵌入式开发中。本次还提供了一个范例程序,tensorflow_lite_micro_helloworld,非常值得分析和对比一下。
2. 按照前述流程创建范例程序,然后编译并下载调试。
代码的过程是按照正弦函数输出波形,对应在PWM控制板载LED的亮度逐步变换,串口输出x,y的对应数值如下
3. 代码简析
整个逻辑仍然采用了标准的框架分析方法,在main.c启动任务app.c,在app.c中才启动最终的核心代码。这里是在tensorflow子目录项目下,
main.c
int main(void)
{
// Initialize Silicon Labs device, system, service(s) and protocol stack(s).
// Note that if the kernel is present, processing task(s) will be created by
// this call.
sl_system_init();
// Initialize the application. For example, create periodic timer(s) or
// task(s) if the kernel is present.
app_init();
#if defined(SL_CATALOG_KERNEL_PRESENT)
// Start the kernel. Task(s) created in app_init() will start running.
sl_system_kernel_start();
#else // SL_CATALOG_KERNEL_PRESENT
while (1) {
// Do not remove this call: Silicon Labs components process action routine
// must be called from the super loop.
sl_system_process_action();
app_process_action();
#if defined(SL_CATALOG_POWER_MANAGER_PRESENT)
// Let the CPU go to sleep if the system allows it.
sl_power_manager_sleep();
#endif
}
#endif // SL_CATALOG_KERNEL_PRESENT
}
app.c
void app_init(void)
{
#if defined(SL_CATALOG_MVP_PRESENT)
sli_mvp_init_t init = { .use_dma = false };
sli_mvp_init(&init);
#endif
tensorflow_lite_micro_helloworld_init();
}
/***************************************************************************//**
* App ticking function.
******************************************************************************/
void app_process_action(void)
{
tensorflow_lite_micro_helloworld_process_action();
}
void tensorflow_lite_micro_helloworld_init(void)
{
sl_pwm_start(&sl_pwm_led0);
setup();
}
/***************************************************************************//**
* Ticking function.
******************************************************************************/
void tensorflow_lite_micro_helloworld_process_action(void)
{
// Delay between model inferences to simplify visible output
sl_sleeptimer_delay_millisecond(100);
loop();
}
核心代码都在main_functions.c
// The name of this function is important for Arduino compatibility.
void setup() {
// Set up logging. Google style is to avoid globals or statics because of
// lifetime uncertainty, but since this has a trivial destructor it's okay.
// NOLINTNEXTLINE(runtime-global-variables)
static tflite::MicroErrorReporter micro_error_reporter;
error_reporter = µ_error_reporter;
// Map the model into a usable data structure. This doesn't involve any
// copying or parsing, it's a very lightweight operation.
model = tflite::GetModel(g_model);
if (model->version() != TFLITE_SCHEMA_VERSION) {
TF_LITE_REPORT_ERROR(error_reporter,
"Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
return;
}
// This pulls in all the operation implementations we need.
// NOLINTNEXTLINE(runtime-global-variables)
static tflite::AllOpsResolver resolver;
// Build an interpreter to run the model with.
static tflite::MicroInterpreter static_interpreter(
model, resolver, tensor_arena, kTensorArenaSize, error_reporter);
interpreter = &static_interpreter;
// Allocate memory from the tensor_arena for the model's tensors.
TfLiteStatus allocate_status = interpreter->AllocateTensors();
if (allocate_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "AllocateTensors() failed");
return;
}
// Obtain pointers to the model's input and output tensors.
input = interpreter->input(0);
output = interpreter->output(0);
// Keep track of how many inferences we have performed.
inference_count = 0;
}
// The name of this function is important for Arduino compatibility.
void loop() {
// Calculate an x value to feed into the model. We compare the current
// inference_count to the number of inferences per cycle to determine
// our position within the range of possible x values the model was
// trained on, and use this to calculate a value.
float position = static_cast<float>(inference_count) /
static_cast<float>(kInferencesPerCycle);
float x_val = position * kXrange;
// Place our calculated x value in the model's input tensor
input->data.f[0] = x_val;
// Run inference, and report any error
TfLiteStatus invoke_status = interpreter->Invoke();
if (invoke_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed on x_val: %f\n",
static_cast<double>(x_val));
return;
}
// Read the predicted y value from the model's output tensor
float y_val = output->data.f[0];
// Output the results. A custom HandleOutput function can be implemented
// for each supported hardware target.
HandleOutput(error_reporter, x_val, y_val);
// Increment the inference_counter, and reset it if we have reached
// the total number per cycle
inference_count += 1;
if (inference_count >= kInferencesPerCycle) inference_count = 0;
}
这里其实直接调用了tensorflow的API
const tflite::Model* model = nullptr;
tflite::MicroInterpreter* interpreter = nullptr;
TfLiteTensor* input = nullptr;
TfLiteTensor* output = nullptr;
首先读取计算出的模型
model = tflite::GetModel(g_model);
然后,导入输入数据,
input->data.f[0] = x_val;
启动推理
TfLiteStatus invoke_status = interpreter->Invoke();
直接读出输出,
float y_val = output->data.f[0];
按照x,y的数值进行串口输出和led亮度控制,
HandleOutput(error_reporter, x_val, y_val);
4. 深度计算的开发应用流程。
4.1 先读取一下model的构型
#include "tensorflow/lite/micro/examples/hello_world/model.h"
// Keep model aligned to 8 bytes to guarantee aligned 64-bit accesses.
alignas(8) const unsigned char g_model[] = {
0x1c, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c, 0x33, 0x00, 0x00, 0x12, 0x00,
0x1c, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x14, 0x00,
0x00, 0x00, 0x18, 0x00, 0x12, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
0x60, 0x09, 0x00, 0x00, 0xa8, 0x02, 0x00, 0x00, 0x90, 0x02, 0x00, 0x00,
0x3c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
........
}
这个是谷歌提供的tensorflow-lite的标准模式,通常的计算方式,是生成一组深度学习模型的参数数据,但是tensorlow-lite压缩成整形数据,但是对于计算精度损失不大。压缩工具可以从谷歌网站上下载和使用。
4.2 这个首先,需要离线训练模型,搭建深度学习的模型,然后用谷歌的AI引擎训练,收敛后压缩模型,并尝试嵌入到开发板的代码中,不少模型是数M起步,一般的嵌入式开发用不了。
这个过程还是挺痛苦的,这个范例已经完成了这个过程,直接把model模型转换成了数据格式,方便评测。
5. 深度学习的移植开发。
虽然这个范例程序使用起来很容易,但是,深度学习的移植是非常困难的,至少模型的收敛和剪裁就需要非常多的经验。学习曲线比较陡。在轻型的系统上,计算性能有限,速度也比较慢,如果能用常规的算法,比深度学习的循环计算效率要高不少。
不过,这个也是一个非常有意义的范例,可以拓展项目开发能力和想象力。
|