跳转至

完整的OTA升级系统设计

项目概述

OTA (Over-The-Air) 升级是现代物联网设备的核心功能之一,它允许设备通过网络远程更新固件,无需物理接触。本项目将带你从零开始构建一个企业级的OTA升级系统,涵盖从云端管理到设备端实现的完整解决方案。

项目目标

构建一个功能完整、安全可靠的OTA升级系统,包括:

云端管理平台: - 固件版本管理和发布 - 设备管理和分组 - 升级任务调度和监控 - 升级统计和报告 - 差分包生成和管理

设备端功能: - 固件下载和验证 - 断点续传支持 - 双区原子升级 - 安全启动验证 - 升级状态上报

核心特性: - 支持MQTT/HTTP协议 - 差分升级节省流量 - 断点续传提高可靠性 - 数字签名保证安全 - 灰度发布降低风险 - 完整的日志和监控

技术栈

云端技术: - 后端:Python/Flask 或 Node.js/Express - 数据库:PostgreSQL/MySQL - 消息队列:MQTT Broker (Mosquitto/EMQ X) - 对象存储:MinIO/AWS S3 - 前端:Vue.js/React

设备端技术: - MCU:STM32F4系列(可移植到其他平台) - 通信:MQTT/HTTP - 加密:mbedTLS - 差分算法:bsdiff/courgette - 存储:双区Flash

项目交付物

完成本项目后,你将获得:

  1. 云端管理系统
  2. 固件管理Web界面
  3. RESTful API服务
  4. MQTT消息服务
  5. 数据库设计和实现

  6. 设备端代码

  7. Bootloader程序
  8. OTA客户端库
  9. 示例应用程序
  10. 测试工具

  11. 文档和工具

  12. 系统架构文档
  13. API接口文档
  14. 部署指南
  15. 测试用例

  16. 完整示例

  17. 端到端升级演示
  18. 异常处理示例
  19. 性能测试报告

前置知识

在开始本项目之前,你需要:

必备技能: - 熟练掌握C语言编程 - 理解Bootloader和IAP原理 - 了解网络通信协议(TCP/IP、HTTP、MQTT) - 掌握Flash存储器操作 - 熟悉Linux基本操作

推荐技能: - Python或Node.js后端开发经验 - 前端开发基础(HTML/CSS/JavaScript) - 数据库设计和SQL - 加密和安全基础知识 - Docker容器技术

硬件要求: - STM32F4开发板(或类似MCU) - 网络模块(WiFi/4G/Ethernet) - 调试器(ST-Link/J-Link) - PC开发环境

系统架构设计

整体架构

graph TB
    subgraph "云端平台"
        A[Web管理界面] --> B[API服务器]
        B --> C[数据库]
        B --> D[对象存储]
        B --> E[MQTT Broker]
    end

    subgraph "设备端"
        F[应用程序] --> G[OTA客户端]
        G --> H[Bootloader]
        H --> I[Flash存储]
    end

    E <--> |MQTT| G
    D <--> |HTTP| G

    subgraph "开发工具"
        J[固件构建工具]
        K[差分包生成器]
        L[签名工具]
    end

    J --> D
    K --> D
    L --> D

核心模块

1. 云端管理平台

云端平台
├── 固件管理模块
│   ├── 固件上传和存储
│   ├── 版本管理
│   ├── 差分包生成
│   └── 固件签名
├── 设备管理模块
│   ├── 设备注册和认证
│   ├── 设备分组
│   ├── 设备状态监控
│   └── 设备信息查询
├── 升级任务模块
│   ├── 任务创建和调度
│   ├── 灰度发布
│   ├── 任务监控
│   └── 任务统计
└── 通信模块
    ├── MQTT消息服务
    ├── HTTP文件服务
    └── WebSocket实时通信

2. 设备端OTA客户端

OTA客户端
├── 通信层
│   ├── MQTT客户端
│   ├── HTTP客户端
│   └── 断点续传
├── 下载管理
│   ├── 固件下载
│   ├── 进度管理
│   ├── 缓存管理
│   └── 错误重试
├── 验证层
│   ├── 数字签名验证
│   ├── CRC校验
│   ├── 版本检查
│   └── 兼容性验证
└── 升级执行
    ├── Flash操作
    ├── 双区切换
    ├── 状态上报
    └── 回滚处理

数据流程

升级流程

sequenceDiagram
    participant Dev as 开发者
    participant Cloud as 云端平台
    participant Device as 设备
    participant Boot as Bootloader

    Dev->>Cloud: 1. 上传新固件
    Cloud->>Cloud: 2. 生成差分包
    Cloud->>Cloud: 3. 签名固件
    Cloud->>Cloud: 4. 创建升级任务

    Cloud->>Device: 5. 推送升级通知
    Device->>Cloud: 6. 查询固件信息
    Cloud-->>Device: 7. 返回下载URL

    Device->>Cloud: 8. 下载固件(支持断点续传)
    Device->>Device: 9. 验证签名和CRC
    Device->>Device: 10. 写入备份分区
    Device->>Cloud: 11. 上报下载完成

    Device->>Device: 12. 设置升级标志
    Device->>Device: 13. 重启系统

    Boot->>Boot: 14. 验证新固件
    Boot->>Boot: 15. 切换到新分区
    Boot->>Device: 16. 启动新固件

    Device->>Cloud: 17. 上报升级成功

核心功能实现

1. 云端管理平台

1.1 数据库设计

-- 固件版本表
CREATE TABLE firmware_versions (
    id SERIAL PRIMARY KEY,
    version VARCHAR(50) NOT NULL UNIQUE,
    hardware_version VARCHAR(50) NOT NULL,
    file_path VARCHAR(255) NOT NULL,
    file_size BIGINT NOT NULL,
    file_md5 VARCHAR(32) NOT NULL,
    signature TEXT,
    release_notes TEXT,
    is_differential BOOLEAN DEFAULT FALSE,
    base_version VARCHAR(50),
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
    created_by VARCHAR(100),
    status VARCHAR(20) DEFAULT 'draft'
);

-- 设备表
CREATE TABLE devices (
    id SERIAL PRIMARY KEY,
    device_id VARCHAR(100) NOT NULL UNIQUE,
    device_name VARCHAR(200),
    hardware_version VARCHAR(50),
    current_version VARCHAR(50),
    group_id INTEGER,
    last_online TIMESTAMP,
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
    status VARCHAR(20) DEFAULT 'online'
);

-- 升级任务表
CREATE TABLE upgrade_tasks (
    id SERIAL PRIMARY KEY,
    task_name VARCHAR(200) NOT NULL,
    firmware_version_id INTEGER REFERENCES firmware_versions(id),
    target_devices TEXT,  -- JSON array of device IDs
    strategy VARCHAR(50) DEFAULT 'immediate',  -- immediate, scheduled, gradual
    schedule_time TIMESTAMP,
    gradual_percentage INTEGER DEFAULT 100,
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
    status VARCHAR(20) DEFAULT 'pending'
);

-- 升级记录表
CREATE TABLE upgrade_records (
    id SERIAL PRIMARY KEY,
    task_id INTEGER REFERENCES upgrade_tasks(id),
    device_id VARCHAR(100) REFERENCES devices(device_id),
    from_version VARCHAR(50),
    to_version VARCHAR(50),
    start_time TIMESTAMP,
    end_time TIMESTAMP,
    status VARCHAR(20),  -- downloading, installing, success, failed
    error_message TEXT,
    download_progress INTEGER DEFAULT 0
);

-- 设备分组表
CREATE TABLE device_groups (
    id SERIAL PRIMARY KEY,
    group_name VARCHAR(200) NOT NULL,
    description TEXT,
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);

1.2 API服务实现(Python/Flask)

from flask import Flask, request, jsonify
from flask_sqlalchemy import SQLAlchemy
import hashlib
import os
from datetime import datetime

app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'postgresql://user:pass@localhost/ota_db'
db = SQLAlchemy(app)

# 固件上传接口
@app.route('/api/firmware/upload', methods=['POST'])
def upload_firmware():
    """上传固件文件"""
    if 'file' not in request.files:
        return jsonify({'error': 'No file provided'}), 400

    file = request.files['file']
    version = request.form.get('version')
    hardware_version = request.form.get('hardware_version')
    release_notes = request.form.get('release_notes', '')

    # 验证参数
    if not version or not hardware_version:
        return jsonify({'error': 'Missing required parameters'}), 400

    # 计算文件MD5
    file_content = file.read()
    file_md5 = hashlib.md5(file_content).hexdigest()
    file_size = len(file_content)

    # 保存文件
    file_path = f'firmware/{hardware_version}/{version}.bin'
    os.makedirs(os.path.dirname(file_path), exist_ok=True)

    with open(file_path, 'wb') as f:
        f.write(file_content)

    # 保存到数据库
    firmware = FirmwareVersion(
        version=version,
        hardware_version=hardware_version,
        file_path=file_path,
        file_size=file_size,
        file_md5=file_md5,
        release_notes=release_notes,
        status='draft'
    )

    db.session.add(firmware)
    db.session.commit()

    return jsonify({
        'success': True,
        'firmware_id': firmware.id,
        'version': version,
        'file_size': file_size,
        'md5': file_md5
    })

