- 2025-01-17
-
回复了主题帖:
《大语言模型开发:用开源模型开发本地系统》阅读报告(3)
风尘流沙 发表于 2025-1-17 17:30
楼主,您好,大语言模型开发:用开源模型开发本地系统
有中文版下载链接吗??
谢谢您
这个是论坛测评活动申请的试读,是实体书来的,不知道有没有电子版。
- 2025-01-15
-
回复了主题帖:
《大语言模型开发:用开源模型开发本地系统》阅读报告(2)
ljg2np 发表于 2024-12-26 11:28
个人感觉transformer模型在NLP得到成功应用之后,逐渐扩大其应用范围,用在电池周期充放电数据的分析研 ...
是的,因为本质上都是时序推理的问题。电池的数据相比于自然语言,还更加具有物理意义上的关联性。
-
发表了主题帖:
《大语言模型开发:用开源模型开发本地系统》阅读报告(3)
在前两周阅读《大语言模型开发》这本书时,发现书中提供了使用Llama2模型进行训练和微调的具体案例,具有很强的实践参考价值。书中说明Llama-2-7b模型的代码可通过在Hugging Face网站申请获得,于是我们在该网站提交了申请,认为其开源性质应该能够顺利使用。但由于期末考核任务繁重,一直没有关注申请进度。直到前两天准备更新报告时,才发现申请并未通过。
因此,本次报告我们只能基于书中的理论知识,纸上谈兵地向各位老师汇报我们的学习成果。另外,不知道其他朋友是否成功申请过,或者是否遇到过类似的情况,也请大家分享一下经验。
我们将按照书中介绍的模型训练过程对其中涉及的关键模块逐一介绍,整个训练流程的组成部分如下图所示:
1. 预训练模型
在模型训练过程中,预训练模型是不可或缺的一部分。我们通常不会从头开始构建一个大语言模型,因为这需要巨大的计算资源和数据。因此,我们主要通过Hugging Face的transformers库来引入预训练模型。transformers库是一个非常强大的工具,它提供了数以千计的预训练模型,支持多种语言的文本分类、信息抽取、问答、摘要、翻译、文本生成等任务。它提供了便于快速下载和使用的API,让你可以把预训练模型用在给定文本、在本地的数据集上微调。
在我们的案例中,我们使用了AutoModelForCausalLM类来加载预训练的因果语言模型(Causal Language Model)。这个类允许我们从预训练模型库中加载模型,并支持多种参数以自定义加载过程。具体到我们的代码示例中,我们使用了from_pretrained方法来加载'meta-llama/Llama-2-7b-hf'模型。这个方法通过from_pretrained字段配置我们需要的预训练模型,并且可以通过quantization_config参数来选择模型的量化精度。
model = AutoModelForCausalLM.from_pretrained(
'meta-llama/Llama-2-7b-hf',
cache_dir="/kaggle/working/",
trusr_remote_code=True,
device_map='auto',
quantization_config=bnb_config,
)
2. 量化技术
有关模型的量化技术,书中也有所介绍。量化技术是模型优化的重要手段之一,它通过减少模型参数的位宽来减少存储空间和加速计算,同时降低能耗。在我们的代码中,我们使用了load_in_4bit参数来决定是否使用4位量化精度的模型。
由此,我们便引入了第二个比较重要的成分,就是配置量化精度的BitsAndBytesConfig模型,用于设置模型的量化参数。通过这个类,我们可以灵活地配置模型的量化精度、分位数等超参数,从而在减少模型存储空间和加速计算的同时,尽量保持模型的性能。
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_compute_dtype='float16',
bnb_4bit_use_double_quant=True,
)
3. 分词器
除此之外,我们还需要使用分词器来处理自然语言任务。分词器(Tokenizer)在自然语言处理(NLP)中是一个关键组件,负责将文本字符串转换成模型可以处理的结构化数据形式,通常是将文本切分成“tokens”或单词、短语、子词等单位。这些tokens是模型理解文本的基础。分词器的类型和复杂性可以根据任务需求而变化,从简单的基于空格的分割到更复杂的基于规则或机器学习的分词方法。
分词器的工作流程通常包括两个关键步骤:
文本分解:将原始文本分割成更细的粒度单元,这些单元可以是单词级别、子词级别,甚至是字符级别。这一步的目标是将文本分解为可以被模型理解并处理的基本单元。
编码映射:将这些基本单元转换为模型可以理解的数值形式,最常见的形式是整数序列。这样,我们就可以将这些数值输入到模型中,让模型进行学习和预测。
此外,分词器还会在序列的开始和结束添加特殊标记,如BERT中的CLS和SEP用于特定任务的序列分类或区分输入片段。有时为了确保输入序列的一致长度,分词器可以对较短的序列进行填充,对较长的序列进行截断。
我们主要是通过AutoTokenizer这个类来进行调用。AutoTokenizer 是 Hugging Face transformers 库中的一个实用工具类,主要用于自动加载与特定预训练模型相匹配的分词器。AutoTokenizer 能够自动识别出模型应该使用哪种类型的分词器,并加载对应的配置文件和词汇表。这使得开发者不需要手动指定分词器类型,简化了代码编写过程。一旦分词器被正确加载,它可以对输入文本进行预处理,包括分词、添加特殊标记、编码转换、填充和截断等操作。
tokenizer = AutoTokenizer.from_pretrained(
'meta-llama/Llama-2-7b-hf',
trust_remote_code=True,
)
tokenizer.pad_token = tokenizer.eos_token
我们也可以输入一些文本内容查看其分词的效果:
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
labels = torch.tensor([1]).unsqueeze(0)
outputs = model(**inputs, labels=labels)
loss = outputs.loss
logits = outputs.logits
4. 训练方法
由于大模型参数量众多,我们即便在有预训练模型的基础上也不会去大量训练所有的模型参数,只会对其中部分进行微调。这里我们就会使用到LoRA这个方法,其只会对少量额外的模型参数进行训练,示意图如下:
LoRa(Low-Rank Adaptation)是一种用于对预训练大语言模型进行微调的高效方法。它通过引入两个低秩矩阵A和B来减少需要训练的参数数量,从而在保持模型性能的同时,显著降低训练成本。具体来说,如果原始参数矩阵W的大小为d×d,LoRA 会引入两个大小分别为d×r和r×d的矩阵,其中r是一个远小于d的秩。通过这种方式,LoRA大幅减少了需要训练的参数数量。
我们会从peft这个库来调用LoRa方法,peft(Parameter-Efficient Fine-Tuning)是一个用于高效微调大语言模型的库。它提供了多种参数高效的微调方法,包括 LoRa、Prefix-Tuning 等。peft 库的主要目标是通过减少需要训练的参数数量,提高微调的效率和效果。
以下是书中使用的LoRa的参数配置,其中lora_target_modules这个参数就是指出我们会对Llama模型的哪些部分进行训练,与上面的示意图也是对应的。
lora_alpha=16,
lora_dropout=0.1,
lora_r=16,
lora_bias='all',
model_type='llama',
lora_target_modules = [
'q_proj',
'k_proj',
'v_proj',
'o_proj',
'gate_proj',
'up_proj',
'down_proj'
]
我们可以通过LoRaConfig来应用这些参数设置:
peft_config = LoraConfig(
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
r=lora_r,
bias=lora_bias,
task_type='CAUSAL_LM',
inference_mode=False,
target_modules=lora_target_modules,
)
model = get_peft_model(model, peft_config)
5. 增量模型训练
接下来我们就可以使用LoRA方法来微调我们的模型参数,我们主要是通过调用tansforemres库中的TrainingArguments和Trainer这两个类。它们是Hugging Face中配置训练参数和用于训练和评估模型的类,封装了训练循环,支持多种功能,如分布式训练、混合精度训练、早停(early stopping)、学习率调度等。有关训练参数的配置主要包括:
output_dir = 'your_dir'
optim_type = 'adamw_8bit'
learning_rate = 0.0005
weight_decay = 0.002
per_device_train_batch_size = 1
per_device_eval_batch_size = 1
gradient_accumulation_steps = 16
warmup_steps = 5
save_steps = 100
logging_steps = 100
然后我们调用TraningArguments来对这些应用这些参数设置:
traing_args = TrainingArguments(
output_dir=output_dir,
evaluation_strategy='epoch',
optim=optim_type,
learning_rate=learning_rate,
weight_decay=weight_decay,
per_device_train_batch_size=per_device_train_batch_size,
per_device_eval_batch_size=per_device_eval_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
do_train=True,
warmup_steps=warmup_steps,
save_steps=save_steps,
logging_steps=logging_steps,
)
trainer=Trainer(
model=model,
args=training_args,
train_dataset=data_train,
eval_dataset=data_test,
tokenizer=tokenizer,
)
可以看到这些参数的命名也还是比较一目了然的。最后使用如下代码就可以使用LoRA进行训练了。
trainer.train()
6. 模型合并
在使用 LoRa 方法进行微调后,我们通常需要将微调后的增量模型参数合并到原始的预训练模型中。这样,我们可以得到一个完整的、微调后的模型,才可以用于推理或其他任务。整个完整的训练流程我们总结如下:
加载预训练模型和分词器:首先,我们需要加载原始的预训练模型和分词器。
应用 LoRa 配置:将 LoRa 配置应用到预训练模型中,进行微调。
训练模型:使用训练数据对模型进行微调。
合并模型:将微调后的 LoRa 增量参数合并到原始模型中,得到一个完整的微调模型。
示意图如下:
合并模型的代码如下:
model = get_peft_model(model, lora_config)
model.save_pretrained('path_to_save_lora_model')
model = PeftModel.from_pretrained(model, 'path_to_save_lora_model')
model = model.merge_and_unload()
model.save_pretrained('path_to_save_merged_model')
tokenizer.save_pretrained('path_to_save_merged_model')
merged_model = AutoModelForCausalLM.from_pretrained('path_to_save_merged_model', trust_remote_code=True)
merged_tokenizer = AutoTokenizer.from_pretrained('path_to_save_merged_model', trust_remote_code=True)
通过上述步骤,我们成功完成了基于预训练模型和LoRa方法的大模型训练尝试。
7. 全书阅读总结
本次《大语言模型开发》的试读活动已接近尾声。本书不仅提供了详尽的理论阐述,还辅以实用的代码示例,极大地助力了我们的学习与实践,我愿意给予高度评价。然而,正如古语所言,“纸上得来终觉浅,绝知此事要躬行”,在实际操作过程中,我们深知还会遭遇诸多挑战与问题。我们也会再去积极获得Llama2的调用权限,或获取更多开源项目资源,以便进一步深化我们的研究与实践。活动结束后,我们仍将不懈努力,持续探索与学习,并及时向各位老师汇报我们的最新进展。
- 2024-12-23
-
发表了主题帖:
《大语言模型开发:用开源模型开发本地系统》阅读报告(2)
按照我们的阅读计划,本次报告向各位老师汇报我们在Transformer架构领域的实践成果。在上一次的分享中我们解读了书中的代码,因为时间关系没有具体尝试。今天,我们将填补这一空白,具体展示Transformer在实际项目中的应用。
回顾过往,我们在《人工智能实践教程》中利用较为简单的多层卷积神经网络(CNN),完成了一个基于电池电压和电流特征预测电池当前周期容量的小型演示项目。在此基础上,我们想进一步探索了Transformer模型,利用其在时序信号上下文推理和特征提取方面的优势,尝试预测电池未来容量的变化趋势。
我们的核心原理是利用Transformer对上下文关联性的推理能力,将电池容量变化趋势视为具有前后关联的序列。通过输入一小段历史数据,模型能够编码这些信息,并预测下一周期的容量。随后,这些预测结果将作为新的输入,推动模型进行更进一步的预测。通过这种方式,我们可以不断扩展输出序列,生成具有预见性的输出结果。
在这一领域,已有众多研究工作为我们提供了宝贵的参考。本次分享,我们将以一篇经典的论文为蓝本,尝试复现这一过程。论文的链接和相关代码已附在报告末尾可以参考:【GitHub - XiuzeZhou/RUL: Transformer Network for Remaining Useful Life Prediction of Lithium-Ion Batteries】
1. 数据来源
我们选取了马里兰大学CALCE(Center for Advanced Life Cycle Engineering)提供的在线数据集中的CS2系列电池周期充放电数据作为我们的研究对象。这些数据可以通过访问CALCE的官方网站【Battery Research Data | Center for Advanced Life Cycle Engineering (umd.edu)】进行下载。
最终可以得到锂电池的容量与周期数的关系。
2. 网络架构
接下来,我们将详细介绍文章中实现电池容量预测功能所采用的具体网络架构。我们的模型由两部分组成:前级的去噪编码器和后级的Transformer解码器。幸运的是,论文作者已经将代码开源,我们在此基础上进行了一些修改,以测试不同参数组合的性能。因此,我们不会在文章中粘贴代码,而是提供一些关键的截图。感兴趣的读者可以访问作者的GitHub页面自行下载代码。
首先,我们来看编码器部分。因为电池容量时时间序列,并不涉及自然语言处理的词嵌入,所以这里的编码器本质上由两层全连接网络构成。在输入第一层网络之前,我们对输入信号叠加噪声信号,并进行线性变换和非线性激活,得到一个中间变量,该变量作为去噪编码的信号。然后,我们将这个中间变量再次进行线性变换,以得到输出。
因此,我们的前向传播计算流程如下:input + noise -> FC1 + ReLU -> FC2 -> output。在这里,"mask"指的是叠加噪声的部分,"encode"表示第一个全连接层和非线性激活,得到去噪编码的结果,而"decode"代表第二个全连接层,输出重新变化回来的输入数据。在训练过程中,我们的目标是最小化输入和输出之间的差距,从而实现去噪编码和解码的过程。中间去噪后的向量将作为Transformer解码器的输入。
接下来是位置编码,它用于标记位置信息。具体编码方式可以参考典型的Transformer论文,即通过正弦和余弦函数的编码过程,然后将计算得到的位置信息叠加在输入上。
最后,文章设计了Transformer解码器。根据经典的Transformer架构,解码器本质上是由多个注意力机制和残差网络组成的解码模块的串联(在我们的教程中,只使用了单个模块)。虽然这里直接调用的是encoder模块,但可以看到,由于Transformer架构中编码器和解码器的模块结构是类似的,我们直接调用了PyTorch库中的TransformerEncoderLayer。
将上述所有过程串联起来,前向计算的过程就变得非常清晰:首先分别计算去噪声的编码和位置编码,然后将它们相加作为Transformer的输入。调用Transformer中的编码器输出结果后,再进行一次全连接操作,以得到最终的预测结果。
这样的设计能够很好地对应文章中给出的网络架构。
3. 训练过程
在定义好网络架构之后,我们利用PyTorch库中提供的自动微分功能来执行反向传播,无需手动编写复杂的梯度计算代码。在我们的实验中,选择了Adam优化器来更新网络权重,这是一种广泛使用的优化算法,因其自适应学习率的特性而受到青睐。至于损失函数,我们选择了均方误差(MSELoss),它衡量预测结果与实际目标之间的差异,是回归问题中常用的损失函数。
训练的具体流程如下:
模型验证:在训练过程中,我们定期使用测试集中预先划分的最初window_size个数据点作为起点,逐个生成后续的容量衰退趋势,并与实际测试集数据进行对比,以此来评估模型的性能。
训练效率:得益于模型参数量适中以及架构的简洁性,训练过程的效率相当令人满意。参数量的控制和合理的架构设计,使得模型训练既高效又稳定。
4. 测试结果
在完成了模型的训练和验证之后,我们得到了一些初步的测试结果。虽然文章中提供的量化指标看似乐观,但实际上我们的预测结果并不如预期理想。
具体来说,我们注意到了几个关键问题:
量化指标的局限性:文章中提到的指标,如通过预测的剩余使用寿命(RUL)与实际RUL相减后除以实际RUL得到的比值,虽然数值上看似很小,但实际上误差可能高达40-50个周期。这表明我们的模型在实际应用中可能存在较大的偏差。
预测曲线与实际曲线的对比:通过对比预测的容量变化曲线和实际曲线,我们发现两者之间的误差相当明显。尤其是在选择70%的标称容量作为截止点时,虽然两者看似重合,但如果选择80%的标称容量作为截止点,误差则显著增加。
窗口长度的影响:我们初步考虑窗口长度可能过短,因为我们仅使用了64个数据点进行预测。增加窗口长度后,虽然模型性能有所改善,但普适性较差,部分曲线的预测效果有所提升,而其他曲线的预测效果仍然不尽人意。此外,位置编码的运算规则限制了输入维度,只能调整为2的次方。
注意力头和隐层规模的影响:增加注意力头对模型性能的提升并不明显。
增加隐层的规模以后,效果有了显著改善。然而,我们也意识到单纯增加隐层规模可能会导致过拟合。
后续改进方向:我们认为后续的提升空间主要在于调整网络架构和改进训练算法。教程中还涉及了优化超参数的部分,我们计划进一步探索其他参数的优化空间。
总的来说,虽然我们能够复现文章中的内容,但效果并不理想,且文章中采用的衡量指标存在一定的取巧。基于当前的网络结构,我们将继续探索可能的改进方法,以期达到更准确的预测效果。
5. 总结
我们尝试通过构建包含去噪编码器和Transformer解码器的网络架构,旨在捕捉电池容量变化的复杂模式,并预测其未来趋势。
在实践过程中,我们遇到了一些挑战。尽管初步的量化指标显示了乐观的结果,但深入分析后发现,模型的预测精度并不如预期。我们识别出几个关键问题,包括量化指标的局限性、预测曲线与实际曲线之间的显著误差、窗口长度的选择、以及模型结构对预测性能的影响。
针对这些问题,我们进行了一系列的调整和优化。我们尝试增加窗口长度、调整注意力头的数量、以及扩大隐层规模,以期提高模型的预测能力。虽然这些调整在一定程度上改善了模型性能,但我们认识到,为了实现更准确的预测,还需要在网络架构和训练算法上进行更深入的探索和改进。
总结来说,虽然我们能够复现文章中的内容,但实际效果表明,我们还有很长的路要走。我们将继续学习,探索优化超参数的方法,并尝试不同的网络结构,以期在未来的研究中取得更好的成果。我们相信,通过不断的努力和创新,我们能够克服当前的挑战,实现更精准的电池容量预测。
附上文章的引用信息:
@article{chen2022transformer,
title={Transformer network for remaining useful life prediction of lithium-ion batteries},
author={Chen, Daoquan and Hong, Weicong and Zhou, Xiuze},
journal={Ieee Access},
volume={10},
pages={19621--19628},
year={2022},
publisher={IEEE}
}
- 2024-12-07
-
发表了主题帖:
《大语言模型开发:用开源模型开发本地系统》阅读报告(1)
本帖最后由 Aclicee 于 2024-12-7 12:28 编辑
我们于在11月24日收到的《大语言模型开发:用开源模型开发本地系统》一书,书籍包装完整,印刷清晰,装帧精美,给人以良好的第一印象。
按照既定计划,我们今天向各位老师汇报对书本第一部分的阅读报告,这部分内容聚焦于Transformer模型的详细构成。由于时间限制,我们尚未将Transformer模型应用于具体任务的操作实践,敬请期待我们在后续报告中对实际操作的深入探讨。
在深入剖析Transformer模型之前,本书对Pytorch和深度学习的基础方法进行了全面的介绍。鉴于我们在之前的《人工智能实践教程》测评中已经详细汇报过这部分内容,并且论坛中已有众多专家老师的宝贵分享,我们将不再对此进行重复讨论。
Transformer架构以其独特的设计在自然语言处理领域占据着举足轻重的地位。一个典型的Transformer模型由以下几个核心部分组成:自注意力机制、多头自注意力、位置编码、前馈神经网络、归一化层以及残差连接。接下来的报告中,我们将按照Transformer的工作流程,逐一深入学习这些组成部分。
在这些组件中,位置编码和注意力机制尤为关键。位置编码的引入是为了使Transformer能够理解输入序列中各元素之间的前后关系,而注意力机制则赋予模型聚焦于更为重要的信息维度的能力,从而深刻理解输入数据的相关性和影响。这两个模块是Transformer模型能够准确捕捉序列特征并进行有效预测的关键。
1. 词嵌入
词嵌入是自然语言处理领域中的一项关键技术,它通过无监督学习算法从大规模文本数据中提取词汇的向量表示。这种表示能够捕捉词汇之间的语义和句法关系,为后续的模型训练和语言理解任务提供基础。
在Transformer模型中,词嵌入是与模型参数一同训练得到的,这样可以在特定任务的上下文中优化词嵌入,以更好地适应任务需求。
Transformer模型的词嵌入由两部分组成:标记嵌入(Token Embeddings)和位置编码(Positional Encoding)。
标记嵌入:这部分将输入序列中的每个标记(token)映射到一个高维空间中的向量。通过这种方式,语义上相似的标记在向量空间中的距离更近,从而反映出它们之间的相似性。
位置编码:由于Transformer模型本身不具备处理序列顺序的能力,位置编码被引入以赋予模型对输入序列中标记位置的感知能力。这使得模型能够区分相同标记在不同位置时的上下文含义。
Transformer模型中的位置编码通常采用正弦和余弦函数的组合来实现。这种方法利用了三角函数的周期性特性,为每个位置生成唯一的编码,从而蕴含位置信息。具体公式如下:
其中,pos表示位置, i表示维度,d_model表示模型的维度。这种编码方式确保了不同位置的编码向量具有不同的值,且随着位置的变化,编码向量呈现出周期性变化。
以下是实现正弦位置编码的参考代码:
import torch
import torch.nn as nn
def precompute_freq_cis(dim, seqlen, theta=10000.0):
freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim//2)].float() / dim))
t = torch.arange(seqlen)
freqs = torch.outer(t, freqs).float()
return freqs
embedding_dim = 8
sequence_length = 5
embeddings = torch.randn(sequence_length * embedding_dim).view(sequence_length, embedding_dim)
freqs = precompute_freq_cis(embedding_dim, sequence_length)
pe = torch.zeros((sequence_length, embedding_dim))
pe[:,0::2] = torch.sin(freqs)
pe[:,1::2] = torch.cos(freqs)
pe_out = token_embedding + pe
这段代码首先创建了一个位置编码矩阵,然后通过正弦和余弦函数为每个位置生成了一个唯一的编码。这样,即使是相同的词,在句子中不同位置的编码也会有所不同,从而帮助模型理解词序的重要性。
2. 自注意力机制
自注意力机制(Self-Attention Mechanism),也称为内部注意力机制,是一种允许输入序列中的每个元素根据其与其他元素的相似度动态分配权重的技术。这种机制对于处理长序列数据尤为重要,因为它能够捕捉序列内部的长距离依赖关系,同时提高了模型的计算效率和可解释性。
自注意力机制的计算过程涉及以下步骤:
映射到查询(Q)、键(K)和值(V)向量:输入序列首先被映射到查询向量(Q)、键向量(K)和值向量(V)。这一步骤通常通过线性变换实现,其中权重是可学习的。
计算注意力分数:通过计算Q和K的点积来衡量序列中各元素之间的相似度,得到一个原始的注意力分数矩阵。
Softmax归一化:对原始注意力分数应用Softmax函数进行归一化,使得每一行的和为1。这一步骤确保了模型在计算加权平均时,能够根据元素间的相似度合理分配权重。
加权平均:将归一化后的注意力分数与值向量V相乘,得到加权平均的结果。这一步骤反映了每个输入元素根据其与其他元素的相似度关系,调整其嵌入表示。
公式可以表示为:
其中d_k表示值向量的维度。以下是自注意力机制的初步实现代码:
query_matrix = nn.Linear(embedding_dim, embedding_dim)
key_matrix = nn.Linear(embedding_dim, embedding_dim)
value_matrix = nn.Linear(embedding_dim, embedding_dim)
query_vectors = query_matrix(embeddings)
key_vectors = key_matrix(embeddings)
value_vectors = value_matrix(embeddings)
scores = torch.matmul(query_vectors, key_vectors.transpose(-2,-1)) / torch.sqrt(torch.tensor(embedding_dim, dtype=torch.float32))
softmax = nn.Softmax(dim=-1)
attention_weights = softmax(scores)
output = torch.matmul(attention_weights, value_vectors)
在实际应用中,Transformer模型通常采用多头自注意力机制,即同一输入序列被分割成多个Q、K、V的组合,每个“头”独立计算注意力,最后将所有头的输出拼接起来。这种设计增强了模型对不同子空间信息的捕捉能力。
num_attention_heads = 2
output_copy = output.clone()
m_output = torch.concat((output, output_copy), dim=1)
output_matrix = nn.Linear(num_attention_heads*embedding_dim, num_attention_heads*embedding_dim)
out_vectors = output_matrix(m_output)
print(embeddings)
print(out_vectors)
为了加速自注意力机制的计算过程,可以采用ColumnParallelLinear和RowParallelLinear等方法进行并行运算。这些技术通过将矩阵分解为多个较小的子矩阵,利用现代硬件的并行处理能力,从而提高计算效率。
3. 残差连接和归一化
在深度学习中,随着网络层数的增加,梯度消失或梯度爆炸的问题会导致训练深层网络变得困难。残差连接(Residual Connection)和层归一化(Layer Normalization)是解决这些问题的两种有效机制,它们在Transformer模型中发挥着重要作用。
残差连接最初在ResNet中提出,用于缓解深层网络训练中的梯度消失问题。其核心思想是将输入直接添加到网络中某一层的输出上,这样网络就可以学习到恒等映射(Identity Mapping),从而使得深层网络的训练变得更加容易。在Transformer中,残差连接用于将多头自注意力层或前馈网络层的输入与输出相加,然后再进行后续的层归一化和下一层的计算。
公式表示为:
其中,FeedForward(X)表示经过一层或多层网络的变换。
层归一化(Layer Normalization)是一种归一化技术,它对每个样本的每个层的输出进行归一化处理,使得输出的分布具有稳定的均值和方差。与批量归一化(Batch Normalization)不同,层归一化是在单个样本的层级上进行归一化,而不是在整个批次上。这使得层归一化对于小批量大小或非独立同分布的数据更加鲁棒。
层归一化的计算公式为:
其中,x_i是第个i样本的输出,μ_i和σ_i分别是该样本输出的均值和方差,ε是一个很小的常数,防止分母为零,γ和β是可学习的参数,用于对归一化后的输出进行缩放和平移。
在Transformer中,层归一化通常在残差连接之后进行,用于减少层与层之间数据分布的差异,加速训练过程,并保持数据的分布稳定性。
4. 前馈网络
前馈网络(Feed-Forward Network,FFN)是Transformer模型中的关键组件之一,主要用于对自注意力模块的输出进行深度特征提取和非线性变换。这一步骤对于模型学习和表示复杂的语义信息至关重要,能够有效提升模型的性能。
在Transformer中,每个前馈网络由两个线性变换组成,中间夹着一个ReLU激活函数。具体来说,前馈网络的结构可以描述为:
第一个线性变换:输入序列首先通过一个线性层进行变换,这个线性层通常具有较多的神经元,允许模型捕捉更丰富的特征表示。
非线性激活:第一个线性变换的输出通过ReLU激活函数进行非线性映射,这有助于引入非线性因素,使模型能够学习和表示更复杂的函数。
第二个线性变换:经过非线性激活的输出再通过另一个线性层进行变换,最终产生前馈网络的输出。
数学上,前馈网络的计算可以表示为:
其中,W_1和W_2是可学习的权重矩阵,b_1和b_2是可学习的偏置项,激活函数为ReLU。
前馈网络在Transformer中的作用主要体现在以下几个方面:
特征提取:通过线性变换,前馈网络能够从自注意力模块的输出中提取更深层次的特征表示。
非线性建模:ReLU激活函数的引入使得模型能够捕捉输入数据中的非线性关系,增强模型的表达能力。
性能提升:前馈网络通过增加模型的深度和复杂度,有助于提升模型在各种自然语言处理任务上的性能。
适应性:前馈网络的两个可学习的线性变换使得模型能够根据不同的任务和数据特性进行适应性调整。
5. 损失函数
在Transformer模型中,损失函数的设计对于衡量模型预测的准确性和优化模型参数至关重要。交叉熵损失函数因其在处理分类问题中的有效性而被广泛采用。
交叉熵损失函数衡量的是模型预测的概率分布与真实标签分布之间的差异。在自然语言处理任务中,如语言模型或机器翻译,交叉熵损失函数用于计算模型输出的 logits 与真实标签之间的差异。损失函数定义如下:
其中,p(x)是真实标签的分布(通常是 one-hot 编码),q(x)是模型预测的概率分布。
为了评估模型对未来信息的预测能力,特别是在自回归任务中,我们使用上三角掩码(Mask)来屏蔽未来位置的信息。这种掩码确保模型在预测当前位置的输出时,只能依赖于当前位置和之前位置的信息,而不能利用未来位置的信息。
在实际的前向计算中,上三角掩码应用于注意力分数矩阵,将未来位置的分数设置为一个非常大的负数(例如,通过添加一个负无穷大的值),这样在应用 Softmax 函数时,这些位置的权重将接近于零,从而不会对模型的预测产生影响。
掩码的应用不仅局限于自回归任务,它还可以用于处理填充(Padding)问题,确保模型在处理不同长度的序列时不会将填充位置的无意义信息纳入考虑。此外,掩码还可以用于防止在序列生成任务中的信息泄露,例如在机器翻译中,掩码可以确保模型在生成当前词时不会看到未来的词。
6. Llama2对原始Transformer架构的改造
Llama2模型是由Meta AI开发的一系列大型语言模型,旨在通过大规模数据训练和复杂的模型结构提升自然语言处理任务的性能。
Llama2模型对原始Transformer架构进行了一些改造,我们根据书中的介绍总结如下:
旋转位置编码(RoPE): Llama2模型采用了旋转位置编码(RoPE)来替代传统的位置编码方法。RoPE通过将位置信息编码为旋转矩阵,使模型能够更有效地捕捉序列中元素之间的位置关系。这种方法利用了复数的旋转特性,通过预计算频率和复数的函数、重塑函数以及应用旋转嵌入的函数来实现。具体而言,RoPE通过将位置信息转换为旋转向量,并使用可学习的参数来调整旋转角度,提高了模型对位置信息的敏感性。
分组查询机制(Group Query Attention): Llama2模型在注意力机制上采用了分组查询机制。这种机制通过将输入序列分成若干组,并对每组进行独立的自注意力计算,提高了模型对序列中不同部分的关注度。同时,GQA技术还引入了查询(Query)的概念,通过将输入序列中的每个元素与查询进行匹配,使模型能够更好地理解输入序列中的重要信息。这种技术提高了Llama2模型在长序列处理任务中的性能和准确性。
预归一化(Pre-normalization): 与原始Transformer中的后置归一化不同,Llama2模型采用了前置层归一化策略,即在每个子层(自注意力层和前馈网络)的输入之前进行层归一化。这种策略有助于提高训练过程中的稳定性,尤其是在模型参数初始化阶段,可以降低梯度爆炸的风险。
SwiGLU激活函数: Llama2模型在前馈网络的激活函数上进行了创新,将原始的ReLU替换为SwiGLU。SwiGLU是基于Swish激活函数的GLU变体,它提供了更好的梯度流动和可能的性能提升。SwiGLU激活函数的公式为SwiGLU(x, W, V) = Swish_β(xW) ⊗ (xV),其中Swish_β(x) = x * sigmoid(βx),⊗为逐元素乘。SwiGLU的优势在于其处处可微的非线性特性,以及通过门机制控制信息通过的比例,来让模型自适应地选择哪些特征对预测下一个词有帮助。
这些改造使得Llama2模型在处理复杂的自然语言任务时,具有更高的效率和更好的性能。
- 2024-11-21
-
回复了主题帖:
共读入围:《大语言模型开发:用开源模型开发本地系统》
个人信息无误,确认可以完成评测计划
- 2024-10-13
-
发表了主题帖:
【Follow me第二期】任务汇总帖
本帖最后由 Aclicee 于 2024-11-16 14:26 编辑
大家好,我是Aclicee,很高兴能够参加EEworld和DigiKey的Follow me活动。今天向各位老师报告一下本期活动的任务汇总。在本期活动中,我们物料清单是:
Arduino UNO R4 WiFi、LTR329环境光传感器、SHT40温湿度传感器以及一条Qwiic缆线
项目简介:
本次活动主要围绕Arduino UNO R4 WiFi开发板进行学习和实践。项目涉及了开发环境的搭建、基础编程任务(如LED点亮、串口打印、驱动点阵LED、DAC生成正弦波、OPAMP放大DAC信号、ADC采样上传上位机显示),以及进阶任务(如Arduino Wifi连接并通过MQTT协议接入HomeAssistant平台)。此外,还包括了扩展任务,即多模态传感器(LTR-329环境光传感器和SHT40温湿度传感器)的使用并上传数据到HomeAssistant平台。
1. 入门任务:开发环境搭建
【【Follow me第二期】入门任务 - 开发环境配置+LED点亮+串口打印+LED点阵驱动 - DigiKey得捷技术专区 - 电子工程世界-论坛 (eeworld.com.cn)】
我们可以选择使用在线编辑器或者本地的IDE来编译和上传Arduino的代码。打开Arduino的官方网站就可以看到对应的使用方式。这里以使用在线的编辑器Arduino Cloud Editor为例,我们打开编辑器的网站,根据指示安装一个Arduino Cloud Agent插件,即可使用在线编辑器。与我们之前使用的一些单片机开发环境相比,Arduino的配置过程显得尤为简单直观。完成环境配置后,将Arduino开发板通过USB线连接至电脑,即可开始编写代码并进行烧录。
2. 入门任务:LED点亮以及串口打印
【【Follow me第二期】入门任务 - 开发环境配置+LED点亮+串口打印+LED点阵驱动 - DigiKey得捷技术专区 - 电子工程世界-论坛 (eeworld.com.cn)】
我们点击Create Sketch来编写第一个Arduino程序,可以看到我们的开发板可以被编辑器自动识别。我们打开官方提供的例程库,可以看到点亮LED的示例代码。不难发现,Arduino程序主要由两个部分组成:setup()和loop()。setup()函数负责初始化设置,如配置管脚模式和初始化外设,它仅在程序开始时执行一次。随后,程序将进入loop()函数,这是程序的主循环,所有持续执行的任务都在这里进行。
设计思路:在“Blink”例程中,我们首先将LED对应的管脚设置为输出模式。这样,我们就可以通过改变该管脚的电平状态(高电平或低电平),来控制LED的点亮与熄灭。在loop()函数中,我们通过编写代码,使LED以一定的频率闪烁。
将上述代码复制到我们的工程中,并烧录到Arduino开发板上。烧录过程中,开发板上的两盏LED指示灯会短暂点亮,以指示烧录过程正在进行。烧录完成后,只有我们指定的LED会按照预设的频率反复点亮和熄灭,实现闪烁效果。
串口打印功能的实现与LED点亮类似。在setup()函数中,我们需要初始化串口通信,并设置适当的波特率。在loop()函数中,我们编写代码,使开发板能够周期性地通过串口发送字符串信息。
3. 基础任务:驱动点阵LED
【【Follow me第二期】入门任务 - 开发环境配置+LED点亮+串口打印+LED点阵驱动 - DigiKey得捷技术专区 - 电子工程世界-论坛 (eeworld.com.cn)】
设计思路: 与单个LED的控制相似,点阵LED的控制原理也是通过设置对应位置的管脚电平来控制LED的点亮或熄灭。然而,由于涉及的LED数量较多,我们无法对每个LED进行单独配置,因此需要采用一种编码规则来实现有效的控制。我们可以参考官方提供的使用教程,有matrix.renderBitmap()和matrix.loadFrame()两种编码方式,分别使用单独每一个LED位置赋二进制数和将连续四个LED位置转换成16进制两种方式。
在浏览例程时,我们注意到“GameOfLife”例程提供了动态效果的实现。受此启发,我们制作了一个动态点亮笑脸图案的小动画。该动画的原理是随机选择矩阵中的行和列进行点亮或熄灭,然后在主体循环中不断刷新matrix.renderBitmap()以展示动态效果。
除了动态图案,我们还尝试了使用LED矩阵显示字符。这主要借助于ArduinoGraphics.h库。在“TextWithArduinoGraphics”例程中,我们看到可以使用matrix.beginDraw()和matrix.endDraw()构建显示框架,并设置了字体大小和内容。通过matrix.textScrollSpeed()函数,我们可以控制字符的滚动速度和方向。
综上所述,我们实现了LED矩阵的多种驱动方式和功能。我们将随机生成笑脸图案的代码放在setup()函数中,作为一次性的开屏动画。动画结束后,开发板将循环播放字符串滚动效果。
4. 基础任务:DAC生成正弦波
【【Follow me第二期】基础任务 - DAC+OPAMP+ADC以及上位机波形显示 - DigiKey得捷技术专区 - 电子工程世界-论坛 (eeworld.com.cn)】
设计思路:由于需要直观地展示波形,而我们目前缺乏便携式示波器,因此我们将采用本地Arduino IDE内置的串口绘图工具作为替代方案。我们的思路是使用analogWave库,并通过配置使其通过DAC输出波形。然后再调用ADC模块,将DAC输出的模拟信号再转回数字信号,并通过串口发送。通过配置串口的传输波特率并开启串口,发送的数据将基于ADC采集得到。
5. 基础任务:OPAMP放大DAC信号
【【Follow me第二期】基础任务 - DAC+OPAMP+ADC以及上位机波形显示 - DigiKey得捷技术专区 - 电子工程世界-论坛 (eeworld.com.cn)】
设计思路:我们参考官方文档中的Arduino内置运算放大器的管脚配置,输入信号连接至A1(运放的正输入端),A2作为负输入端,A3则作为输出端。我们的思路是设计如图的电压放大器,根据其增益计算公式,选择R1和R2两个均为10kΩ的电阻。电路连接方式如下:A0作为DAC的输出,与A1相连,作为运放的同相输入端;A2作为运放的反相输入端,通过R1和R2电阻分别连接至GND和A3。
在代码中,我们需要配置好OPAMP模块,调用OPAMP.h头文件,并设置运放的工作模式为高速模式。由于ADC读取的数值来源于运放的输出端口,我们将analogRead端口绑定为A3。我们设计的是一个二倍放大器,从ADC采样到的电压从原来的半量程(8位,约128)增加到了满量程(255左右),成功实现了输出电压的翻倍。
6. 基础任务:ADC采样上传上位机显示
【【Follow me第二期】基础任务 - DAC+OPAMP+ADC以及上位机波形显示 - DigiKey得捷技术专区 - 电子工程世界-论坛 (eeworld.com.cn)】
这一过程在前两个任务中已有展示,在此将对之前的DAC输出信号和经OPAMP放大后的信号进行汇总显示,从而直观地展示放大器的效果。在图中,蓝线代表DAC的输出信号,而红线代表经OPAMP放大后的信号,直观地展示了二倍电压放大的效果。
7. 进阶任务:ArduinoWifi连接并通过MQTT协议接入到HomeAssistant平台
【【Follow me第二期】进阶任务 - WiFi+MQTT协议连接智能家居HA平台 - DigiKey得捷技术专区 - 电子工程世界-论坛 (eeworld.com.cn)】
我们首先需要安装并配置Home Assistant(HA),考虑到便利性和易用性,我们选择使用Docker拉取HA容器。拉取完成后,我们需要配置Docker容器在本地的存储位置,并运行容器。运行上述命令后,我们可以在Docker Engine中看到名为homeassistant的容器。通过访问http://localhost:8123,可以进入Home Assistant的Web界面。
然后我们需要配置EMQX平台,EMQX平台是一个高性能、可扩展的MQTT消息服务器。类似的,我们还是采用docker拉取容器的方式。拉取完成后,我们可以使用以下命令来运行EMQX容器。此时可以看到docker的容器列表中新增了emqx容器。使用https://localhost:18083端口可以打开EMQX平台,输入默认的账号密码即可登录EMQX Dashboard。
接下来,我们需要在EMQX Dashboard中创建客户端认证,并添加用户作为HA平台的接入。我们需要在HA平台中配置MQTT服务,填写相关的配置信息,包括EMQX平台的IP、用户名和密码等。这样,HA平台就可以通过MQTT协议与EMQX平台进行通信了。EMQX平台的在线连接数来从0变为1,说明HA平台已经成功接入。
设计思路:在此基础上,我们使用Arduino的Wifi模块,使其接入HA平台。我们的思路是使用WiFiS3库来配置WiFi模块,使用WiFi.begin启用WiFi并连接手机热点。然后,我们需要安装MQTT和HomeAssistant相关的库文件,我们引入ArduinoHA.h头文件,以便使用HomeAssistant平台。我们使用HADevice和HAMqtt来创建HA设备启用MQTT协议。我们实例化了几个HA订阅的传感器,比如按钮、模拟传感器和更新时间传感器等,这些可以根据后面任务的需求进行调整。
完成后需要在程序的主循环中通过mqtt.loop()使其使用MQTT协议进行信息的发送。后面我们就仿照例程中发送传感器数据的方式,添加了两个每间隔1000ms进行一次更新采集数据和时间的代码。全部代码完成以后,将其烧录,同时启动docker中的HA容器和EMQX容器,我们可以看到Arduino首先成功连接了wifi,并输出wifi的相关信息,然后尝试建立MQTT的连接并且成功连接,Arduino通过串口发送了采集到的传感器电压,以及更新时间。
8. 扩展任务:多模态传感器的使用并上传数据到HA平台
【【Follow me第二期】扩展任务 - LTR329+SHT40传感器的使用并上传数据到HA平台 - DigiKey得捷技术专区 - 电子工程世界-论坛 (eeworld.com.cn)】
任务使用的外部传感器为LTR-329环境光传感器和SHT40温湿度传感器,关键在于掌握这两种传感器的配置和数据传输的方法。为了简化连接过程,我们采用了Qwiic缆线连接方式,将传感器与Arduino开发板相连。
有关两个传感器的配置方法已经被分别封装在“Adafruit LTR329 and LTR303”和“Adafruit SHT4X”库中,我们都暂时保持默认即可。由于我们只采购了一条Qwiic缆线,这里以温湿度传感器为例,介绍如何实现数据采集和上传到HA平台。
设计思路:我们的思路是在此前实现的Arduino连接HA平台的基础上,将新添加的传感器实例化为模拟信号传感器,然后初始化SHT4X模块,我们把全部配置参数都设成默认即可。最后在主循环中,我们直接使用例程中的温湿度数据获取的函数,然后通过Sensor.setValue()函数显示到HA界面上。
完成代码编写后,我们启动Docker Engine,打开HA和EMQX平台的容器,并将代码烧录到Arduino。测试结果显示,我们可以看到网络连接和MQTT连接的信息,并确认了SHT40传感器的连接。当我们打开HA平台时,可以看到设备增加了新的传感器,并且更新了数据单位信息,数据也会随时间不断更新。
以上就是本次Follow me活动的所有任务实现过程的分享,感谢大家。
项目心得:
Arduino的开发环境配置过程简单直观,无论是使用在线编辑器还是本地IDE,都能快速上手。整个项目过程中,通过实际操作和问题解决,提升了编程能力、硬件操作技能以及对物联网平台的熟悉度。通过不断学习和实践,对Arduino平台和物联网技术有了更深入的理解,为未来在相关领域的进一步探索打下了坚实的基础。
演示视频:
【【Follow me第二季第二期】任务汇总(by Aclicee)-【Follow me第二季第二期】任务汇总(by Aclicee)-EEWORLD大学堂】
程序源码:
代码链接【Follow me第二季第二期Code by Aclicee-嵌入式开发相关资料下载-EEWORLD下载中心】,也随附在本帖下:
-
上传了资料:
Follow me第二季第二期Code by Aclicee
-
加入了学习《【Follow me第二季第二期】任务汇总(by Aclicee)》,观看 【Follow me第二季第二期】任务汇总(by Aclicee)
- 2024-10-08
-
发表了主题帖:
【Follow me第二期】扩展任务 - LTR329+SHT40传感器的使用并上传数据到HA平台
扩展任务为连接外部的环境光或者温湿度传感器,采集环境数据,然后通过Arduino上传到智能家居HA平台。在完成先前任务的基础上,本任务的难度相对较低。关键在于掌握如何配置环境光传感器LTR-329和温湿度传感器SHT40,并访问它们采集的数据。一旦这些传感器被正确配置,我们便可以利用在进阶任务中学到的代码,将模拟传感器的数据采集和上传流程应用于实际的传感器,从而实现数据的实时上传。
1. 环境光传感器LTR329的使用
为了实现环境光数据的采集,我们选用了LTR-329环境光传感器。关于该传感器的具体参数和管脚配置,可以参考官方提供的规格书【5591.pdf (digikey.com)】。为了简化连接过程,我们采用了Qwiic缆线连接方式,将传感器与Arduino开发板相连。
Arduino平台已经提供了LTR-329传感器的库支持。通过Arduino IDE左侧的Library manager搜索“Adafruit LTR329 and LTR303”,即可找到并安装对应的库。
安装库之后,我们可以在例程中查看库函数的具体使用方法。首先,我们需要在代码中包含Adafruit_LTR329_LTR303.h头文件,并实例化传感器对象ltr。后面是一些配置的代码,比如确认连接、设置传感器增益、采样时间间隔等等,我们都暂时保持默认即可。
在主循环中,我们通过调用ltr.readBothChannels(visible_plus_ir, infrared)来获取传感器采集的参数。LTR-329传感器提供两个通道的数据:可见光与红外光的叠加值,以及单独的红外光值。
由于我们使用了Qwiic缆线连接,因此在代码开头需要包含Arduino Wire.h头文件,并在初始化时确认缆线的连接可用。完整的代码如下:
#include "Adafruit_LTR329_LTR303.h"
#include <Wire.h>
Adafruit_LTR329 ltr = Adafruit_LTR329();
void setup() {
Serial.begin(9600);
Serial.println("Adafruit LTR-329 advanced test");
if ( ! ltr.begin(&Wire1) ) {
Serial.println("Couldn't find LTR sensor!");
while (1) delay(10);
}
Serial.println("Found LTR sensor!");
ltr.setGain(LTR3XX_GAIN_2);
Serial.print("Gain : ");
switch (ltr.getGain()) {
case LTR3XX_GAIN_1: Serial.println(1); break;
case LTR3XX_GAIN_2: Serial.println(2); break;
case LTR3XX_GAIN_4: Serial.println(4); break;
case LTR3XX_GAIN_8: Serial.println(8); break;
case LTR3XX_GAIN_48: Serial.println(48); break;
case LTR3XX_GAIN_96: Serial.println(96); break;
}
ltr.setIntegrationTime(LTR3XX_INTEGTIME_100);
Serial.print("Integration Time (ms): ");
switch (ltr.getIntegrationTime()) {
case LTR3XX_INTEGTIME_50: Serial.println(50); break;
case LTR3XX_INTEGTIME_100: Serial.println(100); break;
case LTR3XX_INTEGTIME_150: Serial.println(150); break;
case LTR3XX_INTEGTIME_200: Serial.println(200); break;
case LTR3XX_INTEGTIME_250: Serial.println(250); break;
case LTR3XX_INTEGTIME_300: Serial.println(300); break;
case LTR3XX_INTEGTIME_350: Serial.println(350); break;
case LTR3XX_INTEGTIME_400: Serial.println(400); break;
}
ltr.setMeasurementRate(LTR3XX_MEASRATE_200);
Serial.print("Measurement Rate (ms): ");
switch (ltr.getMeasurementRate()) {
case LTR3XX_MEASRATE_50: Serial.println(50); break;
case LTR3XX_MEASRATE_100: Serial.println(100); break;
case LTR3XX_MEASRATE_200: Serial.println(200); break;
case LTR3XX_MEASRATE_500: Serial.println(500); break;
case LTR3XX_MEASRATE_1000: Serial.println(1000); break;
case LTR3XX_MEASRATE_2000: Serial.println(2000); break;
}
}
void loop() {
bool valid;
uint16_t visible_plus_ir, infrared;
if (ltr.newDataAvailable()) {
valid = ltr.readBothChannels(visible_plus_ir, infrared);
if (valid) {
Serial.print("CH0 Visible + IR: ");
Serial.print(visible_plus_ir);
Serial.print("\t\tCH1 Infrared: ");
Serial.println(infrared);
}
}
delay(100);
}
运行上述代码后,我们可以通过串口监视器观察到光照强度信息的输出。随着环境光线的变化,例如手动遮挡光源,示数也会相应变化。
动态的数据接收过程请参考随附的视频。
2. 温湿度传感器SHT40的使用
有了环境光传感器的经验,温湿度传感器只需要如法炮制即可。与环境光传感器类似,我们可以通过Arduino的库来简化集成过程。首先,通过Arduino IDE的Library manager搜索并安装“Adafruit SHT4X”库。这个库为SHT40传感器提供了封装好的接口。
打开其中的例程,其实现的过程和环境光传感器是一样的。我们可以在代码中包含Adafruit_SHT4X.h头文件,并实例化一个SHT40传感器对象sht4。
主循环部分,主要是通过sht4.getEvent(&humidity, &temp)来获取温湿度的读数,然后通过串口分别发送。同样,我们因为通过Qwiic缆线来连接开发板和传感器,包含对应的头文件并检查确认其连接即可,具体的代码如下:
#include "Adafruit_SHT4x.h"
#include <Wire.h>
Adafruit_SHT4x sht4 = Adafruit_SHT4x();
void setup() {
Serial.begin(9600);
while (!Serial)
delay(10); // will pause Zero, Leonardo, etc until serial console opens
Serial.println("Adafruit SHT4x test");
if (! sht4.begin(&Wire1)) {
Serial.println("Couldn't find SHT4x");
while (1) delay(1);
}
Serial.println("Found SHT4x sensor");
Serial.print("Serial number 0x");
Serial.println(sht4.readSerial(), HEX);
// You can have 3 different precisions, higher precision takes longer
sht4.setPrecision(SHT4X_HIGH_PRECISION);
switch (sht4.getPrecision()) {
case SHT4X_HIGH_PRECISION:
Serial.println("High precision");
break;
case SHT4X_MED_PRECISION:
Serial.println("Med precision");
break;
case SHT4X_LOW_PRECISION:
Serial.println("Low precision");
break;
}
// You can have 6 different heater settings
// higher heat and longer times uses more power
// and reads will take longer too!
sht4.setHeater(SHT4X_NO_HEATER);
switch (sht4.getHeater()) {
case SHT4X_NO_HEATER:
Serial.println("No heater");
break;
case SHT4X_HIGH_HEATER_1S:
Serial.println("High heat for 1 second");
break;
case SHT4X_HIGH_HEATER_100MS:
Serial.println("High heat for 0.1 second");
break;
case SHT4X_MED_HEATER_1S:
Serial.println("Medium heat for 1 second");
break;
case SHT4X_MED_HEATER_100MS:
Serial.println("Medium heat for 0.1 second");
break;
case SHT4X_LOW_HEATER_1S:
Serial.println("Low heat for 1 second");
break;
case SHT4X_LOW_HEATER_100MS:
Serial.println("Low heat for 0.1 second");
break;
}
}
void loop() {
sensors_event_t humidity, temp;
uint32_t timestamp = millis();
sht4.getEvent(&humidity, &temp);// populate temp and humidity objects with fresh data
timestamp = millis() - timestamp;
Serial.print("Temperature: "); Serial.print(temp.temperature); Serial.println(" degrees C");
Serial.print("Humidity: "); Serial.print(humidity.relative_humidity); Serial.println("% rH");
Serial.print("Read duration (ms): ");
Serial.println(timestamp);
delay(1000);
}
运行上述代码后,我们可以通过串口监视器观察到温湿度信息的输出。随着环境温湿度的变化,示数也会相应变化。温度传感器的结果将和扩展任务一同展示。动态的数据接收过程请参考随附的视频。
3. 扩展任务(必做):通过外部传感器上传信息到HA,并通过HA面板显示数据
在成功集成环境光传感器LTR-329和温湿度传感器SHT40之后,我们的目标是将这些传感器的数据上传至Home Assistant (HA) 平台,并在HA界面上显示这些数据。由于我们只有一条Qwiic缆线,我们将以温湿度传感器SHT40为例进行演示。
我们删除了上个任务中打印wifi信息的相关函数,因为我们暂时不需要显示这些,来精简一下代码。我们包含了所有必要的头文件,并开始配置HA平台。我们将新添加的传感器实例化为模拟信号传感器,具体代码如下:
HASensorNumber analogSensor("AnalogInput", HASensorNumber::PrecisionP1);
HASensorNumber uptimeSensor("Uptime");
HASensorNumber TempSensor("Temperature", HASensorNumber::PrecisionP2);
HASensorNumber HumidSensor("Humid", HASensorNumber::PrecisionP3);
//HASensorNumber LightSensor("Light", HASensorNumber::PrecisionP2);
HAButton buttonA("myButtonA");
HAButton buttonB("myButtonB");
然后我们在初始化完wifi模块和MQTT的连接以后,具体方法参考上一则报告【【Follow me第二期】进阶任务 - WiFi+MQTT协议连接智能家居HA平台 - DigiKey得捷技术专区 - 电子工程世界-论坛 (eeworld.com.cn)】,给HA界面上各传感器的显示增加更多细节,包括进行命名的区分以及增加具体的单位,代码如下:
analogSensor.setName("Analog voltage");
analogSensor.setUnitOfMeasurement("V");
uptimeSensor.setName("Update Time");
uptimeSensor.setUnitOfMeasurement("s");
TempSensor.setName("Temperature");
TempSensor.setUnitOfMeasurement("degrees C");
HumidSensor.setName("Humidity");
HumidSensor.setUnitOfMeasurement("% rH");
//LightSensor.setUnitOfMeasurement("lkx");
最后初始化SHT4X模块,我们把全部配置参数都设成默认,然后合并成一个SHT40_init()函数。
在主循环中,我们直接使用例程中的温湿度数据获取的函数,然后通过Sensor.setValue()函数显示到HA界面上。具体代码如下:
sensors_event_t humidity, temp;
sht4.getEvent(&humidity, &temp);
Serial.print("Temperature: "); Serial.print(temp.temperature); Serial.println("degrees C");
Serial.print("Humidity: "); Serial.print(humidity.relative_humidity); Serial.println("% rH");
TempSensor.setValue(temp.temperature);
HumidSensor.setValue(humidity.relative_humidity);
最后完整的代码如下:
#include <WiFiS3.h>
#include <ArduinoHA.h>
#include <Wire.h>
#include "arduino_secrets.h"
#include "analogWave.h"
#include "Adafruit_SHT4x.h"
#include "Adafruit_LTR329_LTR303.h"
///////please enter your sensitive data in the Secret tab/arduino_secrets.h
char ssid[] = SECRET_SSID; // your network SSID (name)
char pass[] = SECRET_PASS; // your network password (use for WPA, or use as key for WEP)
int status = WL_IDLE_STATUS; // the WiFi radio's status
unsigned long lastUpdateAt = 0;
int freq = 1;
WiFiClient client;
HADevice device(MQTT_CLIENT_ID);
HAMqtt mqtt(client, device);
HASensorNumber analogSensor("AnalogInput", HASensorNumber::PrecisionP1);
HASensorNumber uptimeSensor("Uptime");
HASensorNumber TempSensor("Temperature", HASensorNumber::PrecisionP2);
HASensorNumber HumidSensor("Humid", HASensorNumber::PrecisionP3);
//HASensorNumber LightSensor("Light", HASensorNumber::PrecisionP2);
HAButton buttonA("myButtonA");
HAButton buttonB("myButtonB");
analogWave wave(DAC);
Adafruit_SHT4x sht4 = Adafruit_SHT4x();
Adafruit_LTR329 ltr = Adafruit_LTR329();
void setup() {
//Initialize serial and wait for port to open:
Serial.begin(9600);
while (!Serial) {
; // wait for serial port to connect. Needed for native USB port only
}
// check for the WiFi module:
if (WiFi.status() == WL_NO_MODULE) {
Serial.println("Communication with WiFi module failed!");
// don't continue
while (true);
}
String fv = WiFi.firmwareVersion();
if (fv < WIFI_FIRMWARE_LATEST_VERSION) {
Serial.println("Please upgrade the firmware");
}
// attempt to connect to WiFi network:
while (status != WL_CONNECTED) {
Serial.print("Attempting to connect to WPA SSID: ");
Serial.println(ssid);
// Connect to WPA/WPA2 network:
status = WiFi.begin(ssid, pass);
// wait 10 seconds for connection:
delay(10000);
}
// you're connected now, so print out the data:
Serial.print("You're connected to the network");
Serial.println("\nStart connecting to MQTT server");
if (!mqtt.begin(MQTT_SERVER, MQTT_PORT, MQTT_USERNAME, MQTT_PASSWORD)){
Serial.print("Connection falied");
Serial.print(mqtt.getState());
Serial.println("Try again in 5 seconds");
delay(5000);
}
wave.sine(freq);
wave.amplitude(0.5);
analogReadResolution(14);
device.setName("Arduino");
device.setSoftwareVersion("1.0.0");
buttonA.setIcon("mdi:fire");
buttonA.setName("Click me A");
buttonB.setIcon("mdi:home");
buttonB.setName("Click me B");
analogSensor.setName("Analog voltage");
analogSensor.setUnitOfMeasurement("V");
uptimeSensor.setName("Update Time");
uptimeSensor.setUnitOfMeasurement("s");
TempSensor.setName("Temperature");
TempSensor.setUnitOfMeasurement("degrees C");
HumidSensor.setName("Humidity");
HumidSensor.setUnitOfMeasurement("% rH");
//LightSensor.setUnitOfMeasurement("lkx");
SHT40_init();
//LTR329_init();
}
void loop() {
// check the network connection once every 10 seconds:
mqtt.loop();
if ((millis() - lastUpdateAt) > 1000) { // 1000ms debounce time
uint16_t reading = analogRead(A0);
float voltage = reading * 5.f / 16383.f; // 0.0V - 5.0V
Serial.print("Volt:");
Serial.println(voltage);
analogSensor.setValue(voltage);
unsigned long uptimeValue = millis() / 1000;
Serial.print("Uptime:");
Serial.println(uptimeValue);
uptimeSensor.setValue(uptimeValue);
sensors_event_t humidity, temp;
sht4.getEvent(&humidity, &temp);
Serial.print("Temperature: "); Serial.print(temp.temperature); Serial.println("degrees C");
Serial.print("Humidity: "); Serial.print(humidity.relative_humidity); Serial.println("% rH");
TempSensor.setValue(temp.temperature);
HumidSensor.setValue(humidity.relative_humidity);
// bool valid;
// uint16_t visible_plus_ir, infrared;
// if (ltr.newDataAvailable()) {
// valid = ltr.readBothChannels(visible_plus_ir, infrared);
// if (valid) {
// Serial.print("CH0 Visible + IR: ");
// Serial.print(visible_plus_ir);
// Serial.print("\t\tCH1 Infrared: ");
// Serial.println(infrared);
// }
// LightSensor.setValue(visible_plus_ir);
// }
lastUpdateAt = millis();
}
}
void SHT40_init() {
Serial.println("Adafruit SHT4x test");
if (! sht4.begin(&Wire1)) {
Serial.println("Couldn't find SHT4x");
while (1) delay(1);
}
Serial.println("Found SHT4x sensor");
Serial.print("Serial number 0x");
Serial.println(sht4.readSerial(), HEX);
sht4.setPrecision(SHT4X_HIGH_PRECISION);
switch (sht4.getPrecision()) {
case SHT4X_HIGH_PRECISION:
Serial.println("High precision");
break;
case SHT4X_MED_PRECISION:
Serial.println("Med precision");
break;
case SHT4X_LOW_PRECISION:
Serial.println("Low precision");
break;
}
sht4.setHeater(SHT4X_NO_HEATER);
switch (sht4.getHeater()) {
case SHT4X_NO_HEATER:
Serial.println("No heater");
break;
case SHT4X_HIGH_HEATER_1S:
Serial.println("High heat for 1 second");
break;
case SHT4X_HIGH_HEATER_100MS:
Serial.println("High heat for 0.1 second");
break;
case SHT4X_MED_HEATER_1S:
Serial.println("Medium heat for 1 second");
break;
case SHT4X_MED_HEATER_100MS:
Serial.println("Medium heat for 0.1 second");
break;
case SHT4X_LOW_HEATER_1S:
Serial.println("Low heat for 1 second");
break;
case SHT4X_LOW_HEATER_100MS:
Serial.println("Low heat for 0.1 second");
break;
}
}
void LTR329_init() {
Serial.println("Adafruit LTR-329 advanced test");
if ( ! ltr.begin(&Wire1) ) {
Serial.println("Couldn't find LTR sensor!");
while (1) delay(10);
}
Serial.println("Found LTR sensor!");
ltr.setGain(LTR3XX_GAIN_2);
Serial.print("Gain : ");
switch (ltr.getGain()) {
case LTR3XX_GAIN_1: Serial.println(1); break;
case LTR3XX_GAIN_2: Serial.println(2); break;
case LTR3XX_GAIN_4: Serial.println(4); break;
case LTR3XX_GAIN_8: Serial.println(8); break;
case LTR3XX_GAIN_48: Serial.println(48); break;
case LTR3XX_GAIN_96: Serial.println(96); break;
}
ltr.setIntegrationTime(LTR3XX_INTEGTIME_100);
Serial.print("Integration Time (ms): ");
switch (ltr.getIntegrationTime()) {
case LTR3XX_INTEGTIME_50: Serial.println(50); break;
case LTR3XX_INTEGTIME_100: Serial.println(100); break;
case LTR3XX_INTEGTIME_150: Serial.println(150); break;
case LTR3XX_INTEGTIME_200: Serial.println(200); break;
case LTR3XX_INTEGTIME_250: Serial.println(250); break;
case LTR3XX_INTEGTIME_300: Serial.println(300); break;
case LTR3XX_INTEGTIME_350: Serial.println(350); break;
case LTR3XX_INTEGTIME_400: Serial.println(400); break;
}
ltr.setMeasurementRate(LTR3XX_MEASRATE_200);
Serial.print("Measurement Rate (ms): ");
switch (ltr.getMeasurementRate()) {
case LTR3XX_MEASRATE_50: Serial.println(50); break;
case LTR3XX_MEASRATE_100: Serial.println(100); break;
case LTR3XX_MEASRATE_200: Serial.println(200); break;
case LTR3XX_MEASRATE_500: Serial.println(500); break;
case LTR3XX_MEASRATE_1000: Serial.println(1000); break;
case LTR3XX_MEASRATE_2000: Serial.println(2000); break;
}
}
完成代码编写后,我们启动Docker Engine,打开HA和EMQX平台的容器,并将代码烧录到Arduino。测试结果显示,我们可以看到网络连接和MQTT连接的信息,并确认了SHT40传感器的连接。传感器开始通过串口向上位机发送温湿度信息(截图时我已经把连接线断开了,所以显示了未连接)。
当我们打开HA平台时,可以看到设备增加了新的传感器,并且更新了数据单位信息,数据也会随时间不断更新。动态的展示请参考随附的视频。
至此我们完成了本次Follow me活动的全部任务,汇总帖和视频正在全力整合中。
[localvideo]ca5798636437d85f40691a7ef6521192[/localvideo]
- 2024-10-07
-
发表了主题帖:
【Follow me第二期】进阶任务 - WiFi+MQTT协议连接智能家居HA平台
本帖最后由 Aclicee 于 2024-10-7 22:41 编辑
进阶任务主要使用的是Arduino UNO R4 WiFi板子的ESP32S3 WiFi模块。尽管我对这一领域的知识尚浅,但通过与Arduino、上位机软件、HomeAssistant(HA)这一开源智能家居平台,以及MQTT协议和EMQX消息服务器的深入交互,我得以构建了一个多平台联动的智能系统。这一过程充满了挑战,但也充满了成就感。在本文中,我将分享我的探索之旅,包括遇到的困难和解决方案,希望能为同样走在这条道路上的你提供一些启示和帮助。
1. HA平台的安装和配置
为了实现我们智能家居项目的核心功能,我们首先需要安装并配置Home Assistant(HA),这是一个流行的开源智能家居平台。详细的安装和使用教程可以在【Home Assistant (home-assistant.io)】找到。按照“Get Started”的指引,我们可以进入官方教程,了解HA支持的多种安装方式。
考虑到便利性和易用性,我们选择使用Docker进行安装。Docker可以从【Docker: Accelerated Container Application Development】下载。安装并注册完成后,确保Docker Engine在Docker Desktop中启动。如果你需要科学上网来访问Docker服务,请确保相应的配置已经完成。
接下来,我们将在Docker中拉取HA的镜像。根据【Linux - Home Assistant (home-assistant.io)】中的“Installation with Docker”部分,我们首先在命令行界面(CMD)中执行docker search home-assistant命令,以查看可用的容器目录。以下是具体的命令输出:
我们选择列表中的第一个镜像,并通过执行docker pull homeassistant/home-assistant命令来拉取它。请注意,拉取过程可能需要一些时间,并且有时会因为网络问题而中断。如果遇到失败,只需重新尝试即可。
拉取完成后,我们需要配置Docker容器在本地的存储位置,并运行容器。与官方教程稍有不同,由于我们已经提前拉取了容器,因此不需要再次从镜像源下载。我们可以直接使用以下命令来运行容器:
docker run -d --name homeassistant -v /path/to/local/config:/config -p 8123:8123 homeassistant/home-assistant
请将/path/to/local/config替换为你本地存放容器配置文件的实际路径,并将端口号8123替换为你希望使用的端口。
运行上述命令后,你可以在Docker Engine中看到名为homeassistant的容器。通过访问http://localhost:8123,你可以进入Home Assistant的Web界面。
在首次访问时,系统会提示你创建你的智能家居。按照提示完成一系列信息注册后,重新登录Home Assistant,你就可以进入主界面了。
完成后重新登录HomeAssistant,主页显示的相关信息如下:
2. EMQX平台的安装和配置
MQTT协议是一种轻量级的、基于发布/订阅模式的消息传输协议,广泛用于物联网和分布式系统中。它具有简单易实现、支持多种服务质量(QoS)、报文精简、基于TCP/IP等特性,特别适合于带宽有限和网络不稳定的环境。EMQX平台是一个高性能、可扩展的MQTT消息服务器,支持大规模分布式物联网设备的连接,能够实时处理和移动大量消息和事件流数据。
EMQX的具体信息可以在其官方网站【EMQX 文档】查看,里面也提供了使用MQTT协议的相关教程。我们还是采用docker拉取容器的方式,参考【通过 Docker 运行 EMQX | EMQX文档】的文档说明。我们在指令栏中输入docker search emqx来查找可供拉取的容器列表。
没有看到说明文档中提到的带版本号的emqx容器,就直接拉取第一个根目录,使用docker pull emqx来进行拉取。
拉取完成后,我们可以使用以下命令来运行EMQX容器:
docker run -d --name emqx -p 1883:1883 -p 8083:8083 -p 8084:8084 -p 8883:8883 -p 18083:18083 -v /path/to/local/config:/opt/emqx/data emqx/emqx:5.8.0
请将/path/to/local/config替换为您本地存放容器配置文件的实际路径。此时可以看到docker的容器列表中新增了emqx容器。
使用https://localhost:18083端口可以打开EMQX平台,参考以下文档进行初步的测试【快速开始 | EMQX文档】,平台默认的账号是admin,密码是public,即可登录EMQX Dashboard,界面如下:
接下来,我们需要在EMQX Dashboard中创建客户端认证,并添加用户作为HA平台的接入。
完成后,我们需要在HA平台中配置MQTT服务,填写相关的配置信息,包括EMQX平台的IP、用户名和密码等。EMQX节点的名称可以在集群概览中查看,比如这边是emqx@172.17.0.2。
通过HA平台的设置菜单,配置新的设备与服务,选择MQTT。
配置完成后,我们可以通过HA平台的设置菜单,配置新的设备与服务,选择MQTT,并填写相关的配置信息。这样,HA平台就可以通过MQTT协议与EMQX平台进行通信了。
最后,我们可以通过检查EMQX平台的在线连接数来验证HA平台是否已经成功接入EMQX集群。如果连接数从0变为1,说明HA平台已经成功接入。接下来,我们可以通过MQTT协议将Arduino接入EMQX平台,实现Arduino和HA之间的通信和数据传输。
3. 进阶任务(必做):ArduinoWifi连接并通过MQTT协议接入到HomeAssistant平台
我们将探讨如何将Arduino UNO R4 WiFi开发板通过WiFi连接,并利用MQTT协议接入HomeAssistant平台。这一过程不仅涉及到硬件的配置,还包括了软件的集成,是对我们技能的一次全面考验。
首先,我们需要启动Arduino的WiFi模块。可以参考官方文档【docs.arduino.cc/tutorials/uno-r4-wifi/wifi-examples】,里面提供了Arduino连接Wifi的示例。我们创建一个arduino_secret.h的保密文件,用于存储连接WiFi路由器的敏感信息,如网络名称(SSID)和密码等。
由于我在学校连接校园网需要特定的账号和密码,这一部分在Arduino中实现较为复杂,因此我们选择了一个更为简便的方法:使用手机作为热点,让Arduino通过热点进行连接。
我们直接采用官方例程中的代码,配置WiFi模块。主要包括以下几个部分:引入必要的头文件WiFiS3.h;使用WiFi.begin(ssid, pass)启用WiFi并连接;确认WiFi连接状态并打印相关信息,如网络名称、IP地址、MAC地址等。
#include <WiFiS3.h>
#include "arduino_secrets.h"
///////please enter your sensitive data in the Secret tab/arduino_secrets.h
char ssid[] = SECRET_SSID; // your network SSID (name)
char pass[] = SECRET_PASS; // your network password (use for WPA, or use as key for WEP)
int status = WL_IDLE_STATUS; // the WiFi radio's status
void setup() {
//Initialize serial and wait for port to open:
Serial.begin(9600);
while (!Serial) {
; // wait for serial port to connect. Needed for native USB port only
}
// check for the WiFi module:
if (WiFi.status() == WL_NO_MODULE) {
Serial.println("Communication with WiFi module failed!");
// don't continue
while (true);
}
String fv = WiFi.firmwareVersion();
if (fv < WIFI_FIRMWARE_LATEST_VERSION) {
Serial.println("Please upgrade the firmware");
}
// attempt to connect to WiFi network:
while (status != WL_CONNECTED) {
Serial.print("Attempting to connect to WPA SSID: ");
Serial.println(ssid);
// Connect to WPA/WPA2 network:
status = WiFi.begin(ssid, pass);
// wait 10 seconds for connection:
delay(10000);
}
// you're connected now, so print out the data:
Serial.print("You're connected to the network");
printCurrentNet();
printWifiData();
}
void loop() {
// check the network connection once every 10 seconds:
delay(10000);
printCurrentNet();
}
void printWifiData() {
// print your board's IP address:
IPAddress ip = WiFi.localIP();
Serial.print("IP Address: ");
Serial.println(ip);
// print your MAC address:
byte mac[6];
WiFi.macAddress(mac);
Serial.print("MAC address: ");
printMacAddress(mac);
}
void printCurrentNet() {
// print the SSID of the network you're attached to:
Serial.print("SSID: ");
Serial.println(WiFi.SSID());
// print the MAC address of the router you're attached to:
byte bssid[6];
WiFi.BSSID(bssid);
Serial.print("BSSID: ");
printMacAddress(bssid);
// print the received signal strength:
long rssi = WiFi.RSSI();
Serial.print("signal strength (RSSI):");
Serial.println(rssi);
// print the encryption type:
byte encryption = WiFi.encryptionType();
Serial.print("Encryption Type:");
Serial.println(encryption, HEX);
Serial.println();
}
void printMacAddress(byte mac[]) {
for (int i = 0; i < 6; i++) {
if (i > 0) {
Serial.print(":");
}
if (mac[i] < 16) {
Serial.print("0");
}
Serial.print(mac[i], HEX);
}
Serial.println();
}
将代码烧录到Arduino中,我们可以看到WiFi模块成功连接到了手机热点,并且打印出了网络信息。同时手机上也显示ESP32S3模块的连接。至此完成了wifi模块的使用。
接下来,我们需要在Arduino IDE中安装MQTT和HomeAssistant相关的库文件。可以在Library管理器中搜索并安装home-assistant-integration库。
安装好所有所需的库文件和插件,我们可以查看一下其中涉及MQTT相关操作的例程,因为后面可选任务涉及光照和温湿度传感器的使用,所以我们以模拟传感器的例程为参考【arduino-home-assistant/examples/sensor-analog/sensor-analog.ino at main · dawidchyrzynski/arduino-home-assistant · GitHub】,我们使用Arduino内置的DAC来暂时模拟一下传感器的输出。
由于我们已经配置好了EMQX平台,接下来需要在arduino_secret.h文件中增加EMQX平台的用户名、密码等信息,具体如下:
//arduino_secrets.h header file
#define SECRET_SSID "Galaxy S24+ 6927"
#define SECRET_PASS "LYQ1234567890"
#define MQTT_SERVER "192.168.40.250"
#define MQTT_PORT 1883
#define MQTT_CLIENT_ID "arduino"
#define MQTT_USERNAME "admin"
#define MQTT_PASSWORD "admin"
#define TOPIC_SUBSCRIBE "UNO/arduino/sensor"
因为EMQX平台是基于电脑的IP搭建的,其中MQTT_SERVER后的地址就是当前电脑无线网络的IP地址,可以在指令栏中通过ipconfig/all来查询,如下:
将相关信息全部录入arduino_secret.h头文件,然后仿照上面的例程。因为例程中使用的Ethernet以太网,所以我们只需要学习其中涉及MQTT相关的部分,将其合并到上面的Wifi连接的程序中。
现在,我们将WiFi连接代码与MQTT代码整合在一起。首先,我们引入ArduinoHA.h头文件,以便使用HomeAssistant平台。
我们使用HADevice device(MQTT_CLIENT_ID)和HAMqtt mqtt(client, device)来创建HA设备启用MQTT协议。然后我们实例化了几个HA订阅的传感器,比如按钮、模拟传感器和更新时间传感器等,这些可以根据后面任务的需求进行调整。也可以通过device.XX或者某个具体定义的传感器的字段进行修改,来自定义传感器在HA平台的UI上呈现的形式(比如图标、是否附加单位等)。
以上的初始设置完成以后,我们在Wifi代码的基础上增加以下代码:
Serial.println("\nStart connecting to MQTT server");
if (!mqtt.begin(MQTT_SERVER, MQTT_PORT, MQTT_USERNAME, MQTT_PASSWORD)){
Serial.print("Connection falied");
Serial.print(mqtt.getState());
Serial.println("Try again in 5 seconds");
delay(5000);
}
核心为mqtt. begin()用于启动MQTT的连接,需要输入服务器的地址、端口以及我们预设好的用户名和密码。完成后需要在程序的主循环中通过mqtt.loop()使其使用MQTT协议进行信息的发送。后面我们就仿照例程中发送传感器数据的方式,添加了两个每间隔1000ms进行一次更新采集数据和时间的代码,完整代码如下:
#include <WiFiS3.h>
#include <ArduinoHA.h>
#include "arduino_secrets.h"
#include "analogWave.h"
///////please enter your sensitive data in the Secret tab/arduino_secrets.h
char ssid[] = SECRET_SSID; // your network SSID (name)
char pass[] = SECRET_PASS; // your network password (use for WPA, or use as key for WEP)
int status = WL_IDLE_STATUS; // the WiFi radio's status
unsigned long lastUpdateAt = 0;
int freq = 1;
WiFiClient client;
HADevice device(MQTT_CLIENT_ID);
HAMqtt mqtt(client, device);
HASensorNumber analogSensor("myAnalogInput", HASensorNumber::PrecisionP1);
HASensorNumber uptimeSensor("myUptime");
HAButton buttonA("myButtonA");
HAButton buttonB("myButtonB");
analogWave wave(DAC);
void setup() {
//Initialize serial and wait for port to open:
Serial.begin(9600);
while (!Serial) {
; // wait for serial port to connect. Needed for native USB port only
}
// check for the WiFi module:
if (WiFi.status() == WL_NO_MODULE) {
Serial.println("Communication with WiFi module failed!");
// don't continue
while (true);
}
String fv = WiFi.firmwareVersion();
if (fv < WIFI_FIRMWARE_LATEST_VERSION) {
Serial.println("Please upgrade the firmware");
}
// attempt to connect to WiFi network:
while (status != WL_CONNECTED) {
Serial.print("Attempting to connect to WPA SSID: ");
Serial.println(ssid);
// Connect to WPA/WPA2 network:
status = WiFi.begin(ssid, pass);
// wait 10 seconds for connection:
delay(10000);
}
// you're connected now, so print out the data:
Serial.print("You're connected to the network");
printCurrentNet();
printWifiData();
Serial.println("\nStart connecting to MQTT server");
if (!mqtt.begin(MQTT_SERVER, MQTT_PORT, MQTT_USERNAME, MQTT_PASSWORD)){
Serial.print("Connection falied");
Serial.print(mqtt.getState());
Serial.println("Try again in 5 seconds");
delay(5000);
}
wave.sine(freq);
wave.amplitude(0.5);
analogReadResolution(14);
device.setName("Arduino");
device.setSoftwareVersion("1.0.0");
buttonA.setIcon("mdi:fire");
buttonA.setName("Click me A");
buttonB.setIcon("mdi:home");
buttonB.setName("Click me B");
}
void loop() {
// check the network connection once every 10 seconds:
mqtt.loop();
if ((millis() - lastUpdateAt) > 1000) { // 1000ms debounce time
uint16_t reading = analogRead(A0);
float voltage = reading * 5.f / 16383.f; // 0.0V - 5.0V
Serial.print("Volt:");
Serial.println(voltage);
analogSensor.setValue(voltage);
unsigned long uptimeValue = millis() / 1000;
Serial.print("Uptime:");
Serial.println(uptimeValue);
uptimeSensor.setValue(uptimeValue);
lastUpdateAt = millis();
}
}
全部代码完成以后,将其烧录,同时启动docker中的HA容器和EMQX容器,并通过对应的端口号打开其网页。烧录需要花费很多时间,电脑因为开了太多的东西已经卡顿不堪了,最终结果如下:
我们可以看到Arduino首先成功连接了wifi,并输出wifi的相关信息,然后尝试建立MQTT的连接并且成功连接,Arduino通过串口发送了采集到的传感器电压,以及更新时间。
在HA平台的服务器上,我们可以看到Arduino已经成功接入并开始更新采样到的电压和时间信息。
至此,我们成功完成了Arduino通过WiFi连接并通过MQTT协议接入HomeAssistant平台的任务,属实不易。具体信息的动态展示可以查看随附的视频。
[localvideo]4b8611fb8c5a29c98ec2adee00885522[/localvideo]
- 2024-10-06
-
发表了主题帖:
【Follow me第二期】基础任务 - DAC+OPAMP+ADC以及上位机波形显示
前阵子因为事务繁忙,没来得及更新Follow me的相关测评任务。趁着国庆假期有一些闲暇,把相关任务再完成完成。
本次活动的基础任务主要涉及Arduino内置的DAC,ADC还有运算放大器OPAMP的使用。这些组件对于电子项目的开发至关重要,它们能够处理模拟信号,为数字世界提供精确的数据。
1. 本地IDE的安装
由于本次测评需要直观地展示波形,而我们目前缺乏便携式示波器,因此我们将采用Arduino IDE内置的串口绘图工具(Serial Plotter)作为替代方案。这一工具能够将串口传输的数据实时转化为波形图,为缺乏专业设备的用户提供了极大的便利。
为了使用串口绘图工具,我们首先需要安装Arduino IDE。用户可以从Arduino官方网站的【Software | Arduino】页面下载安装包。安装完成后,您可以在IDE的右上角轻松找到串口绘图工具的入口。
在基础任务中,我们将通过串口发送采集到的数据至IDE,并利用串口绘图工具将这些数据绘制成波形图,以此模拟示波器的功能。
2. 基础任务(必做):用DAC生成正弦波
为了更好地理解和运用DAC模块,我们将参考官方文档【Arduino UNO R4 WiFi Digital-to-Analog Converter (DAC) | Arduino Documentation】,了解如何利用Arduino的内部DAC模块生成特定形状的波形。我们将使用analogWave库,并通过配置使其通过DAC输出波形。
我们计划将输出波形的频率设定为100Hz,并将输出幅度设置为量程的0.5倍(即1.65V)。尽管Arduino理论上能够输出这样的正弦波,但受限于没有示波器,我们无法直观地观察输出波形。
#include "analogWave.h"
analogWave wave(DAC);
int freq = 100;
void setup() {
// put your setup code here, to run once:
wave.sine(freq);
wave.amplitude(0.5);
}
为解决这一问题,我们将调用ADC模块,将DAC输出的模拟信号再转回数字信号,并通过串口发送。通过配置串口的传输波特率并开启串口,发送的数据将基于ADC采集得到。我们通过Serial来配置串口的传输波特率并打开串口,其发送的数据将由ADC采样得到。
Arduino的DAC默认输出口设为A0,我们将使用analogRead绑定A0来读取DAC输出的电平。ADC发送的数字信号为二进制数据,Arduino支持的数字信号最高14位精度,我们可以使用analogReadResolution函数进行调整。
相比于使用示波器,使用ADC采样并通过串口发送数据绘制波形时,可能面临波特率与采样频率不匹配的问题。例如,初始设置的串口波特率为2000000,采样位数为14位,代码示例如下:
#include "analogWave.h"
analogWave wave(DAC);
int freq = 100;
void setup() {
// put your setup code here, to run once:
Serial.begin(2000000);
analogReadResolution(14);
wave.sine(freq);
wave.amplitude(0.5);
}
void loop() {
// put your main code here, to run repeatedly:
int value = analogRead(A0);
Serial.println(value);
}
观察到波形中出现了许多台阶,这可能是由于串口发送速率过快,导致相同数据的重复发送。串口的print和println函数在绘制波形时,默认以数据采集的个数为分度,因此重复数据形成了台阶。
由于ADC采样频率有上限,我们尝试降低串口波特率以获得更直观的波形展示。将ADC采样精度降低至8位以便换算,最终选定波特率为250000,绘制结果如下:
尽管波形上的台阶仍然存在,但长度已显著缩短,验证了之前的假设。此时,波形曲线大约每隔二至三个点就会更新一次数据,表明需要更精细地调整波特率。由于可用的波特率并不是连续可调的,如果使用115200波特率,波形将仅在几个点上出现台阶,如下图所示,图形非常不美观;而使用更低的波特率必将导致某些采样点缺失。
为避免数据重复发送,我们最终选择在循环中加入了1ms的延时。当然,我们也可以通过计算波特率和采样频率来准确计算数据发送的间隔,并使用delayMicroseconds函数实现微秒级精确延时。最终的完整代码如下:
#include "analogWave.h"
analogWave wave(DAC);
int freq = 100;
void setup() {
// put your setup code here, to run once:
Serial.begin(250000);
analogReadResolution(8);
wave.sine(freq);
wave.amplitude(0.5);
}
void loop() {
// put your main code here, to run repeatedly:
int value = analogRead(A0);
Serial.println(value);
delay(1);
}
在IDE的串口绘制工具中,我们得到了相对完整的波形,虽然它不是完美的正弦波,但幅度与设置的一致(使用8位精度,最大值为255,0.5倍量程,最大幅值应在128左右,最高点可能未被采集或发送):
动态图像可以在后面的结果视频中查看。在后续的模块绘制任务中,我们将按照这个波特率和采样精度对串口和ADC进行配置。
3. 基础任务(必做):用OPAMP放大DAC信号
在进行OPAMP模块的任务之前,我们首先参考了官方文档【Arduino UNO R4 WiFi OPAMP | Arduino Documentation】,,以了解Arduino内置运算放大器的管脚配置。里面介绍了Arduino运放的管脚配置如下:
文档中是以一个电压跟随器为例的,其中输入信号连接至A1(运放的正输入端),A2作为负输入端,A3则作为输出端。为了实现电压跟随效果,我们将A2和A3直接连接,这样A3的输出电压将恒等于A1的输入电压。
若要实现电压放大功能,我们可以通过在A2和A3之间分配不同比例的电阻来实现预定比例的电压放大。常用的同相放大电路图如下:
根据理想运放的“虚短”和“虚断”特性,很容易推导得到如下表达式,即放大的倍数为(1+R2/R1)。以一个两倍的电压放大电路为例,我们只需要取R1与R2相等,就能够实现一个两倍的放大。
我们选择R1和R2均为10kΩ的电阻。电路连接方式如下:A0作为DAC的输出,与A1相连,作为运放的同相输入端;A2作为运放的反相输入端,通过R1和R2电阻分别连接至GND和A3。
在代码中,我们需要配置好OPAMP模块,调用OPAMP.h头文件,并设置运放的工作模式为高速模式。由于ADC读取的数值来源于运放的输出端口,我们将analogRead端口绑定为A3。完整的代码如下:
#include "analogWave.h"
#include <OPAMP.h>
analogWave wave(DAC);
int freq = 100;
void setup() {
// put your setup code here, to run once:
Serial.begin(250000);
analogReadResolution(8);
wave.sine(freq);
wave.amplitude(0.5);
OPAMP.begin(OPAMP_SPEED_HIGHSPEED);
}
void loop() {
// put your main code here, to run repeatedly:
int value = analogRead(A3);
Serial.println(value);
delay(1);
}
最终的结果如图所示:
由于我们设计的是一个二倍放大器,所以从ADC采样到的电压从原来的半量程(8位,约128)增加到了满量程(255左右),成功实现了输出电压的翻倍。动态的图像可以在后面的结果视频中查看。
4. 基础任务(必做):用ADC采集并上传到上位机显示
在本任务中,我们将通过ADC模块采集信号,并利用上位机进行显示,以直观地呈现信号的变化。这一过程在前两个任务中已有展示,本次任务将对之前的DAC输出信号和经OPAMP放大后的信号进行汇总显示,从而直观地展示放大器的效果。
由于DAC输出和OPAMP输出位于不同的管脚,我们需要分别对这两个信号进行ADC模拟读取。值得注意的是,当通过串口交替传输这两个信号时,不能连续使用println函数,因为这会导致串口绘图工具将它们视为同一个变量进行绘制,如下:
为了避免这一问题,我们需要在数据之间手动插入换行或空格来进行区分。最终的代码示例如下:
#include "analogWave.h"
#include <OPAMP.h>
analogWave wave(DAC);
int freq = 100;
void setup() {
// put your setup code here, to run once:
Serial.begin(250000);
analogReadResolution(8);
wave.sine(freq);
wave.amplitude(0.5);
OPAMP.begin(OPAMP_SPEED_HIGHSPEED);
}
void loop() {
// put your main code here, to run repeatedly:
int dacvalue = analogRead(A0);
int opampvalue = analogRead(A3);
Serial.print(dacvalue);
Serial.print(" ");
Serial.println(opampvalue);
delay(1);
}
最终结果如下:
在图中,蓝线代表DAC的输出信号,而红线代表经OPAMP放大后的信号,直观地展示了二倍电压放大的效果。动态的图像可以在后面的结果视频中查看。
有一个疑问就是,DAC的输出是否一定绑定在A0这个端口,是否可以通过配置信息来进行修改?之前因为疏忽忘记把A0和A1这条给连在一起了,导致输出的不正常,所以想到除了我们这样在外电路飞线连接的方式,是否有在Arduino内部直接进行配置连接的方式呢?
[localvideo]5d7e005a98463fe63789514f3ccfdce7[/localvideo]
- 2024-10-05
-
回复了主题帖:
共读颁奖:《人工智能实践教程——从Python入门到机器学习》
管理辛苦了!
- 2024-09-15
-
发表了主题帖:
【Follow me第二期】入门任务 - 开发环境配置+LED点亮+串口打印+LED点阵驱动
我们非常有幸参加本期的Follow me活动,能够亲自体验Arduino这一享誉盛名的开源硬件平台。Arduino以其易用性和灵活性,成为全球电子爱好者和教育者的首选工具,期待本次活动有所收获。
1. 开箱
9月6日,我们收到了DigiKey提供的包裹,内含Arduino UNO R4 WiFi开发板一套,以及SHT40温湿度传感器和LTR-329光照传感器的扩展板各一块。此外,还包括了一条用于连接的数据线。在开箱过程中,我们注意到两个传感器扩展板均附带了排针,这表明用户可以通过焊接操作,将传感器直接连接至开发板的排母上,从而实现快速部署。
我们对所收到的产品进行了仔细检查,确认所有配件均齐全且完好无损。开发板的制造工艺精良,细节处理得当,体现了制造商的高标准和对品质的严格要求。值得一提的是,开发板还附带了一个亚克力底座,这一设计不仅提升了产品的美观度,更重要的是,它能够有效防止开发板背面的管脚在操作过程中发生误触,确保了操作的安全性。
2. 入门任务(必做):搭建环境
在进行Arduino开发之前,搭建一个稳定且高效的开发环境是至关重要的。鉴于之前没有使用过Arduino,我们首先查阅了其数据手册【ABX00087-datasheet.pdf (digikey.com)】。该手册不仅详尽地列出了板卡的技术参数、管脚功能及其布局,还提供了基本的使用指南,为我们的后续开发工作奠定了基础。
Arduino提供了两种主要的开发工具:本地IDE和在线编辑器。用户可以通过访问Arduino的官方网站【Arduino - Home】,在顶部菜单中找到对应开发板的使用说明,并下载所需的编辑器。此外,左侧菜单中的Learn Arduino部分提供了不同编辑器的初步教程,为初学者提供了详尽的指导。
本地编辑器的下载地址可以在Software模块找到【Software | Arduino】,不过作为初步尝试,我们选择先Arduino的在线编辑器开发一些短平快的小demo。用户可以通过Software模块顶部的Go To Cloud Editor选项进入在线编辑器,同时,相关的使用教程也在【Using the Arduino Cloud Editor | Arduino Documentation】中提供。
在使用在线编辑器之前,需要注册一个Arduino账户,并安装Arduino Create Agent插件。安装并激活插件后,其图标将显示在电脑右下角的工具栏中。
至此,Arduino在线编辑器的环境配置已基本完成。与我们之前使用的一些单片机开发环境相比,Arduino的配置过程显得尤为简单直观。
完成环境配置后,将Arduino开发板通过USB线连接至电脑,即可开始编写代码并进行烧录。这一过程的便捷性,再次体现了Arduino平台的易用性。
3. 入门任务(必做):LED点亮以及串口打印
LED点亮和串口打印功能这两项基础功能是电子项目开发中不可或缺的部分,对于初学者而言,掌握它们是进入更高级应用的前提。
在Arduino的在线编辑器中,通过点击左侧菜单的Sketches选项,我们可以创建一个新的工程。对于我们这样初次接触Arduino编程的用户,可能会感到无从下手。此时,可以利用编辑器左侧的Examples菜单,其中包含了丰富的示例代码,为我们提供了宝贵的参考。
我们以“Blink”例程为例,该例程展示了如何控制LED的点亮与熄灭。Arduino程序主要由两个函数组成:setup()和loop()。setup()函数负责初始化设置,如配置管脚模式和初始化外设,它仅在程序开始时执行一次。随后,程序将进入loop()函数,这是程序的主循环,所有持续执行的任务都在这里进行。
在“Blink”例程中,我们首先将LED对应的管脚设置为输出模式。这样,我们就可以通过改变该管脚的电平状态(高电平或低电平),来控制LED的点亮与熄灭。在loop()函数中,我们通过编写代码,使LED以一定的频率闪烁。
将上述代码复制到我们的工程中,并烧录到Arduino开发板上。烧录过程中,开发板上的两盏LED指示灯会短暂点亮,以指示烧录过程正在进行。烧录完成后,只有我们指定的LED会按照预设的频率反复点亮和熄灭,实现闪烁效果。
串口打印功能的实现与LED点亮类似。在setup()函数中,我们需要初始化串口通信,并设置适当的波特率。在loop()函数中,我们编写代码,使开发板能够周期性地通过串口发送字符串信息。
我们将LED点亮和串口打印的代码合并到一个工程中,进行烧录,具体代码如下:
void setup() {
pinMode(LED_BUILTIN, OUTPUT);
Serial.begin(9600);
}
void loop() {
Serial.println("Hello EEWorld!");
digitalWrite(LED_BUILTIN, HIGH);
delay(1000);
digitalWrite(LED_BUILTIN, LOW);
delay(1000);
}
烧录完成后,开发板上的LED将按照预设频率闪烁,同时,通过打开编辑器的串口监视器,我们可以观察到开发板发送的字符信息。
4. 基础任务(必做):驱动12x8点阵LED
接下来尝试驱动一下LED的点阵。与单个LED的控制相似,点阵LED的控制原理也是通过设置对应位置的管脚电平来控制LED的点亮或熄灭。然而,由于涉及的LED数量较多,我们无法对每个LED进行单独配置,因此需要采用一种编码规则来实现有效的控制。
我们首先需要引入Arduino_LED_Matrix.h这个头文件,它包含了控制矩阵LED所需的函数。这个库可以通过Arduino IDE的Libraries菜单搜索并添加到项目中。
查看教程中有关LED Matrix的部分【Using the Arduino UNO R4 WiFi LED Matrix | Arduino Documentation】,我们学习了如何使用这个库来控制LED点阵。相关的例程也可以在LED Matrix库中找到。大致学习了一下教程和例程,对点阵进行编码的方式可以总结为如下两种。
第一种是直接构建一个8x12的数组,每个元素代表一个LED的状态,0代表熄灭,1代表点亮。这种方法直观且易于理解。例如,我们参考了“DisplaySingleFrame”例程,设计了一个笑脸图案,如下:
在循环主体中,我们使用matrix.renderBitmap()函数来驱动LED矩阵,将数组中的0和1映射到对应的LED位置,实现图案的显示。这个Bitmap非常生动形象的概括了点亮LED的过程,即将各1位的0或1比特映射到对应的LED矩阵位置。
第二种方法是对整个矩阵进行简化表达,使用三个32位的uint32_t变量来表示整个矩阵。我们将上述数组按照4位一组转换成16进制数,从而简化了代码的复杂性,如下:
在主体循环中,我们使用matrix.loadFrame()函数来驱动LED矩阵,这种方法使得代码更加简洁和优雅。Frame表示已经把整个图像框起来打包好了,是一个整体。
在浏览例程时,我们注意到“GameOfLife”例程提供了动态效果的实现。受此启发,我们制作了一个动态点亮笑脸图案的小动画。该动画的原理是随机选择矩阵中的行和列进行点亮或熄灭,然后不断刷新新的Bitmap以展示动态效果。
为了避免因随机性导致的长时间卡顿,我们引入了一个矩阵来记录已经随机过的位置。这样,下一次遇到重复位置时,我们可以跳过,从而避免了循环中的卡顿。通过统计与目标图案点亮位置相同的个数,一旦匹配完成,即结束随机点亮过程,不必再等待所有位置都被随机遍历一遍了。
即便如此,有时候最后几个位置还是比较慢才会随机到。主要是C语言很久没使用了,平时写Python和Matlab习惯了,一下子遇到C语言不太会写了,写的非常笨重,循环套循环的,希望大家批评指正。具体代码如下:
#include "ArduinoGraphics.h"
#include "Arduino_LED_Matrix.h"
ArduinoLEDMatrix matrix;
#define ROWS 8
#define COLUMNS 12
#define BITNUM 24
uint8_t Figure[8][12] = {
{ 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0 },
{ 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0 },
{ 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1 },
{ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 },
{ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 },
{ 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1 },
{ 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0 },
{ 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0 }
};
void setup() {
Serial.begin(115200);
matrix.begin();
uint8_t initFigure[ROWS][COLUMNS] = {0};
uint8_t currentFrame[ROWS][COLUMNS] = {0};
int count = 0;
for (int i = 0; i < ROWS*COLUMNS; i++) {
bool placed = false;
while (!placed) {
int row = random(ROWS);
int col = random(COLUMNS);
if (initFigure[row][col] == 0) {
initFigure[row][col] = 1;
if (Figure[row][col] == 1) {
currentFrame[row][col] = 1;
count++;
}
placed = true;
}
if (count == 30){
break;
}
}
matrix.renderBitmap(currentFrame, ROWS, COLUMNS);
}
}
下图是其随机生成过程中以及快要画完时候的LED矩阵。
除了动态图案,我们还尝试了使用LED矩阵显示字符。这主要借助于ArduinoGraphics.h库。在“TextWithArduinoGraphics”例程中,我们学习了如何使用matrix.beginDraw()和matrix.endDraw()构建显示框架,并设置了字体大小和内容。通过matrix.textScrollSpeed()函数,我们可以控制字符的滚动速度和方向。具体代码如下:
void loop() {
matrix.beginDraw();
matrix.stroke(0xFFFFFFFF);
matrix.textScrollSpeed(200);
const char text[] = " Hello EEWorld! ";
matrix.textFont(Font_5x7);
matrix.beginText(0, 1, 0xFFFFFF);
matrix.println(text);
matrix.endText(SCROLL_LEFT);
matrix.endDraw();
delay(500);
}
显示的结果如下:
综上所述,我们实现了LED矩阵的多种驱动方式和功能。我们将随机生成笑脸图案的代码放在setup()函数中,作为一次性的开屏动画。动画结束后,开发板将循环播放字符串滚动效果。
具体的实验过程和效果,可参考随附的视频演示。
[localvideo]d98f5afefadc17e0b205f9fe24c2b00e[/localvideo]
- 2024-09-14
-
加入了学习《串口打印Hello EEWorld!》,观看 控制LED8*12矩阵
- 2024-09-06
-
发表了主题帖:
《人工智能实践教程——从Python入门到机器学习》阅读报告(3)
本帖最后由 Aclicee 于 2024-9-6 00:42 编辑
本书的第三部分,作者深入探讨了神经网络的构建与应用,这一部分是全书的精华所在。以下是对书中核心内容的简要介绍:
感知机:作者首先从感知机的原理出发,详细阐述了其从逻辑电路到多层神经网络的演变过程。书中不仅介绍了基础的激活函数,还深入探讨了如何通过这些函数构建更为复杂的神经网络结构;
反向传播算法:通过链式法则和计算图,作者清晰地推导了反向传播算法的理论基础。书中不仅涵盖了常见的激活函数,还特别介绍了Softmax等复杂函数的推导过程,为读者揭示了误差反向传播在神经网络学习中的核心作用;
训练方法:在优化器的选择上,作者提供了SGD、Momentum、AdaGrad和Adam等多种算法的比较。这一部分可以参考李宏毅老师的课程,其中有一些动态图像会提供更直观的理解,帮助读者把握这些方法之间的差异;
训练的优化:书中对网络参数初始化、批量归一化、正则化以及超参数选择等训练优化技术进行了深入分析。虽然这些内容在实际操作中可能需要大量的试错,但作者的理论分析为读者提供了坚实的基础;
卷积网络(CNN):作者介绍了卷积神经网络(CNN)的原理和优势,并以图像处理中的应用案例为例,展示了CNN的强大能力。尽管我的研究领域并非图像处理,但我仍然能够利用CNN的特性来丰富我的项目实践。
对这一部分的总体评价:
总体来说,这部分内容是全面的,尤其是在理论推导方面,作者的详细讲解极大地提升了阅读兴趣。书中对于训练过程的介绍,从优化器的选择到网络优化方法,都配有清晰的公式和示例代码,极大地方便了读者的理解和学习。然而,书中对于计算机视觉和自然语言处理两大应用领域的承诺,似乎在介绍完CNN后戛然而止,对于时间序列相关的网络如RNN、LSTM以及当前热门的Transformer模型并未涉及,这不免让人感到遗憾。
对全书的测评如下:
综合考虑,我认为这本书只能算是无功无过。它的内容全面,但与市场上的其他教材相比,缺乏一定的创新性和深度。虽然作为一本实践教程,它在理论深度上略显不足,但如果仅作为快速查找代码的操作手册,它的实用性又不如直接上网搜索或咨询人工智能助手来得高效。此外,书中未能涵盖自然语言处理的相关方法和网络,这在一定程度上影响了它的完整性。尽管如此,书中的代码开源是一个值得称赞的亮点。
回到测评的实践任务,按照测评计划,我将结合我目前正在进行的科研项目,运用神经网络技术对电池的衰退程度进行预测。专业概念和电池衰退的定义颇为复杂,我在这里就不深入讨论了,可以通过查阅相关论文来获得更多信息。
简单的背景就是电池在经历反复的充放电循环后,其可用容量会逐渐减少。为了量化这一衰退过程,我们引入了一个衡量指标——电池健康状态(State of Health, SOH)。SOH是通过计算电池每个周期最大可用容量与其出厂时标称容量的比值来定义的。与之前仅进行衰退与否的二元判断不同,我们现在的目标是预测一个连续变化的值,即将问题从分类任务转变为回归任务。上一期阅读报告的链接:【《人工智能实践教程——从Python入门到机器学习》阅读报告(2) - 编程基础 - 电子工程世界-论坛 (eeworld.com.cn)】。
而我们选择的特征,我选择了之前提到的马里兰大学的电池数据集(Battery Research Data | Center for Advanced Life Cycle Engineering (umd.edu))作为研究对象。由于实际使用的数据集暂时无法公开,我在此基础上胡诌了一个数据集,但我们采用的原理和方法是相同的。我们从电池每个充放电周期中提取了5个不同充电量时的电压值作为特征,这些特征能够刻画电池的充电过程,并作为评估其SOH的依据。
1. 数据的预处理和读取
在开始模型训练之前,我们已经对数据集进行了预处理,包括归一化等步骤。具体的预处理过程可以参考我上一次的测评报告。此外,我们根据电池的完整衰退数据,选择了两节电池的数据作为训练集,剩余一节作为测试集,训练集和测试集的样本数量比大约为2.5:1。
输出:Training Samples: 1442 Testing Samples: 615
在之前的分类问题研究中,我们已经展示了不同特征之间的相关性。在本次回归任务中,由于预测目标是一个连续变化的值,我们直接绘制了某个特征与输出结果之间的关联图。从图中可以看出,因为我们的选取的特征还是比较和预测目标相关的,所以线性度比较强。
2. 线性回归模型(Linear Regression,LR)
基于此,我们首先考虑实现一个基础的线性回归算法。这个算法将作为我们后续与神经网络方法进行比较的基准。通过线性回归,我们可以初步探索特征与电池衰退程度(SOH)之间的关系。
线性回归是一种预测连续目标值的统计方法,它假设输入特征与目标值之间存在线性关系。在模型构建过程中,我们的目标是找到一组权重,使得模型预测的输出尽可能接近实际的目标值。具体来说,模型的预测公式可以表示为:,其中是预测值,是输入特征,是权重,是偏置项。在之前的以及其他作者的阅读报告中已经提及,一般可以通过最小二乘等方法来求解线性回归问题。这里我就直接调用python的sklearn包了,代码如下:
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score
LRmodel = LinearRegression()
LRmodel.fit(X_train_LR, y_train)
y_train_pred = LRmodel.predict(X_train_LR)
y_test_pred = LRmodel.predict(X_test_LR)
train_mse = mean_squared_error(y_train, y_train_pred)
test_mse = mean_squared_error(y_test, y_test_pred)
train_r2 = r2_score(y_train, y_train_pred)
test_r2 = r2_score(y_test, y_test_pred)
print(f'Training MSE: {train_mse:.6f}, R^2: {train_r2:.4f}')
print(f'Test MSE: {test_mse:.6f}, R^2: {test_r2:.4f}')
输出:Training MSE: 0.000126, R^2: 0.9753 Test MSE: 0.000125, R^2: 0.9641
为了评估模型的性能,我们使用了两个关键指标:均方误差(MSE)和拟合优度(R^2)。MSE衡量的是预测值与实际值之间差异的平方的平均值,而R^2则反映了模型预测值与实际值之间的相关程度。理想情况下,MSE应尽可能小,R^2应接近1。
下面的左右两图边是利用训练集的数据拟合的5个特征预测的输出结果和实际结果的对应关系。上面的每一个散点表示一个样本,纵坐标为其预测的目标值,横坐标为其实际的目标值。黑线代表预测结果和实际结果完全相同的情况,因此输出的散点图分布越接近黑线,准确度越高。
在我们的实验中,由于特征与SOH之间存在较强的线性关系,线性回归模型表现出了较好的预测效果。散点图显示,大多数数据点都紧密地围绕着代表完美预测的黑线分布,这表明模型的预测精度较高。
3. 引入非线性问题
既然如此,我们不妨对数据进行一些干扰,来模拟采集过程中的噪声,来降低两者之间的线性相关性。这种干扰降低了特征与预测目标之间的线性相关性,从而对模型的预测能力提出了挑战。具体代码如下:
import numpy as np
np.random.seed(8)
noise_level = 0.1
X_train_LR = X_train + np.random.randn(*X_train.shape) * noise_level
X_test_LR = X_test + np.random.randn(*X_test.shape) * noise_level
如下图所示:
在引入噪声后,线性回归模型的预测效果受到了影响。测试集上的误差增加,拟合优度下降,这表明模型在面对非理想数据时的预测能力有所降低。如下图所示:
输出如下:
*** Linear Regression ***
Training MSE: 0.000797, R^2: 0.8435
Test MSE: 0.000857, R^2: 0.7535
4. 多层感知机(Multi-Layer Perceptron,MLP)
此时因为问题的非线性特征,我们引入神经网络。神经网络相较于传统的线性回归模型,其最大的优势在于能够通过激活函数引入非线性,从而更准确地刻画数据中的复杂关系。我设计了一个包含三个层的网络。输入层接收当前周期的5个特征值,这与我们在线性回归模型中使用的特征相同。在隐层,我们将特征维度扩展到32和64,以增强模型的表达能力。最后,通过一个全连接层输出预测的SOH值。具体代码如下:
import torch
import torch.nn as nn
class MLP(nn.Module):
def __init__(self, input_dim):
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_dim, 32)
self.fc2 = nn.Linear(32, 64)
self.fc3 = nn.Linear(64, 1)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
为了确保训练过程的稳定性和可重复性,我们从训练集中划分出一部分数据作为验证集。这样做的目的是为了在训练过程中选择更合适的模型超参数。我们还控制了随机种子,以确保每次运行时数据的随机分配都是一致的。具体的代码如下:
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
torch.manual_seed(8)
X_train_tensor = torch.tensor(X_train_LR, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)
X_test_tensor = torch.tensor(X_test_LR, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32).view(-1, 1)
train_size = int(0.8 * len(X_train_tensor))
val_size = len(X_train_tensor) - train_size
train_dataset, val_dataset = random_split(TensorDataset(X_train_tensor, y_train_tensor),
[train_size, val_size])
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=64, shuffle=False)
训练过程的代码如下:
input_dim = X_train_LR.shape[1]
model = MLP(input_dim)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)
num_epochs = 200
train_losses = []
avg_val_losses = []
for epoch in range(num_epochs):
for inputs, targets in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
model.eval()
with torch.no_grad():
val_losses = []
for inputs, targets in val_loader:
outputs = model(inputs)
val_loss = criterion(outputs, targets)
val_losses.append(val_loss.item())
avg_val_loss = sum(val_losses) / len(val_losses)
train_losses.append(loss.item())
avg_val_losses.append(avg_val_loss)
if (epoch+1) % 10 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Traing Loss: {loss.item():.6f}, Validation Loss: {avg_val_loss:.6f}')
测试的代码如下:
model.eval()
with torch.no_grad():
y_train_pred = model(X_train_tensor)
y_test_pred = model(X_test_tensor)
train_mse = mean_squared_error(y_train, y_train_pred)
test_mse = mean_squared_error(y_test, y_test_pred)
train_r2 = r2_score(y_train, y_train_pred)
test_r2 = r2_score(y_test, y_test_pred)
print(f'Training MSE: {train_mse:.6f}, R^2: {train_r2:.4f}')
print(f'Test MSE: {test_mse:.6f}, R^2: {test_r2:.4f}')
在MLP模型训练完成后,我们评估了其预测效果。与直接的线性回归模型相比,MLP在预测误差和R^2值上都有所提升。尽管我还没有对模型的超参数进行细致的调整,但我相信通过进一步的优化,模型的性能还有很大的提升空间。目前,模型的准确度虽然有所提高,但仍然有待提高。最终的预测效果如下图所示:
输出如下:
*** MLP ***
Training MSE: 0.000638, R^2: 0.8748
Test MSE: 0.000728, R^2: 0.7907
考虑到电池的衰退不仅与当前周期有关,而且与之前的周期也有密切联系,将某一周期割裂来看确实会因为测量的噪声的问题出现预测的误差。我们将单一周期的数据扩展到连续5个周期的充电曲线特征。这样的数据重构有助于捕捉电池衰退过程中的时间序列特性。
X_train_CNN = X_train.reshape(-1,5,5)
X_test_CNN = X_test.reshape(-1,5,5)
X_train_CNN = X_train_CNN + np.random.randn(*X_train_CNN.shape) * noise_level
X_test_CNN = X_test_CNN + np.random.randn(*X_test_CNN.shape) * noise_level
理论上,对于时间序列数据,长短期记忆网络(LSTM)可能是更合适的选择。然而,卷积神经网络(CNN)在处理此类数据时也展现出了其独特的优势。之前有研究利用CNN对多周期的充电曲线特征进行深入的相关性和特征提取。因此,我们也尝试复现这一过程,以探索CNN在电池衰退预测中的潜力。
5. 卷积神经网络(Convolutional Neural Networks,CNN)
我们的CNN模型采用了3x3的卷积核,对输入的特征矩阵进行两次卷积运算。每次卷积之后,我们使用最大值池化(Max Pooling)来降低特征维度,同时保留重要的信息。最终,所有通道的关键特征被展平,将它们从2D结构压缩到1D结构,并通过一系列线性层和非线性激活函数进行处理,以生成最终的预测输出。具体代码如下:
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.fc1 = nn.Linear(32 * 2 * 2, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, 1)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 32 * 2 * 2)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
CNN的训练过程与MLP相似,我们采用了相同的训练策略和验证方法。在这里,我们不再重复训练过程的详细描述。最终的预测结果如下:
输出如下:
*** CNN ***
Training MSE: 0.000285, R^2: 0.9441
Test MSE: 0.000315, R^2: 0.9094
经过训练,CNN模型展现出了较好的预测效果。与MLP相比,训练集和测试集的预测结果更接近理想的黑线,这表明模型的预测精度有所提高。然而,我们也观察到预测结果普遍高于实际值,这可能与模型参数的调优不足有关。
通过绘制训练集和验证集的损失下降曲线,我们发现两者都有明显的波动。为了更清晰地观察这种波动,我们将纵坐标调整为对数尺度。这种波动至少揭示了两个问题:首先,我们的学习率可能设置得过高,导致模型在训练过程中难以稳定地收敛到损失的局部最小值;其次,模型可能没有在最佳时机停止训练,因为在预设的训练周期结束时,验证集的损失已经开始上升,这可能预示着过拟合的风险。
6. CNN的优化
对此,参考书中的介绍,一一给出解决方案:
对于学习率的设置问题,手动调整学习率并进行多次交叉验证虽然能找到最优解,但过程繁琐且耗时。书中提到了两种自适应学习率的优化器:Adagrad和Adam,我们这里选择了后者。同时引入学习率衰减策略控制具体的学习率衰退,即随着训练的进行逐渐减小学习率,这有助于模型在初期快速下降损失,在训练后期则能细致地逼近局部最小值。
import torch.optim.lr_scheduler as lr_scheduler
scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.98)
对于合适的训练周期数问题,确定合适的训练周期数(epoch数)对于避免过拟合和欠拟合至关重要。我们通过监控验证集上的损失,并引入早停机制(Early Stopping)来解决这一问题。如果在设定的耐心值(patience)周期内,验证集损失没有显著下降,训练将自动停止。这样可以避免在损失尚未稳定时过早结束训练,同时防止过长时间的训练导致资源浪费和趋向过拟合。以下是实现早停机制的代码:
num_epochs = 1000
patience = 50
best_loss = float('inf')
patience_counter = 0
train_losses = []
avg_val_losses = []
for epoch in range(num_epochs):
for inputs, targets in train_loader:
optimizer.zero_grad()
outputs = model_opt(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
scheduler.step()
model_opt.eval()
with torch.no_grad():
val_losses = []
for inputs, targets in val_loader:
outputs = model_opt(inputs)
val_loss = criterion(outputs, targets)
val_losses.append(val_loss.item())
avg_val_loss = sum(val_losses) / len(val_losses)
train_losses.append(loss.item())
avg_val_losses.append(avg_val_loss)
if (epoch+1) % 10 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Traing Loss: {loss.item():.6f}, Validation Loss: {avg_val_loss:.6f}')
if avg_val_loss < best_loss:
best_loss = avg_val_loss
patience_counter = 0
torch.save(model.state_dict(), 'best_model.pth')
else:
patience_counter += 1
if patience_counter >= patience:
print(f'Early stopping at epoch {epoch + 1}')
break
对于过拟合的问题,它会导致模型在训练集上表现良好,但在测试集上表现不佳。书中给出了两种比较好的解决思路。一是引入正则化,通过在损失函数中添加参数的L2范数,减少模型对特定样本的敏感度;二是使用dropout技术,在训练过程中随机“丢弃”一部分神经元,增加模型的泛化能力。在PyTorch中,正则化可以通过优化器的权重衰减参数来实现,它的值越大,正则化的强度就越高,对权重的惩罚也就越大。
class OptCNN(nn.Module):
def __init__(self):
super(OptCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.fc1 = nn.Linear(32 * 2 * 2, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, 1)
self.dropout = nn.Dropout(p=0.3)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 32 * 2 * 2)
x = torch.relu(self.fc1(x))
x = self.dropout(x)
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
model_opt = OptCNN()
criterion = nn.MSELoss()
optimizer = optim.Adam(model_opt.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.98)
最终结果如下:
输出如下:
*** Optimized CNN ***
Early stopping at epoch 237
Training MSE: 0.000274, R^2: 0.9462
Test MSE: 0.000252, R^2: 0.9276
我们还监控了训练集和验证集的损失下降过程。
通过引入正则化和dropout,验证集的损失抖动明显减少,这表明模型的稳定性得到了提升。同时,损失值稳定下降,且训练周期数减少,这得益于早停机制和学习率衰减策略的有效应用。最后所有的模型汇总如下:
LR
MLP
CNN
Opt. CNN
MSE
0.000857
0.000728
0.000315
0.000252
R^2
0.7535
0.7907
0.9094
0.9276
7. 最后总结
虽然书中的实例主要聚焦于计算机视觉领域,这与我的研究方向有所差异,但我依然能够将书中的理论知识与我的研究内容相结合,对具体的任务进行了深入的实践探索。通过本次测评,我不仅体验了神经网络的调用、搭建以及优化的全过程,而且对这些环节有了更深刻的理解。尽管我采取了一些措施在一定程度上提高了模型的性能,但我清楚地意识到,这些优化远非最优解。神经网络的参数调整,常被戏称为“炼丹”,这一比喻形象地说明了其复杂性和不确定性,后续的学习还是永无止境啊!
- 2024-08-22
-
加入了学习《【Follow me第二季第1期】使用Makecode图形化完成任务》,观看 【Follow me第二季第1期】使用Makecode图形化完成任务
-
加入了学习《【Follow me第二季第1期】全部任务演示》,观看 全部任务演示2.0
-
发表了主题帖:
《人工智能实践教程——从Python入门到机器学习》阅读报告(2)
本帖最后由 Aclicee 于 2024-8-22 21:55 编辑
本书的第二部分为机器学习模块,介绍了一些经典的机器学习算法。机器学习与深度学习,二者虽同根生,却各有千秋。个人的理解上,机器学习算法,如KNN、随机森林和支持向量机等,它们是依赖于严格的数学计算的。这些算法的精髓在于其算法本身的设计,而非依赖于人工介入的超参数调整。不像深度学习需要炼丹那样,精心调配各种网络结构和参数,以期达到最佳的训练效果。
本书在这一部分中,为我们详细介绍了以下几个关键的机器学习算法:
PCA(主成分分析):一种通过正交变换将数据转换到新的坐标系统中,使得数据的任何投影的第一大方差在第一个坐标(称为第一主成分)上,第二大方差在第二个坐标上,依此类推的算法。它不仅帮助我们降维,更让我们能够洞察数据的本质。
Kmeans:一种经典的聚类算法,通过迭代优化,将数据点划分为K个簇,使得每个簇内的点尽可能相似,而簇间的点尽可能不同。它简单而高效,是探索数据结构的有力工具。
KNN(K近邻算法):一种基于距离的分类算法,通过测量不同特征值之间的距离来进行分类。它的简洁和直观,使得KNN在多种应用场景中都表现出了不俗的效果。
线性回归与多项式回归:回归算法是预测数值型数据的重要工具。线性回归以其直观的模型和易于理解的特性,为我们提供了数据预测的基础。而多项式回归则在此基础上增加了模型的复杂度,以适应更复杂的数据关系。
对这一部分的总体评价:
有一说一,这本书的介绍逻辑比较奇怪,之前学习其他课程往往会从回归问题开始讲解,因为线性回归是后面很多算法的前置(比如支持向量机,甚至是推导梯度下降时最简单情形的案例),不明白为什么编者要放在最后,反而在最前面花了大量的篇幅讲PCA。
一方面,对初学者而言,初次接触机器学习便直接学习PCA可能会感到困惑,因为它属于是数据降维和特征工程相关的内容,有点像是锦上添花的作用,并不直接关联到具体的分类或回归任务。如果读者在没有充分理解机器学习基本概念的情况下学习PCA,可能会对其作用和重要性缺乏直观的认识。
另一方面,在PCA推导过程中利用到了梯度的概念以及梯度上升算法,此时读者根本不知道什么是梯度,也不知道梯度下降算法,强行插入对梯度上升算法的介绍,完全不合逻辑(并且推导也是没头没尾的,又要去参考十万八千里以外的梯度下降法)。
一些个人浅薄之见,欢迎各位批评指正,阅读下来总觉得编者这样剑走偏锋安排的讲解方式对读者尤其是初学者十分不友好,建议编者考虑是否重新安排一下逻辑和数学上的连贯性。
回到评测计划上来,我之前在计划中也提到过,我目前在做一个预测电池容量衰退情况的工作。简单介绍一下任务的背景,在当今快速发展的新能源领域,锂离子电池作为核心能量存储单元,其性能稳定性和寿命预测显得尤为重要。电池的循环充放电过程中,不可避免地会出现电解液浓度降低、内阻增大等现象,这些因素共同作用导致电池容量逐渐衰退。因此我们通过采集锂离子电池在不同充放电周期的关键特征数据,建立一个模型,以评估电池的衰退情况。
为了构建一个具有代表性的数据集,我们从马里兰大学提供的公开电池数据(Battery Research Data | Center for Advanced Life Cycle Engineering (umd.edu))中挑选了几组电池数据。在每节电池的充放电周期中,我们选取了电池的一些充放电表现作为特征。因为我实际项目的数据集暂时不能公开,所以我胡诌了一个数据集,仅作为演示使用,每个样本为某一节电池某一周期随意选择充电曲线上的五个位置(作为特征),标签为是否出现衰退的二元标签(0-Negative表示未衰退,1-Positive表示已衰退)。然后生成了训练集和测试集,本质上就是一个二元分类的问题。
1.数据处理(使用了PCA进行降维)
对数据集进行初步探索性分析时,绘制了一下其中两个特征的关系图。由于特征我是随意选择的,可以看到其中第一个和第二个特征(如图1)之间存在非常强烈的线性关系,且有大量数据点彼此重叠。这种重叠的情况暗示了这两个特征可能在区分样本方面并不具有足够的能力。进一步观察其他特征,比如第一个特征和第三个特征(如图2),也发现了一定程度的线性相关性,这表明数据中存在冗余。
为了解决这一问题,可以使用PCA方法对数据进行降维处理。书中虽然提供了使用梯度上升法手动实现PCA的方法和详细的推导过程,也可以参考其他作者发表的阅读报告。但为了简化流程并提高效率,我选择直接使用Python内置的PCA工具。考虑到可视化的便利性,我决定仅保留前两个主成分。具体代码如下:
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
scaler.fit(X_train)
X_train_norm = scaler.transform(X_train)
from sklearn.decomposition import PCA
pca = PCA(n_components=2)
pca.fit(X_train_norm)
X_Rtrain = pca.fit_transform(X_train_norm)
通过PCA转换后的特征分布图,我们可以看到大部分数据点现在分布在了水平方向上,原先的线性关系得到了有效的消解。这表明主成分分析成功地捕捉了数据中分布方差最大的方向。此外,原本重叠的数据区域也有所减少,这进一步证实了所选主成分在表示样本差异性方面的有效性。
在完成训练集的PCA降维处理后,我们面临着对测试集进行相同操作的任务。至关重要的一点是,我们必须确保测试集的处理方式与训练集保持完全一致。这样做不仅有助于避免数据泄露,确保模型评估的公正性。我们首先使用训练集的最大值-最小值归一化测试集,然后将训练集上学习到的主成分分析模型直接应用于测试集的归一化数据。具体的代码如下:
X_test_norm = scaler.transform(X_test)
X_Rtest = pca.transform(X_test_norm)
在对测试集进行PCA降维处理后,我们观察到正负样本的分布形态与训练集保持了高度的一致性。这种一致性对于确保模型评估的准确性至关重要,因为它意味着我们的数据预处理步骤没有引入偏差,模型在训练集和测试集上的表现将是可比的。
2.分类问题(使用了逻辑回归和KNN两种方法)
接下来就是使用书中介绍的机器学习算法,来完成这个二元分类任务。
从样本分布的形状可以看出,存在一定的交叠,最好的方法一定是支持向量机或者神经网络了,但是书中既然没有提到支持向量机,我们就试试别的方法。
首先是回归的方法,逻辑回归是一种特殊的线性回归,它通过引入sigmoid函数将线性回归的输出映射到0和1之间,从而实现二元分类。这一变换不仅保留了线性模型的直观解释性,而且通过引入非线性,增强了模型对复杂数据分布的适应能力。具体来说,当预测值小于0.5时,我们将其分类为0;当预测值大于0.5时,我们将其分类为1。这种方法为我们提供了一种既简单又有效的二元分类界面。逻辑回归的内容书中没有讲,实际只是在线性回归的基础上套一个sigmoid函数,应该是很简单的。具体推导的过程也可以参考其他作者的阅读报告,我这里就直接调用Python的工具了,具体代码如下:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
log_reg = LogisticRegression(max_iter=1000)
log_reg.fit(X_Rtrain, y_train)
y_pred = log_reg.predict(X_Rtest)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy}")
输出:Accuracy: 0.8097643097643098
通过逻辑回归模型的初步分类尝试,我们观察到分类界面(如上图)呈现出明显的线性特征,其实际表现比较一般,正确率大约是80%左右。
为了进一步提升分类效果,书中推荐了KNN(K近邻算法)作为另一种经典的分类算法。KNN的核心思想在于,通过计算测试集中每个样本与训练集中样本的距离,找出最近的K个邻居。这些邻居的多数投票结果将决定测试样本的类别。这种方法能够提供更加灵活的非线性决策边界,从而有望提高分类的准确性。书中给出了手动计算的过程(也非常简单),但我依旧选择是直接调用了Python中的KNN工具,设置用于投票的邻居数量k=3,如下:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(X_Rtrain, y_train)
y_pred = knn.predict(X_Rtest)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy}")
输出:Accuracy: 0.9124579124579124
可以看到这一结果(如上图)不仅在数值上更加接近实际的分类情况,而且在视觉上也更加符合我们的预期,正确率大约在91%左右。
为了进一步探究特征选择和降维对模型性能的影响,我们还尝试了不进行PCA降维,直接使用原始的五个特征进行KNN分类。由于高维数据的可视化难度较大,我们无法直观地展示分类界面,但通过模型评估,我们发现正确率大约为87%左右,略低于经过PCA降维后的结果。
这一对比实验强调了PCA在特征工程中的重要性。通过PCA降维,我们不仅减少了特征间的冗余,还去除了可能影响模型性能的噪声。更重要的是,PCA帮助我们捕捉了特征间的本质差异性,这在高维数据中尤为重要。降维后的数据更加精炼,使得KNN算法能够更加聚焦于数据的关键信息,从而提高了分类的准确性。
- 2024-08-07
-
发表了主题帖:
《人工智能实践教程——从Python入门到机器学习》阅读报告(1)
7月24日收到的书,外观比较完整,书脊上略有一些磨损和褶皱,不过不影响阅读
按照之前的测评计划,今天分享一下对本书第一部分“Python编程”的相关内容,包含一、二、三章。
第一章的内容涵盖了Python的基础语法,包括变量定义、运算操作以及函数的使用。这些概念与其他编程语言有着共通之处,个人体验来说Python相对来说是比较好上手的。之前在使用Python时,常常也是通过官方文档可以找到详尽的语法说明和应用方法。
最近重点学习一下第二章面向对象的相关内容。在之前的项目实践中(大多数情况下是参考或者套用开源的Python代码),频繁地遇到类的定义和实例化,但对其中的深层原理并不十分了解。
通过对第二章内容的学习,理解了类的概念,尤其是建立了面向对象的编程思路。在之前的认知里,“类”无非也是一个类似函数的封装,把具有相同特性的一组变量同统一定义,来方便调用。现在有了更加系统的认识,它定义了一组具有相同属性和行为的对象的结构。每个类都可以看作是一个模板,用来创建具有相同特征的多个实例,同时将数据(属性)和行为(方法)封装在一起,保护了数据的完整性和安全性。
对于类,在之前的接触中其实就有几个不太理解的问题(以之前最常使用的基于torch.nn来定义神经网络的类为例,假定定义的类为myClass,其中需要传入几个参数如para_a, para_b, ...):
1. 为什么定义类的时候括号里写的是def myClass(nn.Module),但是调用的时候同样的myClass(para_a,para_b,...),括号里写的是需要传入的参数变量?
这是因为定义的类继承自python构建神经网络模块的基础类nn.Module,自定义的新类myClass继承了nn.Module中所有属性和方法,也可以根据自己的需求,增加定义新的属性和方法,来获得所需的新网络;而在后续的调用,实际上是基于定义好的类(模板)实例化一个具体的对象,需要给这个把参数传递给具体对象的构造器__init__。
2. 所有class在定义过程中都会有def __init__(self, para_a, para_b, ...)这个过程,是什么作用?
__init__为构造器,用于在创建类的新实例时,初始化对象的属性。其需要传入的这些参数用于具体实例对象的创建,参数列表中的第一个参数总是self,它代表当前的实例对象本身。
3. 使用def myClass(nn.Module)定义网络的类的时候,def __init__(self, para_a, para_b, ...)初始化参数总是需要在下边第一行写super(myClass, self).init()?
super()函数在用于调用父类的方法,因为自定义的类继承自nn.Module,需要确保父类被正确地初始化。
4. 初始化有一些参数会在def __init__ (self, para_a, para_b)里面定义,但是在同一类后续的新方法def func()中,不能通过直接的para_a来传递参数,需要使用self.para_a=para_a 在__init__步骤中先赋值,这是为什么?
在类的__init__方法中定义参数时,实际上是在为类的对象实例设置属性。当你写self.para_a=para_a时,你不仅仅是在执行一个赋值操作,你还在定义了一个属性para_a,这个属性与创建的对象相关联,并且存储在对象的实例字典中。如果在__init__方法之外直接使用para_a,Python解释器会认为它是一个局部变量,只在__init__的作用域内是可见的。如果尝试在其他方法中直接使用,会因为Python在当前方法的作用域找不到这个变量而报错。
5. 为什么init前后需要加__,而后面定义的那些方法不需要?
__开头和结尾的方法或属性被为特殊方法。这些方法通常不需要直接调用,而是由Python解释器在特定的情况下自动调用。如果类名为myClass,那么__init__方法实际上在类内部被改写为_myClass__init__。这样做同时也能防止子类不小心覆盖父类的特殊方法。
另外有一些之前没有接触过的新知识,帮助更好的定义和使用类:
1. 可以通过__str__(self, 'myString')的方法,控制对象实例在被print()函数调用时所展示的内容。这个方法允许输出特定的字符串,而不是默认的对象内存地址。例如,之前曾遇到过想试试通过print查看对象的结果,打印输出默认的内存地址<object at 0x10e9e1be0>之类的,可以定义__str__方法来提供更有意义的信息,如对象的状态或属性值。
2. 使用@property装饰器,可以为类的属性定义只读访问器,或者在设置属性值之前添加逻辑限制。这不仅增强了数据的安全性,还提供了一种优雅的方式来实现属性的验证和封装。
3. 使用@total_ordering装饰器,可以更高效地实现类实例之间的比较运算。这个装饰器允许只定义一部分比较方法,Python会自动提供其他的比较方法。
第三章高级编程另外还涉及比如Python闭包、其他的装饰器、迭代器和生成器。比较重要的是迭代器和生成器,它们是Python中实现惰性求值的强大工具。通过使用迭代器,可以遍历大量数据而无需一次性将它们全部加载到内存中。生成器则是一种特殊的迭代器,它在每次迭代时计算下一个值,从而有效避免了不必要的内存占用。这种按需计算的方式,不仅优化了性能,也提高了程序的可扩展性。
从之前的经验来看,这部分暂时会被使用到的不多。等后续涉及具体的项目时,再回头来查看作为进阶的内容。
有一些小错误,还希望作者和编辑再仔细校对。另外一些概念的讲解并不是很容易理解,前后逻辑顺序读着略有些混乱,阅读体验并不是很好。