完整的OTA升级系统设计¶
项目概述¶
OTA (Over-The-Air) 升级是现代物联网设备的核心功能之一,它允许设备通过网络远程更新固件,无需物理接触。本项目将带你从零开始构建一个企业级的OTA升级系统,涵盖从云端管理到设备端实现的完整解决方案。
项目目标¶
构建一个功能完整、安全可靠的OTA升级系统,包括:
云端管理平台: - 固件版本管理和发布 - 设备管理和分组 - 升级任务调度和监控 - 升级统计和报告 - 差分包生成和管理
设备端功能: - 固件下载和验证 - 断点续传支持 - 双区原子升级 - 安全启动验证 - 升级状态上报
核心特性: - 支持MQTT/HTTP协议 - 差分升级节省流量 - 断点续传提高可靠性 - 数字签名保证安全 - 灰度发布降低风险 - 完整的日志和监控
技术栈¶
云端技术: - 后端:Python/Flask 或 Node.js/Express - 数据库:PostgreSQL/MySQL - 消息队列:MQTT Broker (Mosquitto/EMQ X) - 对象存储:MinIO/AWS S3 - 前端:Vue.js/React
设备端技术: - MCU:STM32F4系列(可移植到其他平台) - 通信:MQTT/HTTP - 加密:mbedTLS - 差分算法:bsdiff/courgette - 存储:双区Flash
项目交付物¶
完成本项目后,你将获得:
- 云端管理系统
- 固件管理Web界面
- RESTful API服务
- MQTT消息服务
-
数据库设计和实现
-
设备端代码
- Bootloader程序
- OTA客户端库
- 示例应用程序
-
测试工具
-
文档和工具
- 系统架构文档
- API接口文档
- 部署指南
-
测试用例
-
完整示例
- 端到端升级演示
- 异常处理示例
- 性能测试报告
前置知识¶
在开始本项目之前,你需要:
必备技能: - 熟练掌握C语言编程 - 理解Bootloader和IAP原理 - 了解网络通信协议(TCP/IP、HTTP、MQTT) - 掌握Flash存储器操作 - 熟悉Linux基本操作
推荐技能: - Python或Node.js后端开发经验 - 前端开发基础(HTML/CSS/JavaScript) - 数据库设计和SQL - 加密和安全基础知识 - Docker容器技术
硬件要求: - STM32F4开发板(或类似MCU) - 网络模块(WiFi/4G/Ethernet) - 调试器(ST-Link/J-Link) - PC开发环境
系统架构设计¶
整体架构¶
graph TB
subgraph "云端平台"
A[Web管理界面] --> B[API服务器]
B --> C[数据库]
B --> D[对象存储]
B --> E[MQTT Broker]
end
subgraph "设备端"
F[应用程序] --> G[OTA客户端]
G --> H[Bootloader]
H --> I[Flash存储]
end
E <--> |MQTT| G
D <--> |HTTP| G
subgraph "开发工具"
J[固件构建工具]
K[差分包生成器]
L[签名工具]
end
J --> D
K --> D
L --> D
核心模块¶
1. 云端管理平台
云端平台
├── 固件管理模块
│ ├── 固件上传和存储
│ ├── 版本管理
│ ├── 差分包生成
│ └── 固件签名
├── 设备管理模块
│ ├── 设备注册和认证
│ ├── 设备分组
│ ├── 设备状态监控
│ └── 设备信息查询
├── 升级任务模块
│ ├── 任务创建和调度
│ ├── 灰度发布
│ ├── 任务监控
│ └── 任务统计
└── 通信模块
├── MQTT消息服务
├── HTTP文件服务
└── WebSocket实时通信
2. 设备端OTA客户端
OTA客户端
├── 通信层
│ ├── MQTT客户端
│ ├── HTTP客户端
│ └── 断点续传
├── 下载管理
│ ├── 固件下载
│ ├── 进度管理
│ ├── 缓存管理
│ └── 错误重试
├── 验证层
│ ├── 数字签名验证
│ ├── CRC校验
│ ├── 版本检查
│ └── 兼容性验证
└── 升级执行
├── Flash操作
├── 双区切换
├── 状态上报
└── 回滚处理
数据流程¶
升级流程:
sequenceDiagram
participant Dev as 开发者
participant Cloud as 云端平台
participant Device as 设备
participant Boot as Bootloader
Dev->>Cloud: 1. 上传新固件
Cloud->>Cloud: 2. 生成差分包
Cloud->>Cloud: 3. 签名固件
Cloud->>Cloud: 4. 创建升级任务
Cloud->>Device: 5. 推送升级通知
Device->>Cloud: 6. 查询固件信息
Cloud-->>Device: 7. 返回下载URL
Device->>Cloud: 8. 下载固件(支持断点续传)
Device->>Device: 9. 验证签名和CRC
Device->>Device: 10. 写入备份分区
Device->>Cloud: 11. 上报下载完成
Device->>Device: 12. 设置升级标志
Device->>Device: 13. 重启系统
Boot->>Boot: 14. 验证新固件
Boot->>Boot: 15. 切换到新分区
Boot->>Device: 16. 启动新固件
Device->>Cloud: 17. 上报升级成功
核心功能实现¶
1. 云端管理平台¶
1.1 数据库设计¶
-- 固件版本表
CREATE TABLE firmware_versions (
id SERIAL PRIMARY KEY,
version VARCHAR(50) NOT NULL UNIQUE,
hardware_version VARCHAR(50) NOT NULL,
file_path VARCHAR(255) NOT NULL,
file_size BIGINT NOT NULL,
file_md5 VARCHAR(32) NOT NULL,
signature TEXT,
release_notes TEXT,
is_differential BOOLEAN DEFAULT FALSE,
base_version VARCHAR(50),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
created_by VARCHAR(100),
status VARCHAR(20) DEFAULT 'draft'
);
-- 设备表
CREATE TABLE devices (
id SERIAL PRIMARY KEY,
device_id VARCHAR(100) NOT NULL UNIQUE,
device_name VARCHAR(200),
hardware_version VARCHAR(50),
current_version VARCHAR(50),
group_id INTEGER,
last_online TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
status VARCHAR(20) DEFAULT 'online'
);
-- 升级任务表
CREATE TABLE upgrade_tasks (
id SERIAL PRIMARY KEY,
task_name VARCHAR(200) NOT NULL,
firmware_version_id INTEGER REFERENCES firmware_versions(id),
target_devices TEXT, -- JSON array of device IDs
strategy VARCHAR(50) DEFAULT 'immediate', -- immediate, scheduled, gradual
schedule_time TIMESTAMP,
gradual_percentage INTEGER DEFAULT 100,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
status VARCHAR(20) DEFAULT 'pending'
);
-- 升级记录表
CREATE TABLE upgrade_records (
id SERIAL PRIMARY KEY,
task_id INTEGER REFERENCES upgrade_tasks(id),
device_id VARCHAR(100) REFERENCES devices(device_id),
from_version VARCHAR(50),
to_version VARCHAR(50),
start_time TIMESTAMP,
end_time TIMESTAMP,
status VARCHAR(20), -- downloading, installing, success, failed
error_message TEXT,
download_progress INTEGER DEFAULT 0
);
-- 设备分组表
CREATE TABLE device_groups (
id SERIAL PRIMARY KEY,
group_name VARCHAR(200) NOT NULL,
description TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
1.2 API服务实现(Python/Flask)¶
from flask import Flask, request, jsonify
from flask_sqlalchemy import SQLAlchemy
import hashlib
import os
from datetime import datetime
app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'postgresql://user:pass@localhost/ota_db'
db = SQLAlchemy(app)
# 固件上传接口
@app.route('/api/firmware/upload', methods=['POST'])
def upload_firmware():
"""上传固件文件"""
if 'file' not in request.files:
return jsonify({'error': 'No file provided'}), 400
file = request.files['file']
version = request.form.get('version')
hardware_version = request.form.get('hardware_version')
release_notes = request.form.get('release_notes', '')
# 验证参数
if not version or not hardware_version:
return jsonify({'error': 'Missing required parameters'}), 400
# 计算文件MD5
file_content = file.read()
file_md5 = hashlib.md5(file_content).hexdigest()
file_size = len(file_content)
# 保存文件
file_path = f'firmware/{hardware_version}/{version}.bin'
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'wb') as f:
f.write(file_content)
# 保存到数据库
firmware = FirmwareVersion(
version=version,
hardware_version=hardware_version,
file_path=file_path,
file_size=file_size,
file_md5=file_md5,
release_notes=release_notes,
status='draft'
)
db.session.add(firmware)
db.session.commit()
return jsonify({
'success': True,
'firmware_id': firmware.id,
'version': version,
'file_size': file_size,
'md5': file_md5
})
# 创建升级任务
@app.route('/api/tasks/create', methods=['POST'])
def create_upgrade_task():
"""创建升级任务"""
data = request.json
task_name = data.get('task_name')
firmware_version_id = data.get('firmware_version_id')
target_devices = data.get('target_devices', [])
strategy = data.get('strategy', 'immediate')
# 验证固件版本存在
firmware = FirmwareVersion.query.get(firmware_version_id)
if not firmware:
return jsonify({'error': 'Firmware version not found'}), 404
# 创建任务
task = UpgradeTask(
task_name=task_name,
firmware_version_id=firmware_version_id,
target_devices=json.dumps(target_devices),
strategy=strategy,
status='pending'
)
db.session.add(task)
db.session.commit()
# 发送MQTT通知到设备
for device_id in target_devices:
send_upgrade_notification(device_id, firmware.version)
return jsonify({
'success': True,
'task_id': task.id,
'task_name': task_name
})
# 查询固件信息
@app.route('/api/firmware/info/<version>', methods=['GET'])
def get_firmware_info(version):
"""获取固件信息"""
firmware = FirmwareVersion.query.filter_by(version=version).first()
if not firmware:
return jsonify({'error': 'Firmware not found'}), 404
return jsonify({
'version': firmware.version,
'hardware_version': firmware.hardware_version,
'file_size': firmware.file_size,
'md5': firmware.file_md5,
'download_url': f'/api/firmware/download/{firmware.id}',
'release_notes': firmware.release_notes,
'is_differential': firmware.is_differential,
'base_version': firmware.base_version
})
# 固件下载接口(支持断点续传)
@app.route('/api/firmware/download/<int:firmware_id>', methods=['GET'])
def download_firmware(firmware_id):
"""下载固件文件,支持断点续传"""
firmware = FirmwareVersion.query.get(firmware_id)
if not firmware:
return jsonify({'error': 'Firmware not found'}), 404
file_path = firmware.file_path
if not os.path.exists(file_path):
return jsonify({'error': 'File not found'}), 404
# 支持Range请求(断点续传)
range_header = request.headers.get('Range')
if range_header:
# 解析Range头
byte_range = range_header.replace('bytes=', '').split('-')
start = int(byte_range[0])
end = int(byte_range[1]) if byte_range[1] else firmware.file_size - 1
# 读取指定范围的数据
with open(file_path, 'rb') as f:
f.seek(start)
data = f.read(end - start + 1)
response = app.response_class(
data,
status=206, # Partial Content
mimetype='application/octet-stream'
)
response.headers['Content-Range'] = f'bytes {start}-{end}/{firmware.file_size}'
response.headers['Content-Length'] = str(len(data))
return response
else:
# 完整下载
return send_file(file_path, as_attachment=True)
# 设备状态上报
@app.route('/api/device/status', methods=['POST'])
def report_device_status():
"""设备上报状态"""
data = request.json
device_id = data.get('device_id')
current_version = data.get('current_version')
status = data.get('status')
# 更新设备信息
device = Device.query.filter_by(device_id=device_id).first()
if device:
device.current_version = current_version
device.status = status
device.last_online = datetime.now()
db.session.commit()
return jsonify({'success': True})
# 升级进度上报
@app.route('/api/upgrade/progress', methods=['POST'])
def report_upgrade_progress():
"""上报升级进度"""
data = request.json
device_id = data.get('device_id')
task_id = data.get('task_id')
progress = data.get('progress')
status = data.get('status')
error_message = data.get('error_message')
# 更新升级记录
record = UpgradeRecord.query.filter_by(
task_id=task_id,
device_id=device_id
).first()
if record:
record.download_progress = progress
record.status = status
record.error_message = error_message
if status == 'success':
record.end_time = datetime.now()
db.session.commit()
return jsonify({'success': True})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=True)
1.3 MQTT消息服务¶
import paho.mqtt.client as mqtt
import json
class OTAMQTTService:
def __init__(self, broker_host='localhost', broker_port=1883):
self.client = mqtt.Client()
self.client.on_connect = self.on_connect
self.client.on_message = self.on_message
self.client.connect(broker_host, broker_port, 60)
def on_connect(self, client, userdata, flags, rc):
"""连接成功回调"""
print(f"Connected to MQTT broker with result code {rc}")
# 订阅设备状态主题
client.subscribe("ota/device/+/status")
client.subscribe("ota/device/+/progress")
def on_message(self, client, userdata, msg):
"""接收消息回调"""
topic = msg.topic
payload = json.loads(msg.payload.decode())
if '/status' in topic:
# 处理设备状态消息
self.handle_device_status(payload)
elif '/progress' in topic:
# 处理升级进度消息
self.handle_upgrade_progress(payload)
def handle_device_status(self, data):
"""处理设备状态"""
device_id = data.get('device_id')
version = data.get('version')
# 更新数据库
device = Device.query.filter_by(device_id=device_id).first()
if device:
device.current_version = version
device.last_online = datetime.now()
db.session.commit()
def handle_upgrade_progress(self, data):
"""处理升级进度"""
device_id = data.get('device_id')
progress = data.get('progress')
status = data.get('status')
# 更新升级记录
# ... 数据库操作
def send_upgrade_notification(self, device_id, firmware_version):
"""发送升级通知到设备"""
topic = f"ota/device/{device_id}/upgrade"
message = {
'action': 'upgrade',
'firmware_version': firmware_version,
'timestamp': datetime.now().isoformat()
}
self.client.publish(topic, json.dumps(message), qos=1)
print(f"Sent upgrade notification to {device_id}")
def run(self):
"""启动MQTT服务"""
self.client.loop_forever()
# 启动MQTT服务
if __name__ == '__main__':
mqtt_service = OTAMQTTService()
mqtt_service.run()
1.4 差分包生成工具¶
import bsdiff4
import os
import hashlib
class DifferentialPackageGenerator:
"""差分包生成器"""
def __init__(self):
pass
def generate_diff(self, old_firmware_path, new_firmware_path, output_path):
"""生成差分包"""
print(f"Generating differential package...")
print(f"Old firmware: {old_firmware_path}")
print(f"New firmware: {new_firmware_path}")
# 读取固件文件
with open(old_firmware_path, 'rb') as f:
old_data = f.read()
with open(new_firmware_path, 'rb') as f:
new_data = f.read()
# 生成差分包
diff_data = bsdiff4.diff(old_data, new_data)
# 保存差分包
with open(output_path, 'wb') as f:
f.write(diff_data)
# 计算压缩率
old_size = len(old_data)
new_size = len(new_data)
diff_size = len(diff_data)
compression_ratio = (1 - diff_size / new_size) * 100
print(f"Old firmware size: {old_size} bytes")
print(f"New firmware size: {new_size} bytes")
print(f"Differential package size: {diff_size} bytes")
print(f"Compression ratio: {compression_ratio:.2f}%")
return {
'old_size': old_size,
'new_size': new_size,
'diff_size': diff_size,
'compression_ratio': compression_ratio,
'diff_md5': hashlib.md5(diff_data).hexdigest()
}
def apply_diff(self, old_firmware_path, diff_path, output_path):
"""应用差分包"""
print(f"Applying differential package...")
# 读取旧固件和差分包
with open(old_firmware_path, 'rb') as f:
old_data = f.read()
with open(diff_path, 'rb') as f:
diff_data = f.read()
# 应用差分
new_data = bsdiff4.patch(old_data, diff_data)
# 保存新固件
with open(output_path, 'wb') as f:
f.write(new_data)
print(f"New firmware generated: {output_path}")
print(f"Size: {len(new_data)} bytes")
return len(new_data)
# 使用示例
if __name__ == '__main__':
generator = DifferentialPackageGenerator()
# 生成差分包
result = generator.generate_diff(
'firmware_v1.0.0.bin',
'firmware_v1.1.0.bin',
'diff_v1.0.0_to_v1.1.0.patch'
)
print(f"\nDifferential package generated successfully!")
print(f"MD5: {result['diff_md5']}")
2. 设备端OTA客户端¶
2.1 OTA客户端核心结构¶
// ota_client.h
#ifndef OTA_CLIENT_H
#define OTA_CLIENT_H
#include <stdint.h>
#include <stdbool.h>
// OTA状态
typedef enum {
OTA_STATE_IDLE = 0,
OTA_STATE_CHECKING,
OTA_STATE_DOWNLOADING,
OTA_STATE_VERIFYING,
OTA_STATE_INSTALLING,
OTA_STATE_COMPLETE,
OTA_STATE_ERROR
} OTA_State_t;
// OTA配置
typedef struct {
char device_id[64];
char current_version[32];
char hardware_version[32];
char server_url[256];
char mqtt_broker[128];
uint16_t mqtt_port;
uint32_t check_interval; // 检查更新间隔(秒)
} OTA_Config_t;
// OTA信息
typedef struct {
OTA_State_t state;
char target_version[32];
uint32_t firmware_size;
uint32_t downloaded_size;
uint32_t firmware_crc32;
char download_url[256];
bool is_differential;
char base_version[32];
uint8_t retry_count;
} OTA_Info_t;
// OTA客户端API
void OTA_Init(OTA_Config_t *config);
void OTA_Task(void);
bool OTA_CheckUpdate(void);
bool OTA_StartDownload(void);
bool OTA_VerifyFirmware(void);
bool OTA_InstallFirmware(void);
void OTA_ReportProgress(uint8_t progress);
void OTA_ReportStatus(const char *status, const char *message);
#endif // OTA_CLIENT_H
2.2 MQTT通信实现¶
// ota_mqtt.c
#include "ota_mqtt.h"
#include "MQTTClient.h"
#include <string.h>
#include <stdio.h>
static MQTTClient mqtt_client;
static OTA_Config_t *ota_config;
static OTA_Info_t *ota_info;
// MQTT消息回调
void mqtt_message_callback(MessageData *data) {
char topic[128];
char payload[512];
// 提取主题和消息
memcpy(topic, data->topicName->lenstring.data,
data->topicName->lenstring.len);
topic[data->topicName->lenstring.len] = '\0';
memcpy(payload, data->message->payload,
data->message->payloadlen);
payload[data->message->payloadlen] = '\0';
printf("MQTT message received on topic: %s\r\n", topic);
printf("Payload: %s\r\n", payload);
// 检查是否是升级通知
char upgrade_topic[128];
snprintf(upgrade_topic, sizeof(upgrade_topic),
"ota/device/%s/upgrade", ota_config->device_id);
if (strcmp(topic, upgrade_topic) == 0) {
// 解析升级通知
cJSON *json = cJSON_Parse(payload);
if (json) {
cJSON *action = cJSON_GetObjectItem(json, "action");
cJSON *version = cJSON_GetObjectItem(json, "firmware_version");
if (action && strcmp(action->valuestring, "upgrade") == 0) {
if (version) {
printf("Upgrade notification received: %s\r\n",
version->valuestring);
// 触发OTA检查更新
OTA_CheckUpdate();
}
}
cJSON_Delete(json);
}
}
}
// 初始化MQTT
bool OTA_MQTT_Init(OTA_Config_t *config, OTA_Info_t *info) {
ota_config = config;
ota_info = info;
Network network;
MQTTClientInit(&mqtt_client, &network, 1000, NULL, 0, NULL, 0);
// 连接到MQTT Broker
MQTTPacket_connectData connect_data = MQTTPacket_connectData_initializer;
connect_data.MQTTVersion = 3;
connect_data.clientID.cstring = config->device_id;
if (NetworkConnect(&network, config->mqtt_broker, config->mqtt_port) != 0) {
printf("Failed to connect to MQTT broker\r\n");
return false;
}
if (MQTTConnect(&mqtt_client, &connect_data) != 0) {
printf("Failed to connect to MQTT\r\n");
return false;
}
printf("Connected to MQTT broker\r\n");
// 订阅升级主题
char topic[128];
snprintf(topic, sizeof(topic), "ota/device/%s/upgrade", config->device_id);
if (MQTTSubscribe(&mqtt_client, topic, QOS1, mqtt_message_callback) != 0) {
printf("Failed to subscribe to topic: %s\r\n", topic);
return false;
}
printf("Subscribed to topic: %s\r\n", topic);
return true;
}
// 发布设备状态
void OTA_MQTT_PublishStatus(const char *status, const char *message) {
char topic[128];
char payload[256];
snprintf(topic, sizeof(topic), "ota/device/%s/status",
ota_config->device_id);
cJSON *json = cJSON_CreateObject();
cJSON_AddStringToObject(json, "device_id", ota_config->device_id);
cJSON_AddStringToObject(json, "version", ota_config->current_version);
cJSON_AddStringToObject(json, "status", status);
cJSON_AddStringToObject(json, "message", message);
char *json_str = cJSON_Print(json);
MQTTMessage message_data;
message_data.qos = QOS1;
message_data.retained = 0;
message_data.payload = json_str;
message_data.payloadlen = strlen(json_str);
MQTTPublish(&mqtt_client, topic, &message_data);
free(json_str);
cJSON_Delete(json);
}
// 发布升级进度
void OTA_MQTT_PublishProgress(uint8_t progress) {
char topic[128];
snprintf(topic, sizeof(topic), "ota/device/%s/progress",
ota_config->device_id);
cJSON *json = cJSON_CreateObject();
cJSON_AddStringToObject(json, "device_id", ota_config->device_id);
cJSON_AddNumberToObject(json, "progress", progress);
cJSON_AddStringToObject(json, "status",
ota_info->state == OTA_STATE_DOWNLOADING ?
"downloading" : "installing");
char *json_str = cJSON_Print(json);
MQTTMessage message;
message.qos = QOS0;
message.retained = 0;
message.payload = json_str;
message.payloadlen = strlen(json_str);
MQTTPublish(&mqtt_client, topic, &message);
free(json_str);
cJSON_Delete(json);
}
// MQTT任务
void OTA_MQTT_Task(void) {
MQTTYield(&mqtt_client, 100);
}
2.3 HTTP下载实现(支持断点续传)¶
// ota_download.c
#include "ota_download.h"
#include "http_client.h"
#include <string.h>
#include <stdio.h>
#define DOWNLOAD_BUFFER_SIZE 4096
#define MAX_RETRY_COUNT 3
static uint8_t download_buffer[DOWNLOAD_BUFFER_SIZE];
static OTA_Info_t *ota_info;
// HTTP下载回调
static int http_download_callback(void *data, size_t size, void *userdata) {
uint32_t *downloaded = (uint32_t*)userdata;
// 写入Flash
uint32_t write_addr = FIRMWARE_BACKUP_ADDRESS + *downloaded;
if (!Flash_Write(write_addr, (uint32_t*)data, size / 4)) {
printf("Flash write failed at 0x%08X\r\n", write_addr);
return -1;
}
*downloaded += size;
// 计算进度
uint8_t progress = (*downloaded * 100) / ota_info->firmware_size;
// 上报进度(每5%上报一次)
static uint8_t last_progress = 0;
if (progress >= last_progress + 5) {
OTA_ReportProgress(progress);
last_progress = progress;
}
return 0;
}
// 下载固件(支持断点续传)
bool OTA_DownloadFirmware(OTA_Info_t *info) {
ota_info = info;
printf("Starting firmware download...\r\n");
printf("URL: %s\r\n", info->download_url);
printf("Size: %u bytes\r\n", info->firmware_size);
// 检查已下载的大小
uint32_t downloaded = info->downloaded_size;
if (downloaded > 0) {
printf("Resuming download from %u bytes\r\n", downloaded);
}
// 准备HTTP请求
http_client_t client;
http_client_init(&client);
// 设置Range头(断点续传)
if (downloaded > 0) {
char range_header[64];
snprintf(range_header, sizeof(range_header),
"bytes=%u-%u", downloaded, info->firmware_size - 1);
http_client_set_header(&client, "Range", range_header);
}
// 设置下载回调
http_client_set_callback(&client, http_download_callback, &downloaded);
// 开始下载
int retry_count = 0;
bool success = false;
while (retry_count < MAX_RETRY_COUNT && !success) {
int result = http_client_get(&client, info->download_url);
if (result == 0 && downloaded == info->firmware_size) {
success = true;
printf("Download complete!\r\n");
} else {
retry_count++;
printf("Download failed, retry %d/%d\r\n",
retry_count, MAX_RETRY_COUNT);
// 保存当前下载进度
info->downloaded_size = downloaded;
// 延时后重试
HAL_Delay(5000);
}
}
http_client_cleanup(&client);
return success;
}
2.4 固件验证实现¶
// ota_verify.c
#include "ota_verify.h"
#include "mbedtls/md.h"
#include "mbedtls/pk.h"
#include "mbedtls/sha256.h"
#include <string.h>
// 公钥(用于验证签名)
static const char *public_key_pem =
"-----BEGIN PUBLIC KEY-----\n"
"MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA...\n"
"-----END PUBLIC KEY-----\n";
// 计算固件CRC32
uint32_t OTA_CalculateCRC32(uint32_t address, uint32_t size) {
uint32_t crc = 0xFFFFFFFF;
uint8_t *data = (uint8_t*)address;
for (uint32_t i = 0; i < size; i++) {
crc ^= data[i];
for (uint8_t j = 0; j < 8; j++) {
if (crc & 1) {
crc = (crc >> 1) ^ 0xEDB88320;
} else {
crc = crc >> 1;
}
}
}
return ~crc;
}
// 计算固件SHA256
bool OTA_CalculateSHA256(uint32_t address, uint32_t size, uint8_t *hash) {
mbedtls_sha256_context ctx;
mbedtls_sha256_init(&ctx);
mbedtls_sha256_starts(&ctx, 0); // 0 = SHA256
// 分块计算
uint32_t remaining = size;
uint32_t offset = 0;
uint8_t buffer[1024];
while (remaining > 0) {
uint32_t chunk_size = (remaining > sizeof(buffer)) ?
sizeof(buffer) : remaining;
memcpy(buffer, (void*)(address + offset), chunk_size);
mbedtls_sha256_update(&ctx, buffer, chunk_size);
offset += chunk_size;
remaining -= chunk_size;
}
mbedtls_sha256_finish(&ctx, hash);
mbedtls_sha256_free(&ctx);
return true;
}
// 验证数字签名
bool OTA_VerifySignature(uint32_t address, uint32_t size,
const uint8_t *signature, size_t sig_len) {
mbedtls_pk_context pk;
uint8_t hash[32];
int ret;
// 初始化公钥
mbedtls_pk_init(&pk);
ret = mbedtls_pk_parse_public_key(&pk,
(const unsigned char*)public_key_pem,
strlen(public_key_pem) + 1);
if (ret != 0) {
printf("Failed to parse public key: -0x%04x\r\n", -ret);
mbedtls_pk_free(&pk);
return false;
}
// 计算固件哈希
if (!OTA_CalculateSHA256(address, size, hash)) {
printf("Failed to calculate SHA256\r\n");
mbedtls_pk_free(&pk);
return false;
}
// 验证签名
ret = mbedtls_pk_verify(&pk, MBEDTLS_MD_SHA256,
hash, sizeof(hash),
signature, sig_len);
mbedtls_pk_free(&pk);
if (ret != 0) {
printf("Signature verification failed: -0x%04x\r\n", -ret);
return false;
}
printf("Signature verification passed\r\n");
return true;
}
// 验证固件完整性
bool OTA_VerifyFirmware(OTA_Info_t *info) {
uint32_t address = FIRMWARE_BACKUP_ADDRESS;
uint32_t size = info->firmware_size;
printf("Verifying firmware...\r\n");
// 1. CRC32校验
printf("Calculating CRC32...\r\n");
uint32_t calculated_crc = OTA_CalculateCRC32(address, size);
if (calculated_crc != info->firmware_crc32) {
printf("CRC32 mismatch: expected 0x%08X, got 0x%08X\r\n",
info->firmware_crc32, calculated_crc);
return false;
}
printf("CRC32 verification passed\r\n");
// 2. 数字签名验证(如果有)
if (info->has_signature) {
printf("Verifying digital signature...\r\n");
if (!OTA_VerifySignature(address, size,
info->signature, info->signature_len)) {
printf("Digital signature verification failed\r\n");
return false;
}
}
// 3. 版本检查
FirmwareInfo_t *fw_info = (FirmwareInfo_t*)(address + 0x100);
if (fw_info->magic != FIRMWARE_MAGIC) {
printf("Invalid firmware magic number\r\n");
return false;
}
printf("Firmware version: %d.%d.%d\r\n",
fw_info->major, fw_info->minor, fw_info->build);
// 4. 硬件兼容性检查
if (strcmp(fw_info->hardware_version, info->hardware_version) != 0) {
printf("Hardware version mismatch\r\n");
return false;
}
printf("Firmware verification complete\r\n");
return true;
}
2.5 差分升级实现¶
// ota_differential.c
#include "ota_differential.h"
#include "bspatch.h"
#include <string.h>
#define PATCH_BUFFER_SIZE 4096
// 应用差分包
bool OTA_ApplyDifferentialPatch(const char *base_version,
uint32_t patch_address,
uint32_t patch_size) {
printf("Applying differential patch...\r\n");
printf("Base version: %s\r\n", base_version);
printf("Patch size: %u bytes\r\n", patch_size);
// 1. 验证基础版本
FirmwareInfo_t *current_fw = (FirmwareInfo_t*)(FIRMWARE_ACTIVE_ADDRESS + 0x100);
char current_version[32];
snprintf(current_version, sizeof(current_version), "%d.%d.%d",
current_fw->major, current_fw->minor, current_fw->build);
if (strcmp(current_version, base_version) != 0) {
printf("Base version mismatch: expected %s, got %s\r\n",
base_version, current_version);
return false;
}
// 2. 读取当前固件到RAM(如果RAM足够大)
// 或者使用流式处理
// 3. 应用bsdiff补丁
uint8_t *old_data = (uint8_t*)FIRMWARE_ACTIVE_ADDRESS;
uint8_t *patch_data = (uint8_t*)patch_address;
uint8_t *new_data = (uint8_t*)FIRMWARE_BACKUP_ADDRESS;
// 使用bspatch算法应用补丁
int result = bspatch(old_data, current_fw->firmware_size,
patch_data, patch_size,
new_data);
if (result != 0) {
printf("Failed to apply patch: error %d\r\n", result);
return false;
}
printf("Differential patch applied successfully\r\n");
return true;
}
// 检查是否支持差分升级
bool OTA_IsDifferentialSupported(const char *base_version) {
FirmwareInfo_t *current_fw = (FirmwareInfo_t*)(FIRMWARE_ACTIVE_ADDRESS + 0x100);
char current_version[32];
snprintf(current_version, sizeof(current_version), "%d.%d.%d",
current_fw->major, current_fw->minor, current_fw->build);
return (strcmp(current_version, base_version) == 0);
}
2.6 OTA主控制逻辑¶
// ota_client.c
#include "ota_client.h"
#include "ota_mqtt.h"
#include "ota_download.h"
#include "ota_verify.h"
#include "ota_differential.h"
#include <string.h>
#include <stdio.h>
static OTA_Config_t ota_config;
static OTA_Info_t ota_info;
static uint32_t last_check_time = 0;
// 初始化OTA客户端
void OTA_Init(OTA_Config_t *config) {
memcpy(&ota_config, config, sizeof(OTA_Config_t));
memset(&ota_info, 0, sizeof(OTA_Info_t));
ota_info.state = OTA_STATE_IDLE;
// 初始化MQTT
if (!OTA_MQTT_Init(&ota_config, &ota_info)) {
printf("Failed to initialize MQTT\r\n");
return;
}
printf("OTA client initialized\r\n");
printf("Device ID: %s\r\n", ota_config.device_id);
printf("Current version: %s\r\n", ota_config.current_version);
// 上报设备状态
OTA_ReportStatus("online", "Device started");
}
// 检查更新
bool OTA_CheckUpdate(void) {
printf("Checking for updates...\r\n");
ota_info.state = OTA_STATE_CHECKING;
// 构建API URL
char url[512];
snprintf(url, sizeof(url),
"%s/api/firmware/check?device_id=%s&version=%s&hardware=%s",
ota_config.server_url,
ota_config.device_id,
ota_config.current_version,
ota_config.hardware_version);
// 发送HTTP请求
http_client_t client;
http_client_init(&client);
char response[1024];
int result = http_client_get_string(&client, url, response, sizeof(response));
http_client_cleanup(&client);
if (result != 0) {
printf("Failed to check for updates\r\n");
ota_info.state = OTA_STATE_IDLE;
return false;
}
// 解析响应
cJSON *json = cJSON_Parse(response);
if (!json) {
printf("Failed to parse response\r\n");
ota_info.state = OTA_STATE_IDLE;
return false;
}
cJSON *update_available = cJSON_GetObjectItem(json, "update_available");
if (!update_available || !update_available->valueint) {
printf("No update available\r\n");
cJSON_Delete(json);
ota_info.state = OTA_STATE_IDLE;
return false;
}
// 提取固件信息
cJSON *version = cJSON_GetObjectItem(json, "version");
cJSON *size = cJSON_GetObjectItem(json, "file_size");
cJSON *crc32 = cJSON_GetObjectItem(json, "crc32");
cJSON *download_url = cJSON_GetObjectItem(json, "download_url");
cJSON *is_diff = cJSON_GetObjectItem(json, "is_differential");
cJSON *base_ver = cJSON_GetObjectItem(json, "base_version");
if (version) {
strncpy(ota_info.target_version, version->valuestring,
sizeof(ota_info.target_version) - 1);
}
if (size) {
ota_info.firmware_size = size->valueint;
}
if (crc32) {
sscanf(crc32->valuestring, "%x", &ota_info.firmware_crc32);
}
if (download_url) {
strncpy(ota_info.download_url, download_url->valuestring,
sizeof(ota_info.download_url) - 1);
}
if (is_diff && is_diff->valueint) {
ota_info.is_differential = true;
if (base_ver) {
strncpy(ota_info.base_version, base_ver->valuestring,
sizeof(ota_info.base_version) - 1);
}
}
cJSON_Delete(json);
printf("Update available: %s\r\n", ota_info.target_version);
printf("Size: %u bytes\r\n", ota_info.firmware_size);
printf("Differential: %s\r\n", ota_info.is_differential ? "Yes" : "No");
ota_info.state = OTA_STATE_IDLE;
return true;
}
// 开始下载
bool OTA_StartDownload(void) {
if (ota_info.state != OTA_STATE_IDLE) {
printf("OTA is busy\r\n");
return false;
}
printf("Starting OTA download...\r\n");
ota_info.state = OTA_STATE_DOWNLOADING;
ota_info.downloaded_size = 0;
ota_info.retry_count = 0;
// 擦除备份分区
printf("Erasing backup partition...\r\n");
Flash_EraseSectors(FIRMWARE_BACKUP_ADDRESS, ota_info.firmware_size);
// 下载固件
bool success = OTA_DownloadFirmware(&ota_info);
if (!success) {
printf("Download failed\r\n");
ota_info.state = OTA_STATE_ERROR;
OTA_ReportStatus("error", "Download failed");
return false;
}
printf("Download complete\r\n");
ota_info.state = OTA_STATE_VERIFYING;
return true;
}
// 验证固件
bool OTA_VerifyFirmware(void) {
if (ota_info.state != OTA_STATE_VERIFYING) {
printf("Invalid state for verification\r\n");
return false;
}
printf("Verifying firmware...\r\n");
// 如果是差分包,先应用补丁
if (ota_info.is_differential) {
printf("Applying differential patch...\r\n");
if (!OTA_ApplyDifferentialPatch(ota_info.base_version,
FIRMWARE_BACKUP_ADDRESS,
ota_info.firmware_size)) {
printf("Failed to apply differential patch\r\n");
ota_info.state = OTA_STATE_ERROR;
OTA_ReportStatus("error", "Patch application failed");
return false;
}
}
// 验证固件
bool success = OTA_VerifyFirmware(&ota_info);
if (!success) {
printf("Firmware verification failed\r\n");
ota_info.state = OTA_STATE_ERROR;
OTA_ReportStatus("error", "Verification failed");
return false;
}
printf("Firmware verification passed\r\n");
ota_info.state = OTA_STATE_INSTALLING;
return true;
}
// 安装固件
bool OTA_InstallFirmware(void) {
if (ota_info.state != OTA_STATE_INSTALLING) {
printf("Invalid state for installation\r\n");
return false;
}
printf("Installing firmware...\r\n");
// 设置升级标志
BankInfo_t bank_info;
memset(&bank_info, 0, sizeof(BankInfo_t));
bank_info.magic = BANK_MAGIC;
bank_info.state = BANK_STATE_READY;
bank_info.firmware_size = ota_info.firmware_size;
bank_info.firmware_crc32 = ota_info.firmware_crc32;
// 写入分区信息
Flash_Write(BANK_INFO_ADDRESS, (uint32_t*)&bank_info,
sizeof(BankInfo_t) / 4);
printf("Firmware installed, rebooting...\r\n");
// 上报升级完成
OTA_ReportStatus("installing", "Rebooting to new firmware");
HAL_Delay(1000);
// 重启系统
NVIC_SystemReset();
return true;
}
// 上报进度
void OTA_ReportProgress(uint8_t progress) {
OTA_MQTT_PublishProgress(progress);
}
// 上报状态
void OTA_ReportStatus(const char *status, const char *message) {
OTA_MQTT_PublishStatus(status, message);
}
// OTA任务(周期性调用)
void OTA_Task(void) {
// 处理MQTT消息
OTA_MQTT_Task();
// 定期检查更新
uint32_t current_time = HAL_GetTick();
if (ota_info.state == OTA_STATE_IDLE) {
if (current_time - last_check_time > ota_config.check_interval * 1000) {
OTA_CheckUpdate();
last_check_time = current_time;
}
}
// 状态机处理
switch (ota_info.state) {
case OTA_STATE_DOWNLOADING:
// 下载由回调处理
break;
case OTA_STATE_VERIFYING:
OTA_VerifyFirmware();
break;
case OTA_STATE_INSTALLING:
OTA_InstallFirmware();
break;
case OTA_STATE_ERROR:
// 错误处理
printf("OTA error, resetting state\r\n");
ota_info.state = OTA_STATE_IDLE;
break;
default:
break;
}
}
3. 安全机制实现¶
3.1 固件签名工具¶
# firmware_signer.py
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.backends import default_backend
import os
class FirmwareSigner:
"""固件签名工具"""
def __init__(self, private_key_path=None):
if private_key_path and os.path.exists(private_key_path):
# 加载现有私钥
with open(private_key_path, 'rb') as f:
self.private_key = serialization.load_pem_private_key(
f.read(),
password=None,
backend=default_backend()
)
self.public_key = self.private_key.public_key()
else:
# 生成新密钥对
self.private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend()
)
self.public_key = self.private_key.public_key()
def save_keys(self, private_key_path, public_key_path):
"""保存密钥对"""
# 保存私钥
private_pem = self.private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
)
with open(private_key_path, 'wb') as f:
f.write(private_pem)
# 保存公钥
public_pem = self.public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
with open(public_key_path, 'wb') as f:
f.write(public_pem)
print(f"Keys saved:")
print(f" Private key: {private_key_path}")
print(f" Public key: {public_key_path}")
def sign_firmware(self, firmware_path, signature_path):
"""签名固件"""
# 读取固件
with open(firmware_path, 'rb') as f:
firmware_data = f.read()
# 计算哈希
digest = hashes.Hash(hashes.SHA256(), backend=default_backend())
digest.update(firmware_data)
firmware_hash = digest.finalize()
# 签名
signature = self.private_key.sign(
firmware_hash,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
)
# 保存签名
with open(signature_path, 'wb') as f:
f.write(signature)
print(f"Firmware signed:")
print(f" Firmware: {firmware_path}")
print(f" Signature: {signature_path}")
print(f" Signature size: {len(signature)} bytes")
return signature
def verify_signature(self, firmware_path, signature_path):
"""验证签名"""
# 读取固件和签名
with open(firmware_path, 'rb') as f:
firmware_data = f.read()
with open(signature_path, 'rb') as f:
signature = f.read()
# 计算哈希
digest = hashes.Hash(hashes.SHA256(), backend=default_backend())
digest.update(firmware_data)
firmware_hash = digest.finalize()
# 验证签名
try:
self.public_key.verify(
signature,
firmware_hash,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
)
print("Signature verification: PASSED")
return True
except Exception as e:
print(f"Signature verification: FAILED - {e}")
return False
# 使用示例
if __name__ == '__main__':
signer = FirmwareSigner()
# 保存密钥对
signer.save_keys('private_key.pem', 'public_key.pem')
# 签名固件
signer.sign_firmware('firmware_v1.0.0.bin', 'firmware_v1.0.0.sig')
# 验证签名
signer.verify_signature('firmware_v1.0.0.bin', 'firmware_v1.0.0.sig')
3.2 安全通信(TLS/SSL)¶
// ota_secure_comm.c
#include "mbedtls/ssl.h"
#include "mbedtls/net_sockets.h"
#include "mbedtls/entropy.h"
#include "mbedtls/ctr_drbg.h"
typedef struct {
mbedtls_net_context server_fd;
mbedtls_ssl_context ssl;
mbedtls_ssl_config conf;
mbedtls_entropy_context entropy;
mbedtls_ctr_drbg_context ctr_drbg;
mbedtls_x509_crt cacert;
} SecureConnection_t;
// 初始化安全连接
bool SecureConnection_Init(SecureConnection_t *conn,
const char *server_addr,
const char *server_port,
const char *ca_cert) {
int ret;
// 初始化结构
mbedtls_net_init(&conn->server_fd);
mbedtls_ssl_init(&conn->ssl);
mbedtls_ssl_config_init(&conn->conf);
mbedtls_entropy_init(&conn->entropy);
mbedtls_ctr_drbg_init(&conn->ctr_drbg);
mbedtls_x509_crt_init(&conn->cacert);
// 初始化随机数生成器
ret = mbedtls_ctr_drbg_seed(&conn->ctr_drbg, mbedtls_entropy_func,
&conn->entropy, NULL, 0);
if (ret != 0) {
printf("mbedtls_ctr_drbg_seed failed: -0x%04x\r\n", -ret);
return false;
}
// 加载CA证书
ret = mbedtls_x509_crt_parse(&conn->cacert,
(const unsigned char*)ca_cert,
strlen(ca_cert) + 1);
if (ret != 0) {
printf("mbedtls_x509_crt_parse failed: -0x%04x\r\n", -ret);
return false;
}
// 连接到服务器
ret = mbedtls_net_connect(&conn->server_fd, server_addr, server_port,
MBEDTLS_NET_PROTO_TCP);
if (ret != 0) {
printf("mbedtls_net_connect failed: -0x%04x\r\n", -ret);
return false;
}
// 配置SSL
ret = mbedtls_ssl_config_defaults(&conn->conf,
MBEDTLS_SSL_IS_CLIENT,
MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_DEFAULT);
if (ret != 0) {
printf("mbedtls_ssl_config_defaults failed: -0x%04x\r\n", -ret);
return false;
}
mbedtls_ssl_conf_authmode(&conn->conf, MBEDTLS_SSL_VERIFY_REQUIRED);
mbedtls_ssl_conf_ca_chain(&conn->conf, &conn->cacert, NULL);
mbedtls_ssl_conf_rng(&conn->conf, mbedtls_ctr_drbg_random, &conn->ctr_drbg);
ret = mbedtls_ssl_setup(&conn->ssl, &conn->conf);
if (ret != 0) {
printf("mbedtls_ssl_setup failed: -0x%04x\r\n", -ret);
return false;
}
mbedtls_ssl_set_bio(&conn->ssl, &conn->server_fd,
mbedtls_net_send, mbedtls_net_recv, NULL);
// SSL握手
while ((ret = mbedtls_ssl_handshake(&conn->ssl)) != 0) {
if (ret != MBEDTLS_ERR_SSL_WANT_READ &&
ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
printf("mbedtls_ssl_handshake failed: -0x%04x\r\n", -ret);
return false;
}
}
// 验证证书
uint32_t flags = mbedtls_ssl_get_verify_result(&conn->ssl);
if (flags != 0) {
char vrfy_buf[512];
mbedtls_x509_crt_verify_info(vrfy_buf, sizeof(vrfy_buf),
" ! ", flags);
printf("Certificate verification failed:\r\n%s\r\n", vrfy_buf);
return false;
}
printf("SSL connection established\r\n");
return true;
}
// 安全发送数据
int SecureConnection_Send(SecureConnection_t *conn,
const uint8_t *data, size_t len) {
int ret;
while ((ret = mbedtls_ssl_write(&conn->ssl, data, len)) <= 0) {
if (ret != MBEDTLS_ERR_SSL_WANT_READ &&
ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
printf("mbedtls_ssl_write failed: -0x%04x\r\n", -ret);
return -1;
}
}
return ret;
}
// 安全接收数据
int SecureConnection_Receive(SecureConnection_t *conn,
uint8_t *buffer, size_t len) {
int ret;
ret = mbedtls_ssl_read(&conn->ssl, buffer, len);
if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
return 0;
}
if (ret < 0) {
printf("mbedtls_ssl_read failed: -0x%04x\r\n", -ret);
return -1;
}
return ret;
}
// 关闭安全连接
void SecureConnection_Close(SecureConnection_t *conn) {
mbedtls_ssl_close_notify(&conn->ssl);
mbedtls_net_free(&conn->server_fd);
mbedtls_x509_crt_free(&conn->cacert);
mbedtls_ssl_free(&conn->ssl);
mbedtls_ssl_config_free(&conn->conf);
mbedtls_ctr_drbg_free(&conn->ctr_drbg);
mbedtls_entropy_free(&conn->entropy);
}
4. 灰度发布实现¶
4.1 灰度发布策略¶
# gradual_rollout.py
import random
from datetime import datetime, timedelta
class GradualRollout:
"""灰度发布管理器"""
def __init__(self, task_id, total_devices, strategy='percentage'):
self.task_id = task_id
self.total_devices = total_devices
self.strategy = strategy
self.current_percentage = 0
self.deployed_devices = set()
def should_deploy(self, device_id):
"""判断设备是否应该部署"""
if device_id in self.deployed_devices:
return True
if self.strategy == 'percentage':
# 基于百分比的灰度
hash_value = hash(device_id) % 100
if hash_value < self.current_percentage:
self.deployed_devices.add(device_id)
return True
elif self.strategy == 'whitelist':
# 基于白名单的灰度
if device_id in self.whitelist:
self.deployed_devices.add(device_id)
return True
elif self.strategy == 'random':
# 随机灰度
if random.random() < (self.current_percentage / 100.0):
self.deployed_devices.add(device_id)
return True
return False
def increase_percentage(self, increment):
"""增加灰度百分比"""
self.current_percentage = min(100, self.current_percentage + increment)
print(f"Gradual rollout percentage increased to {self.current_percentage}%")
def get_statistics(self):
"""获取灰度统计"""
return {
'task_id': self.task_id,
'total_devices': self.total_devices,
'deployed_devices': len(self.deployed_devices),
'current_percentage': self.current_percentage,
'deployment_rate': len(self.deployed_devices) / self.total_devices * 100
}
# 灰度发布调度器
class RolloutScheduler:
"""灰度发布调度器"""
def __init__(self):
self.active_rollouts = {}
def create_rollout(self, task_id, device_list, schedule):
"""创建灰度发布任务"""
rollout = GradualRollout(task_id, len(device_list))
# 设置灰度计划
rollout.schedule = schedule # 例如: [(1, 10), (24, 50), (48, 100)]
rollout.start_time = datetime.now()
rollout.device_list = device_list
self.active_rollouts[task_id] = rollout
print(f"Created gradual rollout for task {task_id}")
print(f"Total devices: {len(device_list)}")
print(f"Schedule: {schedule}")
def update_rollouts(self):
"""更新所有活动的灰度发布"""
current_time = datetime.now()
for task_id, rollout in self.active_rollouts.items():
elapsed_hours = (current_time - rollout.start_time).total_seconds() / 3600
# 根据时间表更新百分比
for hours, percentage in rollout.schedule:
if elapsed_hours >= hours and rollout.current_percentage < percentage:
rollout.current_percentage = percentage
print(f"Task {task_id}: Updated to {percentage}%")
# 通知新的设备
self.notify_devices(rollout)
def notify_devices(self, rollout):
"""通知设备进行升级"""
for device_id in rollout.device_list:
if rollout.should_deploy(device_id):
# 发送升级通知
send_upgrade_notification(device_id, rollout.task_id)
实践示例¶
示例1:完整的OTA升级流程¶
# ota_demo.py
"""完整的OTA升级演示"""
# 1. 开发者上传新固件
def upload_new_firmware():
signer = FirmwareSigner('private_key.pem')
# 签名固件
signer.sign_firmware('firmware_v2.0.0.bin', 'firmware_v2.0.0.sig')
# 上传到服务器
with open('firmware_v2.0.0.bin', 'rb') as f:
firmware_data = f.read()
with open('firmware_v2.0.0.sig', 'rb') as f:
signature_data = f.read()
# 调用API上传
response = requests.post(
'http://ota-server.com/api/firmware/upload',
files={
'file': firmware_data,
'signature': signature_data
},
data={
'version': '2.0.0',
'hardware_version': 'STM32F407',
'release_notes': 'Bug fixes and new features'
}
)
print(f"Firmware uploaded: {response.json()}")
# 2. 生成差分包
def generate_differential_package():
generator = DifferentialPackageGenerator()
result = generator.generate_diff(
'firmware_v1.0.0.bin',
'firmware_v2.0.0.bin',
'diff_v1.0.0_to_v2.0.0.patch'
)
print(f"Differential package generated")
print(f"Size reduction: {result['compression_ratio']:.2f}%")
# 3. 创建灰度发布任务
def create_gradual_rollout_task():
# 获取所有设备
devices = get_all_devices()
# 创建灰度发布
scheduler = RolloutScheduler()
scheduler.create_rollout(
task_id=123,
device_list=devices,
schedule=[
(1, 10), # 1小时后部署10%
(24, 50), # 24小时后部署50%
(48, 100) # 48小时后全量部署
]
)
# 4. 设备端执行升级
def device_ota_upgrade():
# 初始化OTA客户端
config = OTA_Config_t()
config.device_id = "DEVICE_001"
config.current_version = "1.0.0"
config.hardware_version = "STM32F407"
config.server_url = "http://ota-server.com"
config.mqtt_broker = "mqtt.ota-server.com"
config.mqtt_port = 1883
config.check_interval = 3600 # 每小时检查一次
OTA_Init(&config)
# 主循环
while True:
OTA_Task()
HAL_Delay(100)
示例2:异常处理和恢复¶
// ota_error_handling.c
"""OTA异常处理示例"""
// 下载失败重试
bool OTA_DownloadWithRetry(OTA_Info_t *info, uint8_t max_retry) {
uint8_t retry_count = 0;
while (retry_count < max_retry) {
printf("Download attempt %d/%d\r\n", retry_count + 1, max_retry);
if (OTA_DownloadFirmware(info)) {
return true;
}
retry_count++;
// 指数退避
uint32_t delay = 1000 * (1 << retry_count); // 2s, 4s, 8s...
printf("Retrying in %u ms\r\n", delay);
HAL_Delay(delay);
}
return false;
}
// 升级失败回滚
void OTA_RollbackOnFailure(void) {
printf("Upgrade failed, rolling back...\r\n");
// 1. 标记新固件为无效
BankInfo_t *bank_b = (BankInfo_t*)BANK_B_INFO_ADDRESS;
bank_b->state = BANK_STATE_INVALID;
// 2. 确保旧固件标记为活动
BankInfo_t *bank_a = (BankInfo_t*)BANK_A_INFO_ADDRESS;
bank_a->state = BANK_STATE_ACTIVE;
// 3. 重启到旧固件
NVIC_SystemReset();
}
// 看门狗保护
void OTA_WatchdogProtection(void) {
// 启动看门狗
IWDG_Init(10000); // 10秒超时
// 升级过程中定期喂狗
while (ota_in_progress) {
IWDG_Feed();
HAL_Delay(1000);
}
}
部署指南¶
1. 云端平台部署¶
1.1 使用Docker部署¶
# docker-compose.yml
version: '3.8'
services:
# PostgreSQL数据库
postgres:
image: postgres:13
environment:
POSTGRES_DB: ota_db
POSTGRES_USER: ota_user
POSTGRES_PASSWORD: ota_password
volumes:
- postgres_data:/var/lib/postgresql/data
ports:
- "5432:5432"
# MQTT Broker
mosquitto:
image: eclipse-mosquitto:2
volumes:
- ./mosquitto.conf:/mosquitto/config/mosquitto.conf
- mosquitto_data:/mosquitto/data
- mosquitto_log:/mosquitto/log
ports:
- "1883:1883"
- "9001:9001"
# MinIO对象存储
minio:
image: minio/minio
command: server /data --console-address ":9001"
environment:
MINIO_ROOT_USER: minioadmin
MINIO_ROOT_PASSWORD: minioadmin
volumes:
- minio_data:/data
ports:
- "9000:9000"
- "9001:9001"
# OTA API服务
ota-api:
build: ./api
environment:
DATABASE_URL: postgresql://ota_user:ota_password@postgres:5432/ota_db
MQTT_BROKER: mosquitto
MINIO_ENDPOINT: minio:9000
depends_on:
- postgres
- mosquitto
- minio
ports:
- "5000:5000"
# Web管理界面
ota-web:
build: ./web
environment:
API_URL: http://ota-api:5000
ports:
- "8080:80"
depends_on:
- ota-api
volumes:
postgres_data:
mosquitto_data:
mosquitto_log:
minio_data:
1.2 启动服务¶
# 构建并启动所有服务
docker-compose up -d
# 查看服务状态
docker-compose ps
# 查看日志
docker-compose logs -f ota-api
# 停止服务
docker-compose down
2. 设备端集成¶
2.1 添加OTA库到项目¶
# Makefile
# OTA源文件
OTA_SOURCES = \
ota/ota_client.c \
ota/ota_mqtt.c \
ota/ota_download.c \
ota/ota_verify.c \
ota/ota_differential.c
# 包含路径
INCLUDES += -Iota/include
# 链接mbedTLS
LIBS += -lmbedtls -lmbedcrypto -lmbedx509
# 编译OTA库
$(BUILD_DIR)/ota/%.o: ota/%.c
$(CC) -c $(CFLAGS) $(INCLUDES) $< -o $@
2.2 应用程序集成¶
// main.c
#include "ota_client.h"
int main(void) {
// 系统初始化
HAL_Init();
SystemClock_Config();
// 初始化外设
UART_Init();
Network_Init();
// 配置OTA
OTA_Config_t ota_config = {
.device_id = "DEVICE_001",
.current_version = "1.0.0",
.hardware_version = "STM32F407",
.server_url = "https://ota.example.com",
.mqtt_broker = "mqtt.example.com",
.mqtt_port = 1883,
.check_interval = 3600
};
// 初始化OTA客户端
OTA_Init(&ota_config);
printf("Application started\r\n");
printf("Version: %s\r\n", ota_config.current_version);
// 主循环
while (1) {
// 应用程序逻辑
Application_Task();
// OTA任务
OTA_Task();
HAL_Delay(100);
}
}
测试验证¶
1. 单元测试¶
# test_ota_api.py
import unittest
import requests
class TestOTAAPI(unittest.TestCase):
"""OTA API测试"""
def setUp(self):
self.base_url = "http://localhost:5000/api"
def test_firmware_upload(self):
"""测试固件上传"""
with open('test_firmware.bin', 'rb') as f:
response = requests.post(
f"{self.base_url}/firmware/upload",
files={'file': f},
data={
'version': '1.0.0',
'hardware_version': 'TEST'
}
)
self.assertEqual(response.status_code, 200)
self.assertTrue(response.json()['success'])
def test_check_update(self):
"""测试检查更新"""
response = requests.get(
f"{self.base_url}/firmware/check",
params={
'device_id': 'TEST_001',
'version': '1.0.0',
'hardware': 'TEST'
}
)
self.assertEqual(response.status_code, 200)
def test_download_firmware(self):
"""测试固件下载"""
response = requests.get(
f"{self.base_url}/firmware/download/1",
headers={'Range': 'bytes=0-1023'}
)
self.assertEqual(response.status_code, 206) # Partial Content
if __name__ == '__main__':
unittest.main()
2. 集成测试¶
// test_ota_integration.c
"""OTA集成测试"""
void Test_OTA_FullUpgrade(void) {
printf("=== OTA Full Upgrade Test ===\r\n");
// 1. 初始化
OTA_Config_t config = {
.device_id = "TEST_001",
.current_version = "1.0.0",
.hardware_version = "TEST",
.server_url = "http://localhost:5000",
.mqtt_broker = "localhost",
.mqtt_port = 1883,
.check_interval = 60
};
OTA_Init(&config);
// 2. 检查更新
printf("Checking for updates...\r\n");
assert(OTA_CheckUpdate() == true);
// 3. 下载固件
printf("Downloading firmware...\r\n");
assert(OTA_StartDownload() == true);
// 4. 验证固件
printf("Verifying firmware...\r\n");
assert(OTA_VerifyFirmware() == true);
// 5. 安装固件
printf("Installing firmware...\r\n");
// OTA_InstallFirmware(); // 会重启系统
printf("=== Test Passed ===\r\n");
}
void Test_OTA_DifferentialUpgrade(void) {
printf("=== OTA Differential Upgrade Test ===\r\n");
// 测试差分升级
// ...
printf("=== Test Passed ===\r\n");
}
void Test_OTA_ResumeDownload(void) {
printf("=== OTA Resume Download Test ===\r\n");
// 测试断点续传
// ...
printf("=== Test Passed ===\r\n");
}
3. 性能测试¶
# test_ota_performance.py
"""OTA性能测试"""
import time
import statistics
def test_download_speed():
"""测试下载速度"""
firmware_size = 256 * 1024 # 256KB
speeds = []
for i in range(10):
start_time = time.time()
# 下载固件
download_firmware('test_firmware.bin')
end_time = time.time()
elapsed = end_time - start_time
speed = firmware_size / elapsed / 1024 # KB/s
speeds.append(speed)
print(f"Test {i+1}: {speed:.2f} KB/s")
print(f"\nAverage speed: {statistics.mean(speeds):.2f} KB/s")
print(f"Min speed: {min(speeds):.2f} KB/s")
print(f"Max speed: {max(speeds):.2f} KB/s")
def test_upgrade_time():
"""测试完整升级时间"""
start_time = time.time()
# 执行完整升级流程
check_update()
download_firmware()
verify_firmware()
install_firmware()
end_time = time.time()
total_time = end_time - start_time
print(f"Total upgrade time: {total_time:.2f} seconds")
def test_differential_compression():
"""测试差分压缩率"""
old_firmware = 'firmware_v1.0.0.bin'
new_firmware = 'firmware_v1.1.0.bin'
generator = DifferentialPackageGenerator()
result = generator.generate_diff(old_firmware, new_firmware, 'diff.patch')
print(f"Old firmware: {result['old_size']} bytes")
print(f"New firmware: {result['new_size']} bytes")
print(f"Differential package: {result['diff_size']} bytes")
print(f"Compression ratio: {result['compression_ratio']:.2f}%")
监控和运维¶
1. 监控指标¶
# ota_monitoring.py
"""OTA系统监控"""
class OTAMonitoring:
"""OTA监控系统"""
def get_system_metrics(self):
"""获取系统指标"""
return {
'total_devices': self.get_total_devices(),
'online_devices': self.get_online_devices(),
'upgrading_devices': self.get_upgrading_devices(),
'failed_upgrades': self.get_failed_upgrades(),
'success_rate': self.calculate_success_rate(),
'average_upgrade_time': self.get_average_upgrade_time()
}
def get_device_status(self, device_id):
"""获取设备状态"""
device = Device.query.filter_by(device_id=device_id).first()
if not device:
return None
return {
'device_id': device.device_id,
'current_version': device.current_version,
'status': device.status,
'last_online': device.last_online,
'upgrade_history': self.get_upgrade_history(device_id)
}
def get_upgrade_statistics(self, task_id):
"""获取升级任务统计"""
records = UpgradeRecord.query.filter_by(task_id=task_id).all()
total = len(records)
success = len([r for r in records if r.status == 'success'])
failed = len([r for r in records if r.status == 'failed'])
in_progress = len([r for r in records if r.status in ['downloading', 'installing']])
return {
'task_id': task_id,
'total_devices': total,
'success': success,
'failed': failed,
'in_progress': in_progress,
'success_rate': success / total * 100 if total > 0 else 0
}
def alert_on_failure(self, device_id, error_message):
"""升级失败告警"""
# 发送告警通知
send_alert(
title=f"OTA Upgrade Failed: {device_id}",
message=error_message,
severity='high'
)
# 记录日志
log_error(f"Device {device_id} upgrade failed: {error_message}")
2. 日志管理¶
# ota_logging.py
"""OTA日志管理"""
import logging
from logging.handlers import RotatingFileHandler
def setup_logging():
"""配置日志系统"""
# 创建logger
logger = logging.getLogger('ota')
logger.setLevel(logging.DEBUG)
# 文件处理器(自动轮转)
file_handler = RotatingFileHandler(
'logs/ota.log',
maxBytes=10*1024*1024, # 10MB
backupCount=10
)
file_handler.setLevel(logging.DEBUG)
# 控制台处理器
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
# 格式化
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
# 添加处理器
logger.addHandler(file_handler)
logger.addHandler(console_handler)
return logger
# 使用日志
logger = setup_logging()
logger.info("OTA system started")
logger.debug("Device connected: DEVICE_001")
logger.warning("Download retry: attempt 2/3")
logger.error("Firmware verification failed")
常见问题¶
Q1: 如何处理网络不稳定导致的下载失败?¶
A: 实现断点续传和重试机制:
- 断点续传:使用HTTP Range请求,从上次中断的位置继续下载
- 重试策略:指数退避重试,避免频繁重试
- 分块下载:将大文件分成小块下载,减少单次失败的影响
- 校验机制:每个块都进行校验,确保数据完整性
Q2: 差分升级能节省多少流量?¶
A: 取决于固件变化程度:
- 小改动(Bug修复):可节省80-95%流量
- 中等改动(功能更新):可节省50-80%流量
- 大改动(重构):可节省20-50%流量
建议: - 频繁的小更新使用差分升级 - 大版本更新使用完整固件 - 提供两种方式供设备选择
Q3: 如何保证OTA升级的安全性?¶
A: 多层安全机制:
- 传输安全:使用TLS/SSL加密通信
- 固件签名:使用RSA/ECDSA数字签名
- 证书验证:验证服务器证书
- 固件加密:可选的固件加密
- 版本控制:防止降级攻击
- 安全启动:Bootloader验证固件签名
Q4: 升级失败率高怎么办?¶
A: 分析失败原因并优化:
- 网络问题:
- 增加重试次数
- 优化超时设置
-
使用CDN加速
-
固件问题:
- 加强测试
- 灰度发布
-
快速回滚
-
设备问题:
- 检查存储空间
- 验证硬件兼容性
- 优化内存使用
Q5: 如何实现大规模设备的OTA升级?¶
A: 采用分批策略:
- 灰度发布:逐步扩大升级范围
- 分组管理:按地域、版本分组
- 负载均衡:使用CDN和负载均衡器
- 限流控制:控制同时升级的设备数量
- 监控告警:实时监控升级进度和失败率
项目总结¶
通过本项目,你已经构建了一个完整的企业级OTA升级系统,包括:
核心成果: - ✅ 云端管理平台(固件管理、设备管理、任务调度) - ✅ 设备端OTA客户端(下载、验证、安装) - ✅ 差分升级功能(节省流量) - ✅ 断点续传支持(提高可靠性) - ✅ 安全机制(签名、加密、TLS) - ✅ 灰度发布策略(降低风险) - ✅ 监控和运维工具
技术能力提升: - 掌握了OTA系统的完整架构设计 - 理解了云端和设备端的协同工作 - 学会了差分算法和断点续传 - 掌握了安全通信和固件签名 - 了解了灰度发布和监控运维
实际应用价值: - 可直接应用于商业项目 - 支持大规模设备管理 - 提供完整的安全保障 - 具备企业级可靠性
延伸学习¶
推荐进一步学习的主题:
技术深化: - 安全启动(Secure Boot)技术详解 - 固件加密与防护技术实战 - 双区升级与A/B分区策略
相关领域: - 云原生应用开发(Kubernetes、微服务) - 物联网平台架构(AWS IoT、Azure IoT) - 边缘计算和雾计算 - 设备管理协议(LwM2M、TR-069)
开源项目: - Eclipse hawkBit - 开源OTA更新服务器 - Mender - 开源OTA更新解决方案 - SWUpdate - 嵌入式Linux更新框架
参考资料¶
- 官方文档:
- MQTT协议规范
- HTTP/1.1 RFC 7233 - Range请求
-
技术文章:
- "Over-the-Air Programming of Flash Memory" - Atmel应用笔记
- "Secure Firmware Updates" - ARM白皮书
-
"Differential Updates for IoT Devices" - 学术论文
-
开源项目:
- Eclipse Paho - MQTT客户端库
- bsdiff - 二进制差分工具
-
mcuboot - 安全Bootloader
-
书籍推荐:
- 《物联网设备固件更新技术》
- 《嵌入式系统安全》
- 《云原生应用架构实践》
恭喜你完成了这个复杂的项目! 🎉
你现在已经掌握了构建企业级OTA升级系统的完整技能。这个系统可以直接应用于实际项目,为你的产品提供可靠的远程升级能力。
下一步建议: 1. 将系统部署到实际环境进行测试 2. 根据实际需求进行定制和优化 3. 添加更多高级功能(如A/B测试、自动回滚等) 4. 分享你的经验,帮助更多开发者
项目实战练习: 1. 实现一个支持多种通信协议的OTA系统(WiFi、4G、LoRa) 2. 添加固件版本依赖管理功能 3. 实现设备分组和批量升级 4. 开发移动端管理应用 5. 集成到现有的物联网平台
祝你在嵌入式开发的道路上越走越远!💪