# 创建升级任务
@app.route('/api/tasks/create', methods=['POST'])
def create_upgrade_task():
    """创建升级任务"""
    data = request.json

    task_name = data.get('task_name')
    firmware_version_id = data.get('firmware_version_id')
    target_devices = data.get('target_devices', [])
    strategy = data.get('strategy', 'immediate')

    # 验证固件版本存在
    firmware = FirmwareVersion.query.get(firmware_version_id)
    if not firmware:
        return jsonify({'error': 'Firmware version not found'}), 404

    # 创建任务
    task = UpgradeTask(
        task_name=task_name,
        firmware_version_id=firmware_version_id,
        target_devices=json.dumps(target_devices),
        strategy=strategy,
        status='pending'
    )

    db.session.add(task)
    db.session.commit()

    # 发送MQTT通知到设备
    for device_id in target_devices:
        send_upgrade_notification(device_id, firmware.version)

    return jsonify({
        'success': True,
        'task_id': task.id,
        'task_name': task_name
    })

# 查询固件信息
@app.route('/api/firmware/info/<version>', methods=['GET'])
def get_firmware_info(version):
    """获取固件信息"""
    firmware = FirmwareVersion.query.filter_by(version=version).first()

    if not firmware:
        return jsonify({'error': 'Firmware not found'}), 404

    return jsonify({
        'version': firmware.version,
        'hardware_version': firmware.hardware_version,
        'file_size': firmware.file_size,
        'md5': firmware.file_md5,
        'download_url': f'/api/firmware/download/{firmware.id}',
        'release_notes': firmware.release_notes,
        'is_differential': firmware.is_differential,
        'base_version': firmware.base_version
    })

# 固件下载接口(支持断点续传)
@app.route('/api/firmware/download/<int:firmware_id>', methods=['GET'])
def download_firmware(firmware_id):
    """下载固件文件,支持断点续传"""
    firmware = FirmwareVersion.query.get(firmware_id)

    if not firmware:
        return jsonify({'error': 'Firmware not found'}), 404

    file_path = firmware.file_path

    if not os.path.exists(file_path):
        return jsonify({'error': 'File not found'}), 404

    # 支持Range请求(断点续传)
    range_header = request.headers.get('Range')

    if range_header:
        # 解析Range头
        byte_range = range_header.replace('bytes=', '').split('-')
        start = int(byte_range[0])
        end = int(byte_range[1]) if byte_range[1] else firmware.file_size - 1

        # 读取指定范围的数据
        with open(file_path, 'rb') as f:
            f.seek(start)
            data = f.read(end - start + 1)

        response = app.response_class(
            data,
            status=206,  # Partial Content
            mimetype='application/octet-stream'
        )
        response.headers['Content-Range'] = f'bytes {start}-{end}/{firmware.file_size}'
        response.headers['Content-Length'] = str(len(data))

        return response
    else:
        # 完整下载
        return send_file(file_path, as_attachment=True)

# 设备状态上报
@app.route('/api/device/status', methods=['POST'])
def report_device_status():
    """设备上报状态"""
    data = request.json

    device_id = data.get('device_id')
    current_version = data.get('current_version')
    status = data.get('status')

    # 更新设备信息
    device = Device.query.filter_by(device_id=device_id).first()

    if device:
        device.current_version = current_version
        device.status = status
        device.last_online = datetime.now()
        db.session.commit()

    return jsonify({'success': True})

# 升级进度上报
@app.route('/api/upgrade/progress', methods=['POST'])
def report_upgrade_progress():
    """上报升级进度"""
    data = request.json

    device_id = data.get('device_id')
    task_id = data.get('task_id')
    progress = data.get('progress')
    status = data.get('status')
    error_message = data.get('error_message')

    # 更新升级记录
    record = UpgradeRecord.query.filter_by(
        task_id=task_id,
        device_id=device_id
    ).first()

    if record:
        record.download_progress = progress
        record.status = status
        record.error_message = error_message

        if status == 'success':
            record.end_time = datetime.now()

        db.session.commit()

    return jsonify({'success': True})

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000, debug=True)

1.3 MQTT消息服务

import paho.mqtt.client as mqtt
import json

class OTAMQTTService:
    def __init__(self, broker_host='localhost', broker_port=1883):
        self.client = mqtt.Client()
        self.client.on_connect = self.on_connect
        self.client.on_message = self.on_message
        self.client.connect(broker_host, broker_port, 60)

    def on_connect(self, client, userdata, flags, rc):
        """连接成功回调"""
        print(f"Connected to MQTT broker with result code {rc}")

        # 订阅设备状态主题
        client.subscribe("ota/device/+/status")
        client.subscribe("ota/device/+/progress")

    def on_message(self, client, userdata, msg):
        """接收消息回调"""
        topic = msg.topic
        payload = json.loads(msg.payload.decode())

        if '/status' in topic:
            # 处理设备状态消息
            self.handle_device_status(payload)
        elif '/progress' in topic:
            # 处理升级进度消息
            self.handle_upgrade_progress(payload)

    def handle_device_status(self, data):
        """处理设备状态"""
        device_id = data.get('device_id')
        version = data.get('version')

        # 更新数据库
        device = Device.query.filter_by(device_id=device_id).first()
        if device:
            device.current_version = version
            device.last_online = datetime.now()
            db.session.commit()

    def handle_upgrade_progress(self, data):
        """处理升级进度"""
        device_id = data.get('device_id')
        progress = data.get('progress')
        status = data.get('status')

        # 更新升级记录
        # ... 数据库操作

    def send_upgrade_notification(self, device_id, firmware_version):
        """发送升级通知到设备"""
        topic = f"ota/device/{device_id}/upgrade"

        message = {
            'action': 'upgrade',
            'firmware_version': firmware_version,
            'timestamp': datetime.now().isoformat()
        }

        self.client.publish(topic, json.dumps(message), qos=1)
        print(f"Sent upgrade notification to {device_id}")

    def run(self):
        """启动MQTT服务"""
        self.client.loop_forever()

# 启动MQTT服务
if __name__ == '__main__':
    mqtt_service = OTAMQTTService()
    mqtt_service.run()

1.4 差分包生成工具

import bsdiff4
import os
import hashlib

class DifferentialPackageGenerator:
    """差分包生成器"""

    def __init__(self):
        pass

    def generate_diff(self, old_firmware_path, new_firmware_path, output_path):
        """生成差分包"""
        print(f"Generating differential package...")
        print(f"Old firmware: {old_firmware_path}")
        print(f"New firmware: {new_firmware_path}")

        # 读取固件文件
        with open(old_firmware_path, 'rb') as f:
            old_data = f.read()

        with open(new_firmware_path, 'rb') as f:
            new_data = f.read()

        # 生成差分包
        diff_data = bsdiff4.diff(old_data, new_data)

        # 保存差分包
        with open(output_path, 'wb') as f:
            f.write(diff_data)

        # 计算压缩率
        old_size = len(old_data)
        new_size = len(new_data)
        diff_size = len(diff_data)

        compression_ratio = (1 - diff_size / new_size) * 100

        print(f"Old firmware size: {old_size} bytes")
        print(f"New firmware size: {new_size} bytes")
        print(f"Differential package size: {diff_size} bytes")
        print(f"Compression ratio: {compression_ratio:.2f}%")

        return {
            'old_size': old_size,
            'new_size': new_size,
            'diff_size': diff_size,
            'compression_ratio': compression_ratio,
            'diff_md5': hashlib.md5(diff_data).hexdigest()
        }

    def apply_diff(self, old_firmware_path, diff_path, output_path):
        """应用差分包"""
        print(f"Applying differential package...")

        # 读取旧固件和差分包
        with open(old_firmware_path, 'rb') as f:
            old_data = f.read()

        with open(diff_path, 'rb') as f:
            diff_data = f.read()

        # 应用差分
        new_data = bsdiff4.patch(old_data, diff_data)

        # 保存新固件
        with open(output_path, 'wb') as f:
            f.write(new_data)

        print(f"New firmware generated: {output_path}")
        print(f"Size: {len(new_data)} bytes")

        return len(new_data)

# 使用示例
if __name__ == '__main__':
    generator = DifferentialPackageGenerator()

    # 生成差分包
    result = generator.generate_diff(
        'firmware_v1.0.0.bin',
        'firmware_v1.1.0.bin',
        'diff_v1.0.0_to_v1.1.0.patch'
    )

    print(f"\nDifferential package generated successfully!")
    print(f"MD5: {result['diff_md5']}")

2. 设备端OTA客户端

2.1 OTA客户端核心结构

// ota_client.h
#ifndef OTA_CLIENT_H
#define OTA_CLIENT_H

#include <stdint.h>
#include <stdbool.h>

// OTA状态
typedef enum {
    OTA_STATE_IDLE = 0,
    OTA_STATE_CHECKING,
    OTA_STATE_DOWNLOADING,
    OTA_STATE_VERIFYING,
    OTA_STATE_INSTALLING,
    OTA_STATE_COMPLETE,
    OTA_STATE_ERROR
} OTA_State_t;

