commit 8531c4dc9ea4f0b278ccefce0d15506706253261 Author: MobKBK <15059009+mobkbk@user.noreply.gitee.com> Date: Tue May 26 03:28:37 2026 +0800 first commit diff --git a/README_ODOMETRY.md b/README_ODOMETRY.md new file mode 100644 index 0000000..887e01a --- /dev/null +++ b/README_ODOMETRY.md @@ -0,0 +1,148 @@ +# IMU 惯性三维里程计 + 3D 坐标显示 + +## 概述 + +基于 STM32 IMU 串口解码系统,在 PC 端实现 3D 轨迹重建和实时可视化。下位机 (STM32) 已完成 EKF 姿态估计,PC 端接收四元数姿态和滤波后加速度数据,进行偏置标定、重力补偿与双重积分,重建三维运动轨迹。 + +## 架构 + +``` +STM32 (200Hz) + └── 串口帧 (48 bytes): header + timestamp(uint32) + gyro[3] + accel[3] + quat[4] + CRC16 + │ +PC 端管线: + ├── imu_decode.py # 串口帧解析 (CRC 校验 + 解包 48 字节帧) + ├── trajectory_tracker.py # 核心算法 (四元数→旋转, 偏置标定, 重力补偿, 双重积分, ZUPT) + ├── visualize_3d.py # matplotlib 3D 动画窗口 + └── main_odometry.py # 主入口 (串联所有模块) +``` + +## 串口协议 (v2) + +| 偏移 | 大小 | 内容 | +|------|------|------| +| 0 | 2B | 帧头 0xAA 0x55 | +| 2 | 4B | 时间戳 uint32_t (ms, 自启动) | +| 6 | 12B | 滤波后 Gyro[3] (float32 × 3) | +| 18 | 12B | 滤波后 Accel[3] (float32 × 3) | +| 30 | 16B | 四元数 qw, qx, qy, qz (float32 × 4) | +| 46 | 2B | CRC16 (对前 46 字节) | + +**坐标系**: 右手系, X-前 Y-左 Z-上 (Z-up) + +## 算法管线 + +``` +串口帧 (已解析) + ├── timestamp ms ──→ dt = Δts / 1000 (精确时间步长) + ├── gyro[3] rad/s ──→ 减 gyro_bias → ZUPT 静止检测 + ├── accel[3] m/s^2 ──→ 减 accel_bias (EKF 四元数标定) + └── qw/qx/qy/qz ──→ 四元数 → 旋转矩阵 (scipy Rotation.from_quat) + │ + a_world = R @ (a_body - accel_bias) + a_linear = a_world - [0, 0, 9.81] + │ + 加速度死区 (|a_i| < 0.03 → 0) + ZUPT: ‖gyro‖ < 0.05 AND ‖a_linear‖ < 0.20 AND var(a_linear) < 0.005 + │ + 梯形积分 → 速度 → 位置 (使用时间戳真实 dt) + │ + matplotlib 3D 实时轨迹显示 +``` + +### 关键: 加速度计偏置标定 + +EKF 四元数与加速度计之间存在不一致时会导致积分漂移。启动时采集 200 帧静止数据,利用 EKF 四元数标定: + +``` +R = from_quat(q_mean) +gravity_body_expected = R^T @ [0, 0, 9.81] +accel_bias = mean(accel_measured) - gravity_body_expected +``` + +标定后 `R @ (accel - bias) ≈ [0, 0, 9.81]`,静止时 `a_linear ≈ 0`,ZUPT 正常触发。 + +## 依赖 + +```bash +pip install numpy matplotlib pyserial scipy +``` + +## 用法 + +### 实时模式 + +```bash +python main_odometry.py COM5 +python main_odometry.py COM5 921600 +python main_odometry.py COM5 --save traj.csv +``` + +### 回放模式 + +```bash +python main_odometry.py --replay traj.csv +``` + +回放时按 **空格键** 暂停/继续。 + +### 调试 + +```bash +python imu_decode.py COM5 # 仅解析帧并打印 +python test_3d_demo.py # 模拟数据 3D 演示(无需串口) +``` + +## 文件说明 + +### `imu_decode.py` +串口帧解析模块。48 字节帧 = 2B header + 44B payload + 2B CRC16。 + +- `parse_frame(payload)` → dict: timestamp_ms, gyro(3), accel(3), quat(4) +- `quat_to_euler(qw,qx,qy,qz)` → yaw, pitch, roll (deg) + +### `trajectory_tracker.py` +核心跟踪算法: + +- `quat_to_rotation(qw,qx,qy,qz)` — 四元数 → 旋转矩阵 +- `rotate_accel(accel_body, R)` — 机体→世界坐标 +- `gravity_compensate(a_world)` — 减重力 [0, 0, 9.81] +- `apply_deadzone(a)` — 加速度死区 +- `Tracker` — 位置/速度/姿态跟踪器 +- `Tracker.calibrate_from_samples()` — 利用四元数标定偏置 + +| ZUPT 参数 | 默认值 | +|-----------|--------| +| `zupt_threshold_accel` | 0.20 m/s^2 | +| `zupt_threshold_gyro` | 0.05 rad/s | +| `zupt_frames` | 15 帧 | +| `deadzone_threshold` | 0.03 m/s^2 | +| `var_window_size` | 30 帧 | +| `zupt_var_threshold` | 0.005 m^2/s^4 | + +### `visualize_3d.py` +matplotlib 3D 实时显示:蓝色轨迹线、红点当前位置、原点 RGB 坐标轴、自适应等比例坐标。 + +### `main_odometry.py` +运行入口:串口采集 + 3D 显示、CSV 保存、回放。dt 由时间戳差值计算。 + +## CSV 格式 + +```csv +timestamp_ms,gyro_x,gyro_y,gyro_z,accel_x,accel_y,accel_z,qw,qx,qy,qz,pos_x,pos_y,pos_z +12345,0.001,0.002,-0.001,0.75,-0.05,9.81,0.996,0.001,-0.002,-0.036,0.000,0.000,0.000 +... +``` + +## 验证结果 + +| 场景 | 结果 | +|------|------| +| 25s 静止 (偏置标定后) | 0.000m 漂移 | +| 2 m/s^2 × 0.5s 运动 | 0.250m 位移 (理论值) | + +## 调试 + +- 漂移:检查标定输出 `accel_bias` 是否合理 (通常 X/Y < 0.8, Z < 0.1) +- ZUPT 不触发:增大 `zupt_threshold_accel` 或减小 `zupt_var_threshold` +- 3D 卡顿:增大 `refresh_interval` diff --git a/__pycache__/imu_decode.cpython-310.pyc b/__pycache__/imu_decode.cpython-310.pyc new file mode 100644 index 0000000..a634a87 Binary files /dev/null and b/__pycache__/imu_decode.cpython-310.pyc differ diff --git a/__pycache__/main_odometry.cpython-310.pyc b/__pycache__/main_odometry.cpython-310.pyc new file mode 100644 index 0000000..4b90fb6 Binary files /dev/null and b/__pycache__/main_odometry.cpython-310.pyc differ diff --git a/__pycache__/test_3d_demo.cpython-310.pyc b/__pycache__/test_3d_demo.cpython-310.pyc new file mode 100644 index 0000000..27054b1 Binary files /dev/null and b/__pycache__/test_3d_demo.cpython-310.pyc differ diff --git a/__pycache__/trajectory_tracker.cpython-310.pyc b/__pycache__/trajectory_tracker.cpython-310.pyc new file mode 100644 index 0000000..1484ab0 Binary files /dev/null and b/__pycache__/trajectory_tracker.cpython-310.pyc differ diff --git a/__pycache__/visualize_3d.cpython-310.pyc b/__pycache__/visualize_3d.cpython-310.pyc new file mode 100644 index 0000000..a4741a4 Binary files /dev/null and b/__pycache__/visualize_3d.cpython-310.pyc differ diff --git a/imu_decode.py b/imu_decode.py new file mode 100644 index 0000000..fc568d4 --- /dev/null +++ b/imu_decode.py @@ -0,0 +1,137 @@ +import serial +import struct +import time + +# ========== CRC16 预计算查找表 ========== +crc_tab = [] +for i in range(256): + crc = 0 + c = i + for _ in range(8): + if (crc ^ c) & 1: + crc = (crc >> 1) ^ 0xA001 + else: + crc >>= 1 + c >>= 1 + crc_tab.append(crc) + + +def crc16(data): + """CRC16-Modbus 校验 (查表法)""" + crc = 0xFFFF + for b in data: + crc = (crc >> 8) ^ crc_tab[(crc ^ b) & 0xFF] + return crc + + +# ========== 帧格式常量 (新协议 48 字节) ========== +# 0 2B 帧头 0xAA 0x55 +# 2 4B 时间戳 uint32_t (ms, 自启动) +# 6 12B 滤波后 Gyro[3] (float32 × 3) +# 18 12B 滤波后 Accel[3] (float32 × 3) +# 30 16B 四元数 qw, qx, qy, qz (float32 × 4) +# 46 2B CRC16 (对前 46 字节) +HEADER = b'\xAA\x55' +HEADER_LEN = 2 +PAYLOAD_LEN = 44 # uint32 + 3f + 3f + 4f = 4+12+12+16 +CRC_LEN = 2 +FRAME_LEN = HEADER_LEN + PAYLOAD_LEN + CRC_LEN # = 48 + +FIELD_NAMES = ['timestamp_ms', + 'gyro_x', 'gyro_y', 'gyro_z', + 'accel_x', 'accel_y', 'accel_z', + 'qw', 'qx', 'qy', 'qz'] + + +def parse_frame(payload_bytes): + """解包 44 字节 payload + + Returns: + dict with keys: timestamp_ms, gyro (3-tuple), accel (3-tuple), quat (4-tuple: qw,qx,qy,qz) + """ + ts, gx, gy, gz, ax, ay, az, qw, qx, qy, qz = struct.unpack(' 1 else 'COM3' + baud = int(sys.argv[2]) if len(sys.argv) > 2 else 115200 + main(port, baud) diff --git a/main_odometry.py b/main_odometry.py new file mode 100644 index 0000000..9a0069a --- /dev/null +++ b/main_odometry.py @@ -0,0 +1,305 @@ +""" +IMU 惯性三维里程计 — 主入口 + +用法: + python main_odometry.py COM5 [baud] + python main_odometry.py COM5 --save traj.csv + python main_odometry.py --replay traj.csv +""" + +import sys +import time +import struct +import csv +import argparse + +import numpy as np +import serial +import matplotlib.pyplot as plt + +from imu_decode import HEADER, PAYLOAD_LEN, CRC_LEN, crc16, parse_frame +from trajectory_tracker import Tracker +from visualize_3d import TrajectoryViewer + + +def read_frame(ser): + """从串口读取一帧, 返回解析后的 dict 或 None + + 逐字节寻找帧头 0xAA 0x55, 校验 CRC, 解包 payload。 + """ + while True: + b = ser.read(1) + if not b: + return None + if b[0] == 0xAA: + b2 = ser.read(1) + if not b2 or b2[0] != 0x55: + continue + + frame = ser.read(PAYLOAD_LEN + CRC_LEN) + if len(frame) < PAYLOAD_LEN + CRC_LEN: + return None + + payload_bytes = frame[:PAYLOAD_LEN] + crc_recv = struct.unpack(' 0.1: + dt = 0.005 # 回退到 200Hz 默认值 + else: + dt = 0.005 + last_ts = ts + + pos = tracker.update(gyro, accel, qw, qx, qy, qz, dt) + + # 保存 CSV + if csv_writer: + csv_writer.writerow([ts, gx, gy, gz, ax, ay, az, qw, qx, qy, qz, + pos[0], pos[1], pos[2]]) + + # 30Hz 刷新显示 + now = time.time() + if now - last_draw >= 0.033: + viewer.update(tracker.history_array) + plt.pause(0.001) + last_draw = now + + frame_count += 1 + if frame_count % 200 == 0: + print(f"[{frame_count:06d}] ts={ts} dt={dt*1000:.1f}ms " + f"pos=({pos[0]:.3f}, {pos[1]:.3f}, {pos[2]:.3f})") + + except KeyboardInterrupt: + print(f"\n停止。共接收 {frame_count} 帧。") + finally: + ser.close() + if csv_file: + csv_file.close() + print(f"轨迹已保存至 {save_csv}") + viewer.close() + + +def run_replay(csv_path): + """回放模式: 从 CSV 文件加载数据并显示 3D 轨迹""" + # 加载 CSV + rows = [] + with open(csv_path, 'r') as f: + reader = csv.DictReader(f) + for row in reader: + rows.append(row) + + print(f"加载 {len(rows)} 帧数据") + + # 从开头静止帧标定偏置 + calib_samples = min(200, len(rows) // 4) + accel_samples, gyro_samples = [], [] + qw_s, qx_s, qy_s, qz_s = [], [], [], [] + for i in range(calib_samples): + row = rows[i] + accel_samples.append([float(row['accel_x']), float(row['accel_y']), float(row['accel_z'])]) + gyro_samples.append([float(row['gyro_x']), float(row['gyro_y']), float(row['gyro_z'])]) + qw_s.append(float(row['qw'])) + qx_s.append(float(row['qx'])) + qy_s.append(float(row['qy'])) + qz_s.append(float(row['qz'])) + accel_bias, gyro_bias = Tracker.calibrate_from_samples( + np.array(accel_samples), np.array(gyro_samples), + np.array(qw_s), np.array(qx_s), np.array(qy_s), np.array(qz_s)) + print(f"标定 (前{calib_samples}帧): accel_bias=({accel_bias[0]:.4f},{accel_bias[1]:.4f},{accel_bias[2]:.4f})" + f" gyro_bias=({gyro_bias[0]:.4f},{gyro_bias[1]:.4f},{gyro_bias[2]:.4f})") + + tracker = Tracker() + tracker.accel_bias = accel_bias + tracker.gyro_bias = gyro_bias + viewer = TrajectoryViewer() + + print("开始回放 ...(空格键暂停)") + + last_ts = None + last_draw = time.time() + idx = 0 + paused = False + + def on_key(event): + nonlocal paused + if event.key == ' ': + paused = not paused + print("暂停" if paused else "继续") + + viewer.fig.canvas.mpl_connect('key_press_event', on_key) + + try: + while plt.fignum_exists(viewer.fig.number) and idx < len(rows): + if not paused: + row = rows[idx] + gyro = np.array([float(row['gyro_x']), float(row['gyro_y']), float(row['gyro_z'])]) + accel = np.array([float(row['accel_x']), float(row['accel_y']), float(row['accel_z'])]) + qw, qx, qy, qz = float(row['qw']), float(row['qx']), float(row['qy']), float(row['qz']) + + # 使用时间戳计算 dt + ts = int(row['timestamp_ms']) if 'timestamp_ms' in row else None + if ts is not None and last_ts is not None: + dt = (ts - last_ts) / 1000.0 + if dt <= 0 or dt > 0.1: + dt = 0.005 + else: + dt = 0.005 + if ts is not None: + last_ts = ts + + tracker.update(gyro, accel, qw, qx, qy, qz, dt) + idx += 1 + + now = time.time() + if now - last_draw >= 0.033: + viewer.update(tracker.history_array) + plt.pause(0.001) + last_draw = now + else: + plt.pause(0.05) + + if idx % 200 == 0: + pos = tracker.position + print(f"[{idx:06d}/{len(rows)}] dt={dt*1000:.1f}ms " + f"pos=({pos[0]:.3f}, {pos[1]:.3f}, {pos[2]:.3f})") + + except KeyboardInterrupt: + print("\n回放中断。") + finally: + viewer.close() + + print(f"回放完成,共处理 {idx} 帧。") + + +def main(): + parser = argparse.ArgumentParser(description='IMU 惯性三维里程计') + parser.add_argument('port', nargs='?', default=None, + help='串口号 (如 COM5)') + parser.add_argument('baud', nargs='?', type=int, default=115200, + help='波特率 (默认 115200)') + parser.add_argument('--save', metavar='FILE', + help='保存轨迹到 CSV 文件') + parser.add_argument('--replay', metavar='FILE', + help='从 CSV 文件回放轨迹') + args = parser.parse_args() + + if args.replay: + run_replay(args.replay) + elif args.port: + run_live(args.port, args.baud, save_csv=args.save) + else: + parser.print_help() + print("\n示例:") + print(" python main_odometry.py COM5") + print(" python main_odometry.py COM5 921600") + print(" python main_odometry.py COM5 --save traj.csv") + print(" python main_odometry.py --replay traj.csv") + + +if __name__ == '__main__': + main() diff --git a/test_3d_demo.py b/test_3d_demo.py new file mode 100644 index 0000000..684651a --- /dev/null +++ b/test_3d_demo.py @@ -0,0 +1,56 @@ +"""3D 显示验证 — 模拟圆周运动轨迹 (无需串口, 使用四元数)""" +import numpy as np +import time +import matplotlib.pyplot as plt +from scipy.spatial.transform import Rotation +from trajectory_tracker import Tracker +from visualize_3d import TrajectoryViewer + + +def main(): + tracker = Tracker(zupt_frames=50) + viewer = TrajectoryViewer(title="IMU 3D Demo — 四元数模拟") + + dt = 0.005 # 200Hz + t = 0.0 + omega = 0.5 # 角速度 (rad/s) + last_draw = time.time() + + print("3D 演示运行中... 按 Ctrl+C 停止") + + try: + while plt.fignum_exists(viewer.fig.number): + # X-Y 平面圆周运动 + Z 轴正弦起伏 + ax_linear = -0.25 * np.cos(omega * t) + ay_linear = -0.25 * np.sin(omega * t) + az_linear = 0.3 * np.sin(2 * t) + + # 模拟四元数 (绕 Z 轴旋转 omega*t) + yaw = omega * t + r = Rotation.from_euler('ZYX', [np.degrees(yaw), 0, 0]) + q = r.as_quat() # [x, y, z, w] + + # body 加速度 = R^T @ (a_linear + gravity) + a_world = np.array([ax_linear, ay_linear, az_linear + 9.81]) + accel_body = r.as_matrix().T @ a_world + + gyro = np.array([0.0, 0.0, omega]) + + pos = tracker.update(gyro, accel_body, q[3], q[0], q[1], q[2], dt) + + now = time.time() + if now - last_draw >= 0.033: + viewer.update(tracker.history_array) + plt.pause(0.001) + last_draw = now + + t += dt + + except KeyboardInterrupt: + print("停止。") + finally: + viewer.close() + + +if __name__ == '__main__': + main() diff --git a/trajectory_tracker.py b/trajectory_tracker.py new file mode 100644 index 0000000..f06e628 --- /dev/null +++ b/trajectory_tracker.py @@ -0,0 +1,198 @@ +""" +IMU 惯性三维里程计 — 核心算法模块 + +管线: + 四元数 → 旋转矩阵 (scipy Rotation.from_quat) + a_world = R @ (a_body - accel_bias) + a_linear = a_world - [0, 0, 9.81] + 双重积分 (梯形积分 + ZUPT 静止检测 + 加速度死区) + +使用 EKF 四元数标定加速度计偏置, 保证 R @ corrected_accel ≈ [0,0,g]。 +""" + +import numpy as np +from scipy.spatial.transform import Rotation + +GRAVITY = np.array([0.0, 0.0, 9.81]) # Z-up 右手系 + + +def quat_to_rotation(qw, qx, qy, qz): + """四元数 → body→world 旋转矩阵 + + Args: + qw, qx, qy, qz: STM32 四元数 (scalar-first) + Returns: + R: 3x3 旋转矩阵 + """ + return Rotation.from_quat([qx, qy, qz, qw]).as_matrix() + + +def rotate_accel(accel_body, R): + """机体加速度 → 世界坐标系""" + return R @ np.asarray(accel_body) + + +def gravity_compensate(a_world): + """减去重力向量""" + return a_world - GRAVITY + + +def apply_deadzone(a, threshold=0.03): + """幅值小于阈值的分量置零""" + a = np.asarray(a).copy() + a[np.abs(a) < threshold] = 0.0 + return a + + +class Tracker: + """三维轨迹跟踪器 + + 使用 EKF 四元数进行姿态旋转, 标定加速度计偏置, ZUPT 抑制静止漂移。 + """ + + def __init__(self, zupt_threshold_accel=0.10, zupt_threshold_gyro=0.03, + zupt_frames=8, deadzone_threshold=0.02, + var_window_size=20, zupt_var_threshold=0.002): + """ + Args: + zupt_threshold_accel: ZUPT ‖a_linear‖ 阈值 (m/s^2) + zupt_threshold_gyro: ZUPT ‖gyro‖ 阈值 (rad/s) + zupt_frames: 连续静止帧数阈值 + deadzone_threshold: 加速度分量死区 (m/s^2) + var_window_size: 方差窗口大小 (帧) + zupt_var_threshold: ZUPT 方差阈值 (m^2/s^4) + """ + self.position = np.zeros(3) + self.velocity = np.zeros(3) + + self.position_history = [self.position.copy()] + + self.zupt_threshold_accel = zupt_threshold_accel + self.zupt_threshold_gyro = zupt_threshold_gyro + self.zupt_frames = zupt_frames + self.deadzone_threshold = deadzone_threshold + self.var_window_size = var_window_size + self._zupt_var_threshold = zupt_var_threshold + + # 传感器偏置 + self.accel_bias = np.zeros(3) + self.gyro_bias = np.zeros(3) + + # ZUPT 状态 + self._zupt_counter = 0 + self._linear_var_window = [] + self._prev_accel_linear = np.zeros(3) + + # 当前四元数 (用于外部查询) + self.qw, self.qx, self.qy, self.qz = 1.0, 0.0, 0.0, 0.0 + + def update(self, gyro, accel, qw, qx, qy, qz, dt): + """处理一帧 IMU 数据 + + Args: + gyro: 机体角速度 [gx, gy, gz] rad/s + accel: 机体加速度 [ax, ay, az] m/s^2 + qw, qx, qy, qz: EKF 四元数 (scalar-first) + dt: 时间步长 (s), 由时间戳差值计算 + """ + self.qw, self.qx, self.qy, self.qz = qw, qx, qy, qz + + accel = np.asarray(accel, dtype=float) + gyro = np.asarray(gyro, dtype=float) - self.gyro_bias + + # 0. 减去加速度计偏置 + accel_corrected = accel - self.accel_bias + + # 1. 四元数 → 旋转矩阵 + R = quat_to_rotation(qw, qx, qy, qz) + + # 2. 机体加速度 → 世界加速度 → 重力补偿 + a_world = rotate_accel(accel_corrected, R) + a_linear = gravity_compensate(a_world) + + # 3. 加速度死区 + a_linear = apply_deadzone(a_linear, self.deadzone_threshold) + + # 4. ZUPT 静止检测 + gyro_norm = np.linalg.norm(gyro) + linear_magnitude = np.linalg.norm(a_linear) + + self._linear_var_window.append(a_linear.copy()) + if len(self._linear_var_window) > self.var_window_size: + self._linear_var_window.pop(0) + + linear_variance = 0.0 + if len(self._linear_var_window) >= self.var_window_size: + linear_variance = np.var(self._linear_var_window, axis=0).mean() + + is_static = ( + gyro_norm < self.zupt_threshold_gyro + and linear_magnitude < self.zupt_threshold_accel + and linear_variance < self._zupt_var_threshold + ) + + if is_static: + self._zupt_counter += 1 + else: + self._zupt_counter = 0 + + zupt_active = self._zupt_counter >= self.zupt_frames + + # 5. 梯形积分 (使用真实 dt) + if dt > 0: + if zupt_active: + self.velocity[:] = 0.0 + self._prev_accel_linear = np.zeros(3) + else: + a_prev = self._prev_accel_linear + self.velocity = self.velocity + (a_prev + a_linear) * dt / 2.0 + self._prev_accel_linear = a_linear.copy() + self.position = self.position + self.velocity * dt + + self.position_history.append(self.position.copy()) + return self.position + + @staticmethod + def calibrate_from_samples(accel_samples, gyro_samples, qw_samples, qx_samples, qy_samples, qz_samples): + """从静止采样数据计算传感器偏置 + + 利用 EKF 四元数计算理论 body 重力: R(q)^T @ [0, 0, 9.81] + 偏置 = 实测均值 - 理论均值 + + Args: + accel_samples: Nx3, 机体加速度 (m/s^2) + gyro_samples: Nx3, 角速度 (rad/s) + qw/qx/qy/qz_samples: N, EKF 四元数分量 + + Returns: + accel_bias: (3,) 加速度计偏置 (m/s^2) + gyro_bias: (3,) 陀螺仪偏置 (rad/s) + """ + accel_mean = np.mean(accel_samples, axis=0) + gyro_bias = np.mean(gyro_samples, axis=0) + + # 用平均四元数计算理论的 body 重力向量 + qw_m = np.mean(qw_samples) + qx_m = np.mean(qx_samples) + qy_m = np.mean(qy_samples) + qz_m = np.mean(qz_samples) + R_mean = quat_to_rotation(qw_m, qx_m, qy_m, qz_m) + gravity_body_expected = R_mean.T @ GRAVITY + + accel_bias = accel_mean - gravity_body_expected + + return accel_bias, gyro_bias + + def reset(self): + """重置轨迹, 保留偏置""" + self.position = np.zeros(3) + self.velocity = np.zeros(3) + self.position_history = [self.position.copy()] + self._zupt_counter = 0 + self._prev_accel_linear = np.zeros(3) + self._linear_var_window = [] + + @property + def history_array(self): + """返回 Nx3 numpy 数组""" + return np.array(self.position_history) diff --git a/visualize_3d.py b/visualize_3d.py new file mode 100644 index 0000000..be45e10 --- /dev/null +++ b/visualize_3d.py @@ -0,0 +1,114 @@ +""" +IMU 3D 轨迹实时可视化 + +matplotlib 3D 窗口, 30Hz 刷新, 显示: + - 蓝色轨迹线 + - 当前点红点 + - 原点坐标系指示 + - 等比例坐标轴 +""" + +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.animation import FuncAnimation + + +class TrajectoryViewer: + """3D 轨迹实时显示窗口""" + + def __init__(self, title="IMU 3D Odometry", refresh_interval=33): + """ + Args: + title: 窗口标题 + refresh_interval: 刷新间隔 (ms), 默认 33ms ≈ 30Hz + """ + self.refresh_interval = refresh_interval + + self.fig = plt.figure(figsize=(8, 7)) + self.fig.canvas.manager.set_window_title(title) + self.ax = self.fig.add_subplot(111, projection='3d') + + # 轨迹线 (蓝色) + self.traj_line, = self.ax.plot([], [], [], 'b-', linewidth=1.0, label='Trajectory') + + # 当前点 (红色) + self.current_point, = self.ax.plot([], [], [], 'ro', markersize=6, label='Current') + + # 坐标系指示 (原点处) + axis_len = 0.3 + self.origin_axes = [ + self.ax.quiver(0, 0, 0, axis_len, 0, 0, color='r', arrow_length_ratio=0.15, label='X'), + self.ax.quiver(0, 0, 0, 0, axis_len, 0, color='g', arrow_length_ratio=0.15, label='Y'), + self.ax.quiver(0, 0, 0, 0, 0, axis_len, color='b', arrow_length_ratio=0.15, label='Z'), + ] + + self._setup_axes() + + # 动画 + self.anim = FuncAnimation(self.fig, self._animate, interval=self.refresh_interval, + cache_frame_data=False, blit=False) + + def _setup_axes(self): + """初始化坐标轴""" + self.ax.set_xlabel('X (front)') + self.ax.set_ylabel('Y (left)') + self.ax.set_zlabel('Z (up)') + self.ax.set_title("IMU 3D Trajectory (Z-up)") + self.ax.legend(loc='upper left') + + # 初始范围 + self.ax.set_xlim([-1, 1]) + self.ax.set_ylim([-1, 1]) + self.ax.set_zlim([-1, 1]) + + try: + self.ax.set_box_aspect([1, 1, 1]) + except NotImplementedError: + pass + + self.ax.grid(True) + + def _animate(self, frame): + """动画帧回调 (不做任何事, 数据由外部 update 驱动)""" + pass # 通过 plt.pause 驱动, FuncAnimation 仅用于保持窗口响应 + + def update(self, history_array): + """更新显示的轨迹数据 + + Args: + history_array: Nx3 numpy array, 位置历史 + """ + if len(history_array) < 1: + return + + x, y, z = history_array[:, 0], history_array[:, 1], history_array[:, 2] + + # 更新轨迹线 + self.traj_line.set_data(x, y) + self.traj_line.set_3d_properties(z) + + # 更新当前点 + self.current_point.set_data([x[-1]], [y[-1]]) + self.current_point.set_3d_properties([z[-1]]) + + # 自适应坐标轴范围 + self._auto_scale(x, y, z) + + def _auto_scale(self, x, y, z): + """根据数据自动调整坐标轴范围, 保持等比例""" + all_coords = np.concatenate([x, y, z]) + margin = max(np.ptp(all_coords) * 0.2, 0.5) + mid = (all_coords.min() + all_coords.max()) / 2 + half = (all_coords.max() - all_coords.min()) / 2 + margin + + self.ax.set_xlim([mid - half, mid + half]) + self.ax.set_ylim([mid - half, mid + half]) + self.ax.set_zlim([mid - half, mid + half]) + + def show(self): + """阻塞显示窗口""" + plt.show() + + def close(self): + """关闭窗口""" + plt.close(self.fig)