// OTA配置
typedef struct {
    char device_id[64];
    char current_version[32];
    char hardware_version[32];
    char server_url[256];
    char mqtt_broker[128];
    uint16_t mqtt_port;
    uint32_t check_interval;  // 检查更新间隔(秒)
} OTA_Config_t;

// OTA信息
typedef struct {
    OTA_State_t state;
    char target_version[32];
    uint32_t firmware_size;
    uint32_t downloaded_size;
    uint32_t firmware_crc32;
    char download_url[256];
    bool is_differential;
    char base_version[32];
    uint8_t retry_count;
} OTA_Info_t;

// OTA客户端API
void OTA_Init(OTA_Config_t *config);
void OTA_Task(void);
bool OTA_CheckUpdate(void);
bool OTA_StartDownload(void);
bool OTA_VerifyFirmware(void);
bool OTA_InstallFirmware(void);
void OTA_ReportProgress(uint8_t progress);
void OTA_ReportStatus(const char *status, const char *message);

#endif // OTA_CLIENT_H

2.2 MQTT通信实现

// ota_mqtt.c
#include "ota_mqtt.h"
#include "MQTTClient.h"
#include <string.h>
#include <stdio.h>

static MQTTClient mqtt_client;
static OTA_Config_t *ota_config;
static OTA_Info_t *ota_info;

// MQTT消息回调
void mqtt_message_callback(MessageData *data) {
    char topic[128];
    char payload[512];

    // 提取主题和消息
    memcpy(topic, data->topicName->lenstring.data, 
           data->topicName->lenstring.len);
    topic[data->topicName->lenstring.len] = '\0';

    memcpy(payload, data->message->payload, 
           data->message->payloadlen);
    payload[data->message->payloadlen] = '\0';

    printf("MQTT message received on topic: %s\r\n", topic);
    printf("Payload: %s\r\n", payload);

    // 检查是否是升级通知
    char upgrade_topic[128];
    snprintf(upgrade_topic, sizeof(upgrade_topic), 
             "ota/device/%s/upgrade", ota_config->device_id);

    if (strcmp(topic, upgrade_topic) == 0) {
        // 解析升级通知
        cJSON *json = cJSON_Parse(payload);
        if (json) {
            cJSON *action = cJSON_GetObjectItem(json, "action");
            cJSON *version = cJSON_GetObjectItem(json, "firmware_version");

            if (action && strcmp(action->valuestring, "upgrade") == 0) {
                if (version) {
                    printf("Upgrade notification received: %s\r\n", 
                           version->valuestring);

                    // 触发OTA检查更新
                    OTA_CheckUpdate();
                }
            }

            cJSON_Delete(json);
        }
    }
}

// 初始化MQTT
bool OTA_MQTT_Init(OTA_Config_t *config, OTA_Info_t *info) {
    ota_config = config;
    ota_info = info;

    Network network;
    MQTTClientInit(&mqtt_client, &network, 1000, NULL, 0, NULL, 0);

    // 连接到MQTT Broker
    MQTTPacket_connectData connect_data = MQTTPacket_connectData_initializer;
    connect_data.MQTTVersion = 3;
    connect_data.clientID.cstring = config->device_id;

    if (NetworkConnect(&network, config->mqtt_broker, config->mqtt_port) != 0) {
        printf("Failed to connect to MQTT broker\r\n");
        return false;
    }

    if (MQTTConnect(&mqtt_client, &connect_data) != 0) {
        printf("Failed to connect to MQTT\r\n");
        return false;
    }

    printf("Connected to MQTT broker\r\n");

    // 订阅升级主题
    char topic[128];
    snprintf(topic, sizeof(topic), "ota/device/%s/upgrade", config->device_id);

    if (MQTTSubscribe(&mqtt_client, topic, QOS1, mqtt_message_callback) != 0) {
        printf("Failed to subscribe to topic: %s\r\n", topic);
        return false;
    }

    printf("Subscribed to topic: %s\r\n", topic);

    return true;
}

// 发布设备状态
void OTA_MQTT_PublishStatus(const char *status, const char *message) {
    char topic[128];
    char payload[256];

    snprintf(topic, sizeof(topic), "ota/device/%s/status", 
             ota_config->device_id);

    cJSON *json = cJSON_CreateObject();
    cJSON_AddStringToObject(json, "device_id", ota_config->device_id);
    cJSON_AddStringToObject(json, "version", ota_config->current_version);
    cJSON_AddStringToObject(json, "status", status);
    cJSON_AddStringToObject(json, "message", message);

    char *json_str = cJSON_Print(json);

    MQTTMessage message_data;
    message_data.qos = QOS1;
    message_data.retained = 0;
    message_data.payload = json_str;
    message_data.payloadlen = strlen(json_str);

    MQTTPublish(&mqtt_client, topic, &message_data);

    free(json_str);
    cJSON_Delete(json);
}

// 发布升级进度
void OTA_MQTT_PublishProgress(uint8_t progress) {
    char topic[128];

    snprintf(topic, sizeof(topic), "ota/device/%s/progress", 
             ota_config->device_id);

    cJSON *json = cJSON_CreateObject();
    cJSON_AddStringToObject(json, "device_id", ota_config->device_id);
    cJSON_AddNumberToObject(json, "progress", progress);
    cJSON_AddStringToObject(json, "status", 
                           ota_info->state == OTA_STATE_DOWNLOADING ? 
                           "downloading" : "installing");

    char *json_str = cJSON_Print(json);

    MQTTMessage message;
    message.qos = QOS0;
    message.retained = 0;
    message.payload = json_str;
    message.payloadlen = strlen(json_str);

    MQTTPublish(&mqtt_client, topic, &message);

    free(json_str);
    cJSON_Delete(json);
}

// MQTT任务
void OTA_MQTT_Task(void) {
    MQTTYield(&mqtt_client, 100);
}

2.3 HTTP下载实现(支持断点续传)

// ota_download.c
#include "ota_download.h"
#include "http_client.h"
#include <string.h>
#include <stdio.h>

#define DOWNLOAD_BUFFER_SIZE 4096
#define MAX_RETRY_COUNT 3

static uint8_t download_buffer[DOWNLOAD_BUFFER_SIZE];
static OTA_Info_t *ota_info;

// HTTP下载回调
static int http_download_callback(void *data, size_t size, void *userdata) {
    uint32_t *downloaded = (uint32_t*)userdata;

    // 写入Flash
    uint32_t write_addr = FIRMWARE_BACKUP_ADDRESS + *downloaded;

    if (!Flash_Write(write_addr, (uint32_t*)data, size / 4)) {
        printf("Flash write failed at 0x%08X\r\n", write_addr);
        return -1;
    }

    *downloaded += size;

    // 计算进度
    uint8_t progress = (*downloaded * 100) / ota_info->firmware_size;

    // 上报进度(每5%上报一次)
    static uint8_t last_progress = 0;
    if (progress >= last_progress + 5) {
        OTA_ReportProgress(progress);
        last_progress = progress;
    }

    return 0;
}

// 下载固件(支持断点续传)
bool OTA_DownloadFirmware(OTA_Info_t *info) {
    ota_info = info;

    printf("Starting firmware download...\r\n");
    printf("URL: %s\r\n", info->download_url);
    printf("Size: %u bytes\r\n", info->firmware_size);

    // 检查已下载的大小
    uint32_t downloaded = info->downloaded_size;

    if (downloaded > 0) {
        printf("Resuming download from %u bytes\r\n", downloaded);
    }

    // 准备HTTP请求
    http_client_t client;
    http_client_init(&client);

    // 设置Range头(断点续传)
    if (downloaded > 0) {
        char range_header[64];
        snprintf(range_header, sizeof(range_header), 
                "bytes=%u-%u", downloaded, info->firmware_size - 1);
        http_client_set_header(&client, "Range", range_header);
    }

    // 设置下载回调
    http_client_set_callback(&client, http_download_callback, &downloaded);

    // 开始下载
    int retry_count = 0;
    bool success = false;

    while (retry_count < MAX_RETRY_COUNT && !success) {
        int result = http_client_get(&client, info->download_url);

        if (result == 0 && downloaded == info->firmware_size) {
            success = true;
            printf("Download complete!\r\n");
        } else {
            retry_count++;
            printf("Download failed, retry %d/%d\r\n", 
                   retry_count, MAX_RETRY_COUNT);

            // 保存当前下载进度
            info->downloaded_size = downloaded;

            // 延时后重试
            HAL_Delay(5000);
        }
    }

    http_client_cleanup(&client);

    return success;
}

2.4 固件验证实现

// ota_verify.c
#include "ota_verify.h"
#include "mbedtls/md.h"
#include "mbedtls/pk.h"
#include "mbedtls/sha256.h"
#include <string.h>

// 公钥(用于验证签名)
static const char *public_key_pem = 
"-----BEGIN PUBLIC KEY-----\n"
"MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA...\n"
"-----END PUBLIC KEY-----\n";

// 计算固件CRC32
uint32_t OTA_CalculateCRC32(uint32_t address, uint32_t size) {
    uint32_t crc = 0xFFFFFFFF;
    uint8_t *data = (uint8_t*)address;

    for (uint32_t i = 0; i < size; i++) {
        crc ^= data[i];
        for (uint8_t j = 0; j < 8; j++) {
            if (crc & 1) {
                crc = (crc >> 1) ^ 0xEDB88320;
            } else {
                crc = crc >> 1;
            }
        }
    }

    return ~crc;
}

// 计算固件SHA256
bool OTA_CalculateSHA256(uint32_t address, uint32_t size, uint8_t *hash) {
    mbedtls_sha256_context ctx;

    mbedtls_sha256_init(&ctx);
    mbedtls_sha256_starts(&ctx, 0);  // 0 = SHA256

    // 分块计算
    uint32_t remaining = size;
    uint32_t offset = 0;
    uint8_t buffer[1024];

    while (remaining > 0) {
        uint32_t chunk_size = (remaining > sizeof(buffer)) ? 
                              sizeof(buffer) : remaining;

        memcpy(buffer, (void*)(address + offset), chunk_size);
        mbedtls_sha256_update(&ctx, buffer, chunk_size);

        offset += chunk_size;
        remaining -= chunk_size;
    }

    mbedtls_sha256_finish(&ctx, hash);
    mbedtls_sha256_free(&ctx);

    return true;
}

// 验证数字签名
bool OTA_VerifySignature(uint32_t address, uint32_t size, 
                        const uint8_t *signature, size_t sig_len) {
    mbedtls_pk_context pk;
    uint8_t hash[32];
    int ret;

    // 初始化公钥
    mbedtls_pk_init(&pk);

    ret = mbedtls_pk_parse_public_key(&pk, 
                                     (const unsigned char*)public_key_pem,
                                     strlen(public_key_pem) + 1);
    if (ret != 0) {
        printf("Failed to parse public key: -0x%04x\r\n", -ret);
        mbedtls_pk_free(&pk);
        return false;
    }

    // 计算固件哈希
    if (!OTA_CalculateSHA256(address, size, hash)) {
        printf("Failed to calculate SHA256\r\n");
        mbedtls_pk_free(&pk);
        return false;
    }

    // 验证签名
    ret = mbedtls_pk_verify(&pk, MBEDTLS_MD_SHA256,
                           hash, sizeof(hash),
                           signature, sig_len);

    mbedtls_pk_free(&pk);

    if (ret != 0) {
        printf("Signature verification failed: -0x%04x\r\n", -ret);
        return false;
    }

    printf("Signature verification passed\r\n");
    return true;
}

// 验证固件完整性
bool OTA_VerifyFirmware(OTA_Info_t *info) {
    uint32_t address = FIRMWARE_BACKUP_ADDRESS;
    uint32_t size = info->firmware_size;

    printf("Verifying firmware...\r\n");

    // 1. CRC32校验
    printf("Calculating CRC32...\r\n");
    uint32_t calculated_crc = OTA_CalculateCRC32(address, size);

    if (calculated_crc != info->firmware_crc32) {
        printf("CRC32 mismatch: expected 0x%08X, got 0x%08X\r\n",
               info->firmware_crc32, calculated_crc);
        return false;
    }

    printf("CRC32 verification passed\r\n");

    // 2. 数字签名验证(如果有)
    if (info->has_signature) {
        printf("Verifying digital signature...\r\n");

        if (!OTA_VerifySignature(address, size, 
                                info->signature, info->signature_len)) {
            printf("Digital signature verification failed\r\n");
            return false;
        }
    }

    // 3. 版本检查
    FirmwareInfo_t *fw_info = (FirmwareInfo_t*)(address + 0x100);

    if (fw_info->magic != FIRMWARE_MAGIC) {
        printf("Invalid firmware magic number\r\n");
        return false;
    }

    printf("Firmware version: %d.%d.%d\r\n",
           fw_info->major, fw_info->minor, fw_info->build);

    // 4. 硬件兼容性检查
    if (strcmp(fw_info->hardware_version, info->hardware_version) != 0) {
        printf("Hardware version mismatch\r\n");
        return false;
    }

    printf("Firmware verification complete\r\n");
    return true;
}

2.5 差分升级实现

// ota_differential.c
#include "ota_differential.h"
#include "bspatch.h"
#include <string.h>

#define PATCH_BUFFER_SIZE 4096

// 应用差分包
bool OTA_ApplyDifferentialPatch(const char *base_version, 
                               uint32_t patch_address,
                               uint32_t patch_size) {
    printf("Applying differential patch...\r\n");
    printf("Base version: %s\r\n", base_version);
    printf("Patch size: %u bytes\r\n", patch_size);

    // 1. 验证基础版本
    FirmwareInfo_t *current_fw = (FirmwareInfo_t*)(FIRMWARE_ACTIVE_ADDRESS + 0x100);

    char current_version[32];
    snprintf(current_version, sizeof(current_version), "%d.%d.%d",
             current_fw->major, current_fw->minor, current_fw->build);

    if (strcmp(current_version, base_version) != 0) {
        printf("Base version mismatch: expected %s, got %s\r\n",
               base_version, current_version);
        return false;
    }

    // 2. 读取当前固件到RAM(如果RAM足够大)
    // 或者使用流式处理

    // 3. 应用bsdiff补丁
    uint8_t *old_data = (uint8_t*)FIRMWARE_ACTIVE_ADDRESS;
    uint8_t *patch_data = (uint8_t*)patch_address;
    uint8_t *new_data = (uint8_t*)FIRMWARE_BACKUP_ADDRESS;

    // 使用bspatch算法应用补丁
    int result = bspatch(old_data, current_fw->firmware_size,
                        patch_data, patch_size,
                        new_data);

    if (result != 0) {
        printf("Failed to apply patch: error %d\r\n", result);
        return false;
    }

    printf("Differential patch applied successfully\r\n");
    return true;
}

// 检查是否支持差分升级
bool OTA_IsDifferentialSupported(const char *base_version) {
    FirmwareInfo_t *current_fw = (FirmwareInfo_t*)(FIRMWARE_ACTIVE_ADDRESS + 0x100);

    char current_version[32];
    snprintf(current_version, sizeof(current_version), "%d.%d.%d",
             current_fw->major, current_fw->minor, current_fw->build);

    return (strcmp(current_version, base_version) == 0);
}

2.6 OTA主控制逻辑

// ota_client.c
#include "ota_client.h"
#include "ota_mqtt.h"
#include "ota_download.h"
#include "ota_verify.h"
#include "ota_differential.h"
#include <string.h>
#include <stdio.h>

static OTA_Config_t ota_config;
static OTA_Info_t ota_info;
static uint32_t last_check_time = 0;

// 初始化OTA客户端
void OTA_Init(OTA_Config_t *config) {
    memcpy(&ota_config, config, sizeof(OTA_Config_t));
    memset(&ota_info, 0, sizeof(OTA_Info_t));

    ota_info.state = OTA_STATE_IDLE;

    // 初始化MQTT
    if (!OTA_MQTT_Init(&ota_config, &ota_info)) {
        printf("Failed to initialize MQTT\r\n");
        return;
    }

    printf("OTA client initialized\r\n");
    printf("Device ID: %s\r\n", ota_config.device_id);
    printf("Current version: %s\r\n", ota_config.current_version);

    // 上报设备状态
    OTA_ReportStatus("online", "Device started");
}

// 检查更新
bool OTA_CheckUpdate(void) {
    printf("Checking for updates...\r\n");

    ota_info.state = OTA_STATE_CHECKING;

    // 构建API URL
    char url[512];
    snprintf(url, sizeof(url), 
             "%s/api/firmware/check?device_id=%s&version=%s&hardware=%s",
             ota_config.server_url,
             ota_config.device_id,
             ota_config.current_version,
             ota_config.hardware_version);

    // 发送HTTP请求
    http_client_t client;
    http_client_init(&client);

    char response[1024];
    int result = http_client_get_string(&client, url, response, sizeof(response));

    http_client_cleanup(&client);

    if (result != 0) {
        printf("Failed to check for updates\r\n");
        ota_info.state = OTA_STATE_IDLE;
        return false;
    }

    // 解析响应
    cJSON *json = cJSON_Parse(response);
    if (!json) {
        printf("Failed to parse response\r\n");
        ota_info.state = OTA_STATE_IDLE;
        return false;
    }

    cJSON *update_available = cJSON_GetObjectItem(json, "update_available");

    if (!update_available || !update_available->valueint) {
        printf("No update available\r\n");
        cJSON_Delete(json);
        ota_info.state = OTA_STATE_IDLE;
        return false;
    }

    // 提取固件信息
    cJSON *version = cJSON_GetObjectItem(json, "version");
    cJSON *size = cJSON_GetObjectItem(json, "file_size");
    cJSON *crc32 = cJSON_GetObjectItem(json, "crc32");
    cJSON *download_url = cJSON_GetObjectItem(json, "download_url");
    cJSON *is_diff = cJSON_GetObjectItem(json, "is_differential");
    cJSON *base_ver = cJSON_GetObjectItem(json, "base_version");

    if (version) {
        strncpy(ota_info.target_version, version->valuestring, 
                sizeof(ota_info.target_version) - 1);
    }

    if (size) {
        ota_info.firmware_size = size->valueint;
    }

    if (crc32) {
        sscanf(crc32->valuestring, "%x", &ota_info.firmware_crc32);
    }

    if (download_url) {
        strncpy(ota_info.download_url, download_url->valuestring,
                sizeof(ota_info.download_url) - 1);
    }

    if (is_diff && is_diff->valueint) {
        ota_info.is_differential = true;
        if (base_ver) {
            strncpy(ota_info.base_version, base_ver->valuestring,
                    sizeof(ota_info.base_version) - 1);
        }
    }

    cJSON_Delete(json);

    printf("Update available: %s\r\n", ota_info.target_version);
    printf("Size: %u bytes\r\n", ota_info.firmware_size);
    printf("Differential: %s\r\n", ota_info.is_differential ? "Yes" : "No");

    ota_info.state = OTA_STATE_IDLE;

    return true;
}

// 开始下载
bool OTA_StartDownload(void) {
    if (ota_info.state != OTA_STATE_IDLE) {
        printf("OTA is busy\r\n");
        return false;
    }

    printf("Starting OTA download...\r\n");

    ota_info.state = OTA_STATE_DOWNLOADING;
    ota_info.downloaded_size = 0;
    ota_info.retry_count = 0;

    // 擦除备份分区
    printf("Erasing backup partition...\r\n");
    Flash_EraseSectors(FIRMWARE_BACKUP_ADDRESS, ota_info.firmware_size);

    // 下载固件
    bool success = OTA_DownloadFirmware(&ota_info);

    if (!success) {
        printf("Download failed\r\n");
        ota_info.state = OTA_STATE_ERROR;
        OTA_ReportStatus("error", "Download failed");
        return false;
    }

    printf("Download complete\r\n");
    ota_info.state = OTA_STATE_VERIFYING;

    return true;
}

// 验证固件
bool OTA_VerifyFirmware(void) {
    if (ota_info.state != OTA_STATE_VERIFYING) {
        printf("Invalid state for verification\r\n");
        return false;
    }

    printf("Verifying firmware...\r\n");

    // 如果是差分包,先应用补丁
    if (ota_info.is_differential) {
        printf("Applying differential patch...\r\n");

        if (!OTA_ApplyDifferentialPatch(ota_info.base_version,
                                       FIRMWARE_BACKUP_ADDRESS,
                                       ota_info.firmware_size)) {
            printf("Failed to apply differential patch\r\n");
            ota_info.state = OTA_STATE_ERROR;
            OTA_ReportStatus("error", "Patch application failed");
            return false;
        }
    }

    // 验证固件
    bool success = OTA_VerifyFirmware(&ota_info);

    if (!success) {
        printf("Firmware verification failed\r\n");
        ota_info.state = OTA_STATE_ERROR;
        OTA_ReportStatus("error", "Verification failed");
        return false;
    }

    printf("Firmware verification passed\r\n");
    ota_info.state = OTA_STATE_INSTALLING;

    return true;
}

// 安装固件
bool OTA_InstallFirmware(void) {
    if (ota_info.state != OTA_STATE_INSTALLING) {
        printf("Invalid state for installation\r\n");
        return false;
    }

    printf("Installing firmware...\r\n");

    // 设置升级标志
    BankInfo_t bank_info;
    memset(&bank_info, 0, sizeof(BankInfo_t));

    bank_info.magic = BANK_MAGIC;
    bank_info.state = BANK_STATE_READY;
    bank_info.firmware_size = ota_info.firmware_size;
    bank_info.firmware_crc32 = ota_info.firmware_crc32;

    // 写入分区信息
    Flash_Write(BANK_INFO_ADDRESS, (uint32_t*)&bank_info, 
                sizeof(BankInfo_t) / 4);

    printf("Firmware installed, rebooting...\r\n");

    // 上报升级完成
    OTA_ReportStatus("installing", "Rebooting to new firmware");

    HAL_Delay(1000);

    // 重启系统
    NVIC_SystemReset();

    return true;
}

// 上报进度
void OTA_ReportProgress(uint8_t progress) {
    OTA_MQTT_PublishProgress(progress);
}

// 上报状态
void OTA_ReportStatus(const char *status, const char *message) {
    OTA_MQTT_PublishStatus(status, message);
}

// OTA任务(周期性调用)
void OTA_Task(void) {
    // 处理MQTT消息
    OTA_MQTT_Task();

    // 定期检查更新
    uint32_t current_time = HAL_GetTick();

    if (ota_info.state == OTA_STATE_IDLE) {
        if (current_time - last_check_time > ota_config.check_interval * 1000) {
            OTA_CheckUpdate();
            last_check_time = current_time;
        }
    }

    // 状态机处理
    switch (ota_info.state) {
        case OTA_STATE_DOWNLOADING:
            // 下载由回调处理
            break;

        case OTA_STATE_VERIFYING:
            OTA_VerifyFirmware();
            break;

        case OTA_STATE_INSTALLING:
            OTA_InstallFirmware();
            break;

        case OTA_STATE_ERROR:
            // 错误处理
            printf("OTA error, resetting state\r\n");
            ota_info.state = OTA_STATE_IDLE;
            break;

        default:
            break;
    }
}

3. 安全机制实现

3.1 固件签名工具

# firmware_signer.py
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.backends import default_backend
import os

class FirmwareSigner:
    """固件签名工具"""

    def __init__(self, private_key_path=None):
        if private_key_path and os.path.exists(private_key_path):
            # 加载现有私钥
            with open(private_key_path, 'rb') as f:
                self.private_key = serialization.load_pem_private_key(
                    f.read(),
                    password=None,
                    backend=default_backend()
                )
            self.public_key = self.private_key.public_key()
        else:
            # 生成新密钥对
            self.private_key = rsa.generate_private_key(
                public_exponent=65537,
                key_size=2048,
                backend=default_backend()
            )
            self.public_key = self.private_key.public_key()

    def save_keys(self, private_key_path, public_key_path):
        """保存密钥对"""
        # 保存私钥
        private_pem = self.private_key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.PKCS8,
            encryption_algorithm=serialization.NoEncryption()
        )

        with open(private_key_path, 'wb') as f:
            f.write(private_pem)

        # 保存公钥
        public_pem = self.public_key.public_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PublicFormat.SubjectPublicKeyInfo
        )

        with open(public_key_path, 'wb') as f:
            f.write(public_pem)

        print(f"Keys saved:")
        print(f"  Private key: {private_key_path}")
        print(f"  Public key: {public_key_path}")

    def sign_firmware(self, firmware_path, signature_path):
        """签名固件"""
        # 读取固件
        with open(firmware_path, 'rb') as f:
            firmware_data = f.read()

        # 计算哈希
        digest = hashes.Hash(hashes.SHA256(), backend=default_backend())
        digest.update(firmware_data)
        firmware_hash = digest.finalize()

        # 签名
        signature = self.private_key.sign(
            firmware_hash,
            padding.PSS(
                mgf=padding.MGF1(hashes.SHA256()),
                salt_length=padding.PSS.MAX_LENGTH
            ),
            hashes.SHA256()
        )

        # 保存签名
        with open(signature_path, 'wb') as f:
            f.write(signature)

        print(f"Firmware signed:")
        print(f"  Firmware: {firmware_path}")
        print(f"  Signature: {signature_path}")
        print(f"  Signature size: {len(signature)} bytes")

        return signature

    def verify_signature(self, firmware_path, signature_path):
        """验证签名"""
        # 读取固件和签名
        with open(firmware_path, 'rb') as f:
            firmware_data = f.read()

        with open(signature_path, 'rb') as f:
            signature = f.read()

        # 计算哈希
        digest = hashes.Hash(hashes.SHA256(), backend=default_backend())
        digest.update(firmware_data)
        firmware_hash = digest.finalize()

        # 验证签名
        try:
            self.public_key.verify(
                signature,
                firmware_hash,
                padding.PSS(
                    mgf=padding.MGF1(hashes.SHA256()),
                    salt_length=padding.PSS.MAX_LENGTH
                ),
                hashes.SHA256()
            )
            print("Signature verification: PASSED")
            return True
        except Exception as e:
            print(f"Signature verification: FAILED - {e}")
            return False

# 使用示例
if __name__ == '__main__':
    signer = FirmwareSigner()

    # 保存密钥对
    signer.save_keys('private_key.pem', 'public_key.pem')

    # 签名固件
    signer.sign_firmware('firmware_v1.0.0.bin', 'firmware_v1.0.0.sig')

    # 验证签名
    signer.verify_signature('firmware_v1.0.0.bin', 'firmware_v1.0.0.sig')

3.2 安全通信(TLS/SSL)

// ota_secure_comm.c
#include "mbedtls/ssl.h"
#include "mbedtls/net_sockets.h"
#include "mbedtls/entropy.h"
#include "mbedtls/ctr_drbg.h"

typedef struct {
    mbedtls_net_context server_fd;
    mbedtls_ssl_context ssl;
    mbedtls_ssl_config conf;
    mbedtls_entropy_context entropy;
    mbedtls_ctr_drbg_context ctr_drbg;
    mbedtls_x509_crt cacert;
} SecureConnection_t;

// 初始化安全连接
bool SecureConnection_Init(SecureConnection_t *conn, 
                          const char *server_addr,
                          const char *server_port,
                          const char *ca_cert) {
    int ret;

    // 初始化结构
    mbedtls_net_init(&conn->server_fd);
    mbedtls_ssl_init(&conn->ssl);
    mbedtls_ssl_config_init(&conn->conf);
    mbedtls_entropy_init(&conn->entropy);
    mbedtls_ctr_drbg_init(&conn->ctr_drbg);
    mbedtls_x509_crt_init(&conn->cacert);

    // 初始化随机数生成器
    ret = mbedtls_ctr_drbg_seed(&conn->ctr_drbg, mbedtls_entropy_func,
                                &conn->entropy, NULL, 0);
    if (ret != 0) {
        printf("mbedtls_ctr_drbg_seed failed: -0x%04x\r\n", -ret);
        return false;
    }

    // 加载CA证书
    ret = mbedtls_x509_crt_parse(&conn->cacert, 
                                (const unsigned char*)ca_cert,
                                strlen(ca_cert) + 1);
    if (ret != 0) {
        printf("mbedtls_x509_crt_parse failed: -0x%04x\r\n", -ret);
        return false;
    }

    // 连接到服务器
    ret = mbedtls_net_connect(&conn->server_fd, server_addr, server_port,
                             MBEDTLS_NET_PROTO_TCP);
    if (ret != 0) {
        printf("mbedtls_net_connect failed: -0x%04x\r\n", -ret);
        return false;
    }

    // 配置SSL
    ret = mbedtls_ssl_config_defaults(&conn->conf,
                                     MBEDTLS_SSL_IS_CLIENT,
                                     MBEDTLS_SSL_TRANSPORT_STREAM,
                                     MBEDTLS_SSL_PRESET_DEFAULT);
    if (ret != 0) {
        printf("mbedtls_ssl_config_defaults failed: -0x%04x\r\n", -ret);
        return false;
    }

    mbedtls_ssl_conf_authmode(&conn->conf, MBEDTLS_SSL_VERIFY_REQUIRED);
    mbedtls_ssl_conf_ca_chain(&conn->conf, &conn->cacert, NULL);
    mbedtls_ssl_conf_rng(&conn->conf, mbedtls_ctr_drbg_random, &conn->ctr_drbg);

    ret = mbedtls_ssl_setup(&conn->ssl, &conn->conf);
    if (ret != 0) {
        printf("mbedtls_ssl_setup failed: -0x%04x\r\n", -ret);
        return false;
    }

    mbedtls_ssl_set_bio(&conn->ssl, &conn->server_fd,
                       mbedtls_net_send, mbedtls_net_recv, NULL);

    // SSL握手
    while ((ret = mbedtls_ssl_handshake(&conn->ssl)) != 0) {
        if (ret != MBEDTLS_ERR_SSL_WANT_READ && 
            ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
            printf("mbedtls_ssl_handshake failed: -0x%04x\r\n", -ret);
            return false;
        }
    }

    // 验证证书
    uint32_t flags = mbedtls_ssl_get_verify_result(&conn->ssl);
    if (flags != 0) {
        char vrfy_buf[512];
        mbedtls_x509_crt_verify_info(vrfy_buf, sizeof(vrfy_buf), 
                                     "  ! ", flags);
        printf("Certificate verification failed:\r\n%s\r\n", vrfy_buf);
        return false;
    }

    printf("SSL connection established\r\n");
    return true;
}

// 安全发送数据
int SecureConnection_Send(SecureConnection_t *conn, 
                         const uint8_t *data, size_t len) {
    int ret;

    while ((ret = mbedtls_ssl_write(&conn->ssl, data, len)) <= 0) {
        if (ret != MBEDTLS_ERR_SSL_WANT_READ && 
            ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
            printf("mbedtls_ssl_write failed: -0x%04x\r\n", -ret);
            return -1;
        }
    }

    return ret;
}

// 安全接收数据
int SecureConnection_Receive(SecureConnection_t *conn, 
                            uint8_t *buffer, size_t len) {
    int ret;

    ret = mbedtls_ssl_read(&conn->ssl, buffer, len);

    if (ret == MBEDTLS_ERR_SSL_WANT_READ || 
        ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
        return 0;
    }

    if (ret < 0) {
        printf("mbedtls_ssl_read failed: -0x%04x\r\n", -ret);
        return -1;
    }

    return ret;
}

// 关闭安全连接
void SecureConnection_Close(SecureConnection_t *conn) {
    mbedtls_ssl_close_notify(&conn->ssl);
    mbedtls_net_free(&conn->server_fd);
    mbedtls_x509_crt_free(&conn->cacert);
    mbedtls_ssl_free(&conn->ssl);
    mbedtls_ssl_config_free(&conn->conf);
    mbedtls_ctr_drbg_free(&conn->ctr_drbg);
    mbedtls_entropy_free(&conn->entropy);
}

4. 灰度发布实现

4.1 灰度发布策略

# gradual_rollout.py
import random
from datetime import datetime, timedelta

class GradualRollout:
    """灰度发布管理器"""

    def __init__(self, task_id, total_devices, strategy='percentage'):
        self.task_id = task_id
        self.total_devices = total_devices
        self.strategy = strategy
        self.current_percentage = 0
        self.deployed_devices = set()

    def should_deploy(self, device_id):
        """判断设备是否应该部署"""
        if device_id in self.deployed_devices:
            return True

        if self.strategy == 'percentage':
            # 基于百分比的灰度
            hash_value = hash(device_id) % 100
            if hash_value < self.current_percentage:
                self.deployed_devices.add(device_id)
                return True

        elif self.strategy == 'whitelist':
            # 基于白名单的灰度
            if device_id in self.whitelist:
                self.deployed_devices.add(device_id)
                return True

        elif self.strategy == 'random':
            # 随机灰度
            if random.random() < (self.current_percentage / 100.0):
                self.deployed_devices.add(device_id)
                return True

        return False

    def increase_percentage(self, increment):
        """增加灰度百分比"""
        self.current_percentage = min(100, self.current_percentage + increment)
        print(f"Gradual rollout percentage increased to {self.current_percentage}%")

    def get_statistics(self):
        """获取灰度统计"""
        return {
            'task_id': self.task_id,
            'total_devices': self.total_devices,
            'deployed_devices': len(self.deployed_devices),
            'current_percentage': self.current_percentage,
            'deployment_rate': len(self.deployed_devices) / self.total_devices * 100
        }

# 灰度发布调度器
class RolloutScheduler:
    """灰度发布调度器"""

    def __init__(self):
        self.active_rollouts = {}

    def create_rollout(self, task_id, device_list, schedule):
        """创建灰度发布任务"""
        rollout = GradualRollout(task_id, len(device_list))

        # 设置灰度计划
        rollout.schedule = schedule  # 例如: [(1, 10), (24, 50), (48, 100)]
        rollout.start_time = datetime.now()
        rollout.device_list = device_list

        self.active_rollouts[task_id] = rollout

        print(f"Created gradual rollout for task {task_id}")
        print(f"Total devices: {len(device_list)}")
        print(f"Schedule: {schedule}")

    def update_rollouts(self):
        """更新所有活动的灰度发布"""
        current_time = datetime.now()

        for task_id, rollout in self.active_rollouts.items():
            elapsed_hours = (current_time - rollout.start_time).total_seconds() / 3600

            # 根据时间表更新百分比
            for hours, percentage in rollout.schedule:
                if elapsed_hours >= hours and rollout.current_percentage < percentage:
                    rollout.current_percentage = percentage
                    print(f"Task {task_id}: Updated to {percentage}%")

                    # 通知新的设备
                    self.notify_devices(rollout)

    def notify_devices(self, rollout):
        """通知设备进行升级"""
        for device_id in rollout.device_list:
            if rollout.should_deploy(device_id):
                # 发送升级通知
                send_upgrade_notification(device_id, rollout.task_id)

实践示例

示例1:完整的OTA升级流程

# ota_demo.py
"""完整的OTA升级演示"""

# 1. 开发者上传新固件
def upload_new_firmware():
    signer = FirmwareSigner('private_key.pem')

    # 签名固件
    signer.sign_firmware('firmware_v2.0.0.bin', 'firmware_v2.0.0.sig')

    # 上传到服务器
    with open('firmware_v2.0.0.bin', 'rb') as f:
        firmware_data = f.read()

    with open('firmware_v2.0.0.sig', 'rb') as f:
        signature_data = f.read()

    # 调用API上传
    response = requests.post(
        'http://ota-server.com/api/firmware/upload',
        files={
            'file': firmware_data,
            'signature': signature_data
        },
        data={
            'version': '2.0.0',
            'hardware_version': 'STM32F407',
            'release_notes': 'Bug fixes and new features'
        }
    )

    print(f"Firmware uploaded: {response.json()}")

# 2. 生成差分包
def generate_differential_package():
    generator = DifferentialPackageGenerator()

    result = generator.generate_diff(
        'firmware_v1.0.0.bin',
        'firmware_v2.0.0.bin',
        'diff_v1.0.0_to_v2.0.0.patch'
    )

    print(f"Differential package generated")
    print(f"Size reduction: {result['compression_ratio']:.2f}%")

# 3. 创建灰度发布任务
def create_gradual_rollout_task():
    # 获取所有设备
    devices = get_all_devices()

    # 创建灰度发布
    scheduler = RolloutScheduler()
    scheduler.create_rollout(
        task_id=123,
        device_list=devices,
        schedule=[
            (1, 10),    # 1小时后部署10%
            (24, 50),   # 24小时后部署50%
            (48, 100)   # 48小时后全量部署
        ]
    )

# 4. 设备端执行升级
def device_ota_upgrade():
    # 初始化OTA客户端
    config = OTA_Config_t()
    config.device_id = "DEVICE_001"
    config.current_version = "1.0.0"
    config.hardware_version = "STM32F407"
    config.server_url = "http://ota-server.com"
    config.mqtt_broker = "mqtt.ota-server.com"
    config.mqtt_port = 1883
    config.check_interval = 3600  # 每小时检查一次

    OTA_Init(&config)

    # 主循环
    while True:
        OTA_Task()
        HAL_Delay(100)

示例2:异常处理和恢复

// ota_error_handling.c
"""OTA异常处理示例"""

// 下载失败重试
bool OTA_DownloadWithRetry(OTA_Info_t *info, uint8_t max_retry) {
    uint8_t retry_count = 0;

    while (retry_count < max_retry) {
        printf("Download attempt %d/%d\r\n", retry_count + 1, max_retry);

        if (OTA_DownloadFirmware(info)) {
            return true;
        }

        retry_count++;

        // 指数退避
        uint32_t delay = 1000 * (1 << retry_count);  // 2s, 4s, 8s...
        printf("Retrying in %u ms\r\n", delay);
        HAL_Delay(delay);
    }

    return false;
}

// 升级失败回滚
void OTA_RollbackOnFailure(void) {
    printf("Upgrade failed, rolling back...\r\n");

    // 1. 标记新固件为无效
    BankInfo_t *bank_b = (BankInfo_t*)BANK_B_INFO_ADDRESS;
    bank_b->state = BANK_STATE_INVALID;

    // 2. 确保旧固件标记为活动
    BankInfo_t *bank_a = (BankInfo_t*)BANK_A_INFO_ADDRESS;
    bank_a->state = BANK_STATE_ACTIVE;

    // 3. 重启到旧固件
    NVIC_SystemReset();
}

// 看门狗保护
void OTA_WatchdogProtection(void) {
    // 启动看门狗
    IWDG_Init(10000);  // 10秒超时

    // 升级过程中定期喂狗
    while (ota_in_progress) {
        IWDG_Feed();
        HAL_Delay(1000);
    }
}

部署指南

1. 云端平台部署

1.1 使用Docker部署

# docker-compose.yml
version: '3.8'

services:
  # PostgreSQL数据库
  postgres:
    image: postgres:13
    environment:
      POSTGRES_DB: ota_db
      POSTGRES_USER: ota_user
      POSTGRES_PASSWORD: ota_password
    volumes:
      - postgres_data:/var/lib/postgresql/data
    ports:
      - "5432:5432"

  # MQTT Broker
  mosquitto:
    image: eclipse-mosquitto:2
    volumes:
      - ./mosquitto.conf:/mosquitto/config/mosquitto.conf
      - mosquitto_data:/mosquitto/data
      - mosquitto_log:/mosquitto/log
    ports:
      - "1883:1883"
      - "9001:9001"

  # MinIO对象存储
  minio:
    image: minio/minio
    command: server /data --console-address ":9001"
    environment:
      MINIO_ROOT_USER: minioadmin
      MINIO_ROOT_PASSWORD: minioadmin
    volumes:
      - minio_data:/data
    ports:
      - "9000:9000"
      - "9001:9001"

  # OTA API服务
  ota-api:
    build: ./api
    environment:
      DATABASE_URL: postgresql://ota_user:ota_password@postgres:5432/ota_db
      MQTT_BROKER: mosquitto
      MINIO_ENDPOINT: minio:9000
    depends_on:
      - postgres
      - mosquitto
      - minio
    ports:
      - "5000:5000"

  # Web管理界面
  ota-web:
    build: ./web
    environment:
      API_URL: http://ota-api:5000
    ports:
      - "8080:80"
    depends_on:
      - ota-api

volumes:
  postgres_data:
  mosquitto_data:
  mosquitto_log:
  minio_data:

1.2 启动服务

# 构建并启动所有服务
docker-compose up -d

# 查看服务状态
docker-compose ps

# 查看日志
docker-compose logs -f ota-api

# 停止服务
docker-compose down

2. 设备端集成

2.1 添加OTA库到项目

# Makefile
# OTA源文件
OTA_SOURCES = \
    ota/ota_client.c \
    ota/ota_mqtt.c \
    ota/ota_download.c \
    ota/ota_verify.c \
    ota/ota_differential.c

# 包含路径
INCLUDES += -Iota/include

# 链接mbedTLS
LIBS += -lmbedtls -lmbedcrypto -lmbedx509

# 编译OTA库
$(BUILD_DIR)/ota/%.o: ota/%.c
    $(CC) -c $(CFLAGS) $(INCLUDES) $< -o $@

2.2 应用程序集成

// main.c
#include "ota_client.h"

int main(void) {
    // 系统初始化
    HAL_Init();
    SystemClock_Config();

    // 初始化外设
    UART_Init();
    Network_Init();

    // 配置OTA
    OTA_Config_t ota_config = {
        .device_id = "DEVICE_001",
        .current_version = "1.0.0",
        .hardware_version = "STM32F407",
        .server_url = "https://ota.example.com",
        .mqtt_broker = "mqtt.example.com",
        .mqtt_port = 1883,
        .check_interval = 3600
    };

    // 初始化OTA客户端
    OTA_Init(&ota_config);

    printf("Application started\r\n");
    printf("Version: %s\r\n", ota_config.current_version);

    // 主循环
    while (1) {
        // 应用程序逻辑
        Application_Task();

        // OTA任务
        OTA_Task();

        HAL_Delay(100);
    }
}

测试验证

1. 单元测试

# test_ota_api.py
import unittest
import requests

class TestOTAAPI(unittest.TestCase):
    """OTA API测试"""

    def setUp(self):
        self.base_url = "http://localhost:5000/api"

    def test_firmware_upload(self):
        """测试固件上传"""
        with open('test_firmware.bin', 'rb') as f:
            response = requests.post(
                f"{self.base_url}/firmware/upload",
                files={'file': f},
                data={
                    'version': '1.0.0',
                    'hardware_version': 'TEST'
                }
            )

        self.assertEqual(response.status_code, 200)
        self.assertTrue(response.json()['success'])

    def test_check_update(self):
        """测试检查更新"""
        response = requests.get(
            f"{self.base_url}/firmware/check",
            params={
                'device_id': 'TEST_001',
                'version': '1.0.0',
                'hardware': 'TEST'
            }
        )

        self.assertEqual(response.status_code, 200)

    def test_download_firmware(self):
        """测试固件下载"""
        response = requests.get(
            f"{self.base_url}/firmware/download/1",
            headers={'Range': 'bytes=0-1023'}
        )

        self.assertEqual(response.status_code, 206)  # Partial Content

if __name__ == '__main__':
    unittest.main()

2. 集成测试

// test_ota_integration.c
"""OTA集成测试"""

void Test_OTA_FullUpgrade(void) {
    printf("=== OTA Full Upgrade Test ===\r\n");

    // 1. 初始化
    OTA_Config_t config = {
        .device_id = "TEST_001",
        .current_version = "1.0.0",
        .hardware_version = "TEST",
        .server_url = "http://localhost:5000",
        .mqtt_broker = "localhost",
        .mqtt_port = 1883,
        .check_interval = 60
    };

    OTA_Init(&config);

    // 2. 检查更新
    printf("Checking for updates...\r\n");
    assert(OTA_CheckUpdate() == true);

    // 3. 下载固件
    printf("Downloading firmware...\r\n");
    assert(OTA_StartDownload() == true);

    // 4. 验证固件
    printf("Verifying firmware...\r\n");
    assert(OTA_VerifyFirmware() == true);

    // 5. 安装固件
    printf("Installing firmware...\r\n");
    // OTA_InstallFirmware();  // 会重启系统

    printf("=== Test Passed ===\r\n");
}

void Test_OTA_DifferentialUpgrade(void) {
    printf("=== OTA Differential Upgrade Test ===\r\n");

    // 测试差分升级
    // ...

    printf("=== Test Passed ===\r\n");
}

void Test_OTA_ResumeDownload(void) {
    printf("=== OTA Resume Download Test ===\r\n");

    // 测试断点续传
    // ...

    printf("=== Test Passed ===\r\n");
}

3. 性能测试

# test_ota_performance.py
"""OTA性能测试"""

import time
import statistics

def test_download_speed():
    """测试下载速度"""
    firmware_size = 256 * 1024  # 256KB

    speeds = []

    for i in range(10):
        start_time = time.time()

        # 下载固件
        download_firmware('test_firmware.bin')

        end_time = time.time()
        elapsed = end_time - start_time
        speed = firmware_size / elapsed / 1024  # KB/s

        speeds.append(speed)
        print(f"Test {i+1}: {speed:.2f} KB/s")

    print(f"\nAverage speed: {statistics.mean(speeds):.2f} KB/s")
    print(f"Min speed: {min(speeds):.2f} KB/s")
    print(f"Max speed: {max(speeds):.2f} KB/s")

def test_upgrade_time():
    """测试完整升级时间"""
    start_time = time.time()

    # 执行完整升级流程
    check_update()
    download_firmware()
    verify_firmware()
    install_firmware()

    end_time = time.time()
    total_time = end_time - start_time

    print(f"Total upgrade time: {total_time:.2f} seconds")

def test_differential_compression():
    """测试差分压缩率"""
    old_firmware = 'firmware_v1.0.0.bin'
    new_firmware = 'firmware_v1.1.0.bin'

    generator = DifferentialPackageGenerator()
    result = generator.generate_diff(old_firmware, new_firmware, 'diff.patch')

    print(f"Old firmware: {result['old_size']} bytes")
    print(f"New firmware: {result['new_size']} bytes")
    print(f"Differential package: {result['diff_size']} bytes")
    print(f"Compression ratio: {result['compression_ratio']:.2f}%")

监控和运维

1. 监控指标

# ota_monitoring.py
"""OTA系统监控"""

class OTAMonitoring:
    """OTA监控系统"""

    def get_system_metrics(self):
        """获取系统指标"""
        return {
            'total_devices': self.get_total_devices(),
            'online_devices': self.get_online_devices(),
            'upgrading_devices': self.get_upgrading_devices(),
            'failed_upgrades': self.get_failed_upgrades(),
            'success_rate': self.calculate_success_rate(),
            'average_upgrade_time': self.get_average_upgrade_time()
        }

    def get_device_status(self, device_id):
        """获取设备状态"""
        device = Device.query.filter_by(device_id=device_id).first()

        if not device:
            return None

        return {
            'device_id': device.device_id,
            'current_version': device.current_version,
            'status': device.status,
            'last_online': device.last_online,
            'upgrade_history': self.get_upgrade_history(device_id)
        }

    def get_upgrade_statistics(self, task_id):
        """获取升级任务统计"""
        records = UpgradeRecord.query.filter_by(task_id=task_id).all()

        total = len(records)
        success = len([r for r in records if r.status == 'success'])
        failed = len([r for r in records if r.status == 'failed'])
        in_progress = len([r for r in records if r.status in ['downloading', 'installing']])

        return {
            'task_id': task_id,
            'total_devices': total,
            'success': success,
            'failed': failed,
            'in_progress': in_progress,
            'success_rate': success / total * 100 if total > 0 else 0
        }

    def alert_on_failure(self, device_id, error_message):
        """升级失败告警"""
        # 发送告警通知
        send_alert(
            title=f"OTA Upgrade Failed: {device_id}",
            message=error_message,
            severity='high'
        )

        # 记录日志
        log_error(f"Device {device_id} upgrade failed: {error_message}")

2. 日志管理

# ota_logging.py
"""OTA日志管理"""

import logging
from logging.handlers import RotatingFileHandler

def setup_logging():
    """配置日志系统"""
    # 创建logger
    logger = logging.getLogger('ota')
    logger.setLevel(logging.DEBUG)

    # 文件处理器(自动轮转)
    file_handler = RotatingFileHandler(
        'logs/ota.log',
        maxBytes=10*1024*1024,  # 10MB
        backupCount=10
    )
    file_handler.setLevel(logging.DEBUG)

    # 控制台处理器
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)

    # 格式化
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    )
    file_handler.setFormatter(formatter)
    console_handler.setFormatter(formatter)

    # 添加处理器
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)

    return logger

# 使用日志
logger = setup_logging()

logger.info("OTA system started")
logger.debug("Device connected: DEVICE_001")
logger.warning("Download retry: attempt 2/3")
logger.error("Firmware verification failed")

常见问题

Q1: 如何处理网络不稳定导致的下载失败?

A: 实现断点续传和重试机制:

  1. 断点续传:使用HTTP Range请求,从上次中断的位置继续下载
  2. 重试策略:指数退避重试,避免频繁重试
  3. 分块下载:将大文件分成小块下载,减少单次失败的影响
  4. 校验机制:每个块都进行校验,确保数据完整性

Q2: 差分升级能节省多少流量?

A: 取决于固件变化程度:

  • 小改动(Bug修复):可节省80-95%流量
  • 中等改动(功能更新):可节省50-80%流量
  • 大改动(重构):可节省20-50%流量

建议: - 频繁的小更新使用差分升级 - 大版本更新使用完整固件 - 提供两种方式供设备选择

Q3: 如何保证OTA升级的安全性?

A: 多层安全机制:

  1. 传输安全:使用TLS/SSL加密通信
  2. 固件签名:使用RSA/ECDSA数字签名
  3. 证书验证:验证服务器证书
  4. 固件加密:可选的固件加密
  5. 版本控制:防止降级攻击
  6. 安全启动:Bootloader验证固件签名

Q4: 升级失败率高怎么办?

A: 分析失败原因并优化:

  1. 网络问题
  2. 增加重试次数
  3. 优化超时设置
  4. 使用CDN加速

  5. 固件问题

  6. 加强测试
  7. 灰度发布
  8. 快速回滚

  9. 设备问题

  10. 检查存储空间
  11. 验证硬件兼容性
  12. 优化内存使用

Q5: 如何实现大规模设备的OTA升级?

A: 采用分批策略:

  1. 灰度发布:逐步扩大升级范围
  2. 分组管理:按地域、版本分组
  3. 负载均衡:使用CDN和负载均衡器
  4. 限流控制:控制同时升级的设备数量
  5. 监控告警:实时监控升级进度和失败率

项目总结

通过本项目,你已经构建了一个完整的企业级OTA升级系统,包括:

核心成果: - ✅ 云端管理平台(固件管理、设备管理、任务调度) - ✅ 设备端OTA客户端(下载、验证、安装) - ✅ 差分升级功能(节省流量) - ✅ 断点续传支持(提高可靠性) - ✅ 安全机制(签名、加密、TLS) - ✅ 灰度发布策略(降低风险) - ✅ 监控和运维工具

技术能力提升: - 掌握了OTA系统的完整架构设计 - 理解了云端和设备端的协同工作 - 学会了差分算法和断点续传 - 掌握了安全通信和固件签名 - 了解了灰度发布和监控运维

实际应用价值: - 可直接应用于商业项目 - 支持大规模设备管理 - 提供完整的安全保障 - 具备企业级可靠性

延伸学习

推荐进一步学习的主题:

技术深化: - 安全启动(Secure Boot)技术详解 - 固件加密与防护技术实战 - 双区升级与A/B分区策略

相关领域: - 云原生应用开发(Kubernetes、微服务) - 物联网平台架构(AWS IoT、Azure IoT) - 边缘计算和雾计算 - 设备管理协议(LwM2M、TR-069)

开源项目: - Eclipse hawkBit - 开源OTA更新服务器 - Mender - 开源OTA更新解决方案 - SWUpdate - 嵌入式Linux更新框架

参考资料

  1. 官方文档
  2. MQTT协议规范
  3. HTTP/1.1 RFC 7233 - Range请求
  4. mbedTLS文档

  5. 技术文章

  6. "Over-the-Air Programming of Flash Memory" - Atmel应用笔记
  7. "Secure Firmware Updates" - ARM白皮书
  8. "Differential Updates for IoT Devices" - 学术论文

  9. 开源项目

  10. Eclipse Paho - MQTT客户端库
  11. bsdiff - 二进制差分工具
  12. mcuboot - 安全Bootloader

  13. 书籍推荐

  14. 《物联网设备固件更新技术》
  15. 《嵌入式系统安全》
  16. 《云原生应用架构实践》

恭喜你完成了这个复杂的项目! 🎉

你现在已经掌握了构建企业级OTA升级系统的完整技能。这个系统可以直接应用于实际项目,为你的产品提供可靠的远程升级能力。

下一步建议: 1. 将系统部署到实际环境进行测试 2. 根据实际需求进行定制和优化 3. 添加更多高级功能(如A/B测试、自动回滚等) 4. 分享你的经验,帮助更多开发者

项目实战练习: 1. 实现一个支持多种通信协议的OTA系统(WiFi、4G、LoRa) 2. 添加固件版本依赖管理功能 3. 实现设备分组和批量升级 4. 开发移动端管理应用 5. 集成到现有的物联网平台

祝你在嵌入式开发的道路上越走越远!💪