first commit

This commit is contained in:
MobKBK
2026-05-26 03:28:37 +08:00
commit 8531c4dc9e
11 changed files with 958 additions and 0 deletions

148
README_ODOMETRY.md Normal file
View File

@@ -0,0 +1,148 @@
# IMU 惯性三维里程计 + 3D 坐标显示
## 概述
基于 STM32 IMU 串口解码系统,在 PC 端实现 3D 轨迹重建和实时可视化。下位机 (STM32) 已完成 EKF 姿态估计PC 端接收四元数姿态和滤波后加速度数据,进行偏置标定、重力补偿与双重积分,重建三维运动轨迹。
## 架构
```
STM32 (200Hz)
└── 串口帧 (48 bytes): header + timestamp(uint32) + gyro[3] + accel[3] + quat[4] + CRC16
PC 端管线:
├── imu_decode.py # 串口帧解析 (CRC 校验 + 解包 48 字节帧)
├── trajectory_tracker.py # 核心算法 (四元数→旋转, 偏置标定, 重力补偿, 双重积分, ZUPT)
├── visualize_3d.py # matplotlib 3D 动画窗口
└── main_odometry.py # 主入口 (串联所有模块)
```
## 串口协议 (v2)
| 偏移 | 大小 | 内容 |
|------|------|------|
| 0 | 2B | 帧头 0xAA 0x55 |
| 2 | 4B | 时间戳 uint32_t (ms, 自启动) |
| 6 | 12B | 滤波后 Gyro[3] (float32 × 3) |
| 18 | 12B | 滤波后 Accel[3] (float32 × 3) |
| 30 | 16B | 四元数 qw, qx, qy, qz (float32 × 4) |
| 46 | 2B | CRC16 (对前 46 字节) |
**坐标系**: 右手系, X-前 Y-左 Z-上 (Z-up)
## 算法管线
```
串口帧 (已解析)
├── timestamp ms ──→ dt = Δts / 1000 (精确时间步长)
├── gyro[3] rad/s ──→ 减 gyro_bias → ZUPT 静止检测
├── accel[3] m/s^2 ──→ 减 accel_bias (EKF 四元数标定)
└── qw/qx/qy/qz ──→ 四元数 → 旋转矩阵 (scipy Rotation.from_quat)
a_world = R @ (a_body - accel_bias)
a_linear = a_world - [0, 0, 9.81]
加速度死区 (|a_i| < 0.03 → 0)
ZUPT: ‖gyro‖ < 0.05 AND ‖a_linear‖ < 0.20 AND var(a_linear) < 0.005
梯形积分 → 速度 → 位置 (使用时间戳真实 dt)
matplotlib 3D 实时轨迹显示
```
### 关键: 加速度计偏置标定
EKF 四元数与加速度计之间存在不一致时会导致积分漂移。启动时采集 200 帧静止数据,利用 EKF 四元数标定:
```
R = from_quat(q_mean)
gravity_body_expected = R^T @ [0, 0, 9.81]
accel_bias = mean(accel_measured) - gravity_body_expected
```
标定后 `R @ (accel - bias) ≈ [0, 0, 9.81]`,静止时 `a_linear ≈ 0`ZUPT 正常触发。
## 依赖
```bash
pip install numpy matplotlib pyserial scipy
```
## 用法
### 实时模式
```bash
python main_odometry.py COM5
python main_odometry.py COM5 921600
python main_odometry.py COM5 --save traj.csv
```
### 回放模式
```bash
python main_odometry.py --replay traj.csv
```
回放时按 **空格键** 暂停/继续。
### 调试
```bash
python imu_decode.py COM5 # 仅解析帧并打印
python test_3d_demo.py # 模拟数据 3D 演示(无需串口)
```
## 文件说明
### `imu_decode.py`
串口帧解析模块。48 字节帧 = 2B header + 44B payload + 2B CRC16。
- `parse_frame(payload)` → dict: timestamp_ms, gyro(3), accel(3), quat(4)
- `quat_to_euler(qw,qx,qy,qz)` → yaw, pitch, roll (deg)
### `trajectory_tracker.py`
核心跟踪算法:
- `quat_to_rotation(qw,qx,qy,qz)` — 四元数 → 旋转矩阵
- `rotate_accel(accel_body, R)` — 机体→世界坐标
- `gravity_compensate(a_world)` — 减重力 [0, 0, 9.81]
- `apply_deadzone(a)` — 加速度死区
- `Tracker` — 位置/速度/姿态跟踪器
- `Tracker.calibrate_from_samples()` — 利用四元数标定偏置
| ZUPT 参数 | 默认值 |
|-----------|--------|
| `zupt_threshold_accel` | 0.20 m/s^2 |
| `zupt_threshold_gyro` | 0.05 rad/s |
| `zupt_frames` | 15 帧 |
| `deadzone_threshold` | 0.03 m/s^2 |
| `var_window_size` | 30 帧 |
| `zupt_var_threshold` | 0.005 m^2/s^4 |
### `visualize_3d.py`
matplotlib 3D 实时显示:蓝色轨迹线、红点当前位置、原点 RGB 坐标轴、自适应等比例坐标。
### `main_odometry.py`
运行入口:串口采集 + 3D 显示、CSV 保存、回放。dt 由时间戳差值计算。
## CSV 格式
```csv
timestamp_ms,gyro_x,gyro_y,gyro_z,accel_x,accel_y,accel_z,qw,qx,qy,qz,pos_x,pos_y,pos_z
12345,0.001,0.002,-0.001,0.75,-0.05,9.81,0.996,0.001,-0.002,-0.036,0.000,0.000,0.000
...
```
## 验证结果
| 场景 | 结果 |
|------|------|
| 25s 静止 (偏置标定后) | 0.000m 漂移 |
| 2 m/s^2 × 0.5s 运动 | 0.250m 位移 (理论值) |
## 调试
- 漂移:检查标定输出 `accel_bias` 是否合理 (通常 X/Y < 0.8, Z < 0.1)
- ZUPT 不触发:增大 `zupt_threshold_accel` 或减小 `zupt_var_threshold`
- 3D 卡顿:增大 `refresh_interval`

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

137
imu_decode.py Normal file
View File

@@ -0,0 +1,137 @@
import serial
import struct
import time
# ========== CRC16 预计算查找表 ==========
crc_tab = []
for i in range(256):
crc = 0
c = i
for _ in range(8):
if (crc ^ c) & 1:
crc = (crc >> 1) ^ 0xA001
else:
crc >>= 1
c >>= 1
crc_tab.append(crc)
def crc16(data):
"""CRC16-Modbus 校验 (查表法)"""
crc = 0xFFFF
for b in data:
crc = (crc >> 8) ^ crc_tab[(crc ^ b) & 0xFF]
return crc
# ========== 帧格式常量 (新协议 48 字节) ==========
# 0 2B 帧头 0xAA 0x55
# 2 4B 时间戳 uint32_t (ms, 自启动)
# 6 12B 滤波后 Gyro[3] (float32 × 3)
# 18 12B 滤波后 Accel[3] (float32 × 3)
# 30 16B 四元数 qw, qx, qy, qz (float32 × 4)
# 46 2B CRC16 (对前 46 字节)
HEADER = b'\xAA\x55'
HEADER_LEN = 2
PAYLOAD_LEN = 44 # uint32 + 3f + 3f + 4f = 4+12+12+16
CRC_LEN = 2
FRAME_LEN = HEADER_LEN + PAYLOAD_LEN + CRC_LEN # = 48
FIELD_NAMES = ['timestamp_ms',
'gyro_x', 'gyro_y', 'gyro_z',
'accel_x', 'accel_y', 'accel_z',
'qw', 'qx', 'qy', 'qz']
def parse_frame(payload_bytes):
"""解包 44 字节 payload
Returns:
dict with keys: timestamp_ms, gyro (3-tuple), accel (3-tuple), quat (4-tuple: qw,qx,qy,qz)
"""
ts, gx, gy, gz, ax, ay, az, qw, qx, qy, qz = struct.unpack('<I10f', payload_bytes)
return {
'timestamp_ms': ts,
'gyro': (gx, gy, gz),
'accel': (ax, ay, az),
'quat': (qw, qx, qy, qz),
}
def quat_to_euler(qw, qx, qy, qz):
"""四元数 → 欧拉角 (ZYX 内旋, deg)
scipy 格式: [x, y, z, w], STM32 格式: [w, x, y, z]
"""
from scipy.spatial.transform import Rotation
r = Rotation.from_quat([qx, qy, qz, qw])
yaw, pitch, roll = r.as_euler('ZYX', degrees=True)
return yaw, pitch, roll
def main(port='COM3', baud=115200):
print(f"打开串口 {port} @ {baud} baud ...")
ser = serial.Serial(port, baud, timeout=1)
print("等待 IMU 数据帧 (header: AA 55) ...\n")
frame_count = 0
err_count = 0
last_ts = None
try:
while True:
b = ser.read(1)
if not b:
continue
if b[0] == 0xAA:
b2 = ser.read(1)
if not b2 or b2[0] != 0x55:
continue
# 帧头匹配, 读取剩余 46 字节
frame = ser.read(PAYLOAD_LEN + CRC_LEN)
if len(frame) < PAYLOAD_LEN + CRC_LEN:
print("警告: 帧数据不完整")
continue
payload_bytes = frame[:PAYLOAD_LEN]
crc_recv = struct.unpack('<H', frame[PAYLOAD_LEN:])[0]
# CRC 校验 (覆盖 header + payload, 前 46 字节)
crc_calc = crc16(HEADER + payload_bytes)
if crc_calc == crc_recv:
data = parse_frame(payload_bytes)
ts = data['timestamp_ms']
gx, gy, gz = data['gyro']
ax, ay, az = data['accel']
qw, qx, qy, qz = data['quat']
yaw, pitch, roll = quat_to_euler(qw, qx, qy, qz)
dt_str = ""
if last_ts is not None:
dt = (ts - last_ts) / 1000.0
dt_str = f" dt={dt:.3f}s"
last_ts = ts
print(f"[#{frame_count:04d}] TS={ts:08d}{dt_str}\n"
f" Gyro: {gx:8.3f}, {gy:8.3f}, {gz:8.3f}\n"
f" Accel: {ax:7.3f}, {ay:7.3f}, {az:7.3f}\n"
f" Quat: w={qw:6.3f} x={qx:6.3f} y={qy:6.3f} z={qz:6.3f}\n"
f" Euler: Y{yaw:7.2f} P{pitch:7.2f} R{roll:7.2f}")
frame_count += 1
else:
err_count += 1
print(f"CRC 校验失败 (期望: 0x{crc_calc:04X}, 接收: 0x{crc_recv:04X}) "
f"[累计错误: {err_count}]")
except KeyboardInterrupt:
print(f"\n\n停止。共接收 {frame_count} 帧, CRC 错误 {err_count} 次。")
ser.close()
if __name__ == '__main__':
import sys
port = sys.argv[1] if len(sys.argv) > 1 else 'COM3'
baud = int(sys.argv[2]) if len(sys.argv) > 2 else 115200
main(port, baud)

305
main_odometry.py Normal file
View File

@@ -0,0 +1,305 @@
"""
IMU 惯性三维里程计 — 主入口
用法:
python main_odometry.py COM5 [baud]
python main_odometry.py COM5 --save traj.csv
python main_odometry.py --replay traj.csv
"""
import sys
import time
import struct
import csv
import argparse
import numpy as np
import serial
import matplotlib.pyplot as plt
from imu_decode import HEADER, PAYLOAD_LEN, CRC_LEN, crc16, parse_frame
from trajectory_tracker import Tracker
from visualize_3d import TrajectoryViewer
def read_frame(ser):
"""从串口读取一帧, 返回解析后的 dict 或 None
逐字节寻找帧头 0xAA 0x55, 校验 CRC, 解包 payload。
"""
while True:
b = ser.read(1)
if not b:
return None
if b[0] == 0xAA:
b2 = ser.read(1)
if not b2 or b2[0] != 0x55:
continue
frame = ser.read(PAYLOAD_LEN + CRC_LEN)
if len(frame) < PAYLOAD_LEN + CRC_LEN:
return None
payload_bytes = frame[:PAYLOAD_LEN]
crc_recv = struct.unpack('<H', frame[PAYLOAD_LEN:])[0]
crc_calc = crc16(HEADER + payload_bytes)
if crc_calc != crc_recv:
continue
return parse_frame(payload_bytes)
def calibrate(ser, num_samples=200, skip_first=10):
"""静止标定: 采集 N 帧静止数据, 计算传感器偏置
Args:
ser: 已打开的串口
num_samples: 标定采样帧数
skip_first: 跳过前 N 帧 (等待数据稳定)
Returns:
accel_bias: (3,) array
gyro_bias: (3,) array
"""
accel_samples = []
gyro_samples = []
qw_samples, qx_samples, qy_samples, qz_samples = [], [], [], []
print(f"静止标定: 请保持 IMU 静止, 采集 {num_samples} 帧 ...")
collected = 0
while collected < num_samples + skip_first:
frame = read_frame(ser)
if frame is None:
continue
collected += 1
if collected <= skip_first:
continue
gx, gy, gz = frame['gyro']
ax, ay, az = frame['accel']
qw, qx, qy, qz = frame['quat']
accel_samples.append([ax, ay, az])
gyro_samples.append([gx, gy, gz])
qw_samples.append(qw)
qx_samples.append(qx)
qy_samples.append(qy)
qz_samples.append(qz)
if collected % 50 == 0:
print(f" 标定进度: {min(collected - skip_first, num_samples)}/{num_samples}")
accel_arr = np.array(accel_samples)
gyro_arr = np.array(gyro_samples)
qw_arr = np.array(qw_samples)
qx_arr = np.array(qx_samples)
qy_arr = np.array(qy_samples)
qz_arr = np.array(qz_samples)
return Tracker.calibrate_from_samples(accel_arr, gyro_arr, qw_arr, qx_arr, qy_arr, qz_arr)
def run_live(port, baud, save_csv=None):
"""实时模式: 从串口读取数据并显示 3D 轨迹"""
tracker = Tracker()
viewer = TrajectoryViewer()
csv_file = None
csv_writer = None
if save_csv:
csv_file = open(save_csv, 'w', newline='')
csv_writer = csv.writer(csv_file)
csv_writer.writerow(['timestamp_ms',
'gyro_x', 'gyro_y', 'gyro_z',
'accel_x', 'accel_y', 'accel_z',
'qw', 'qx', 'qy', 'qz',
'pos_x', 'pos_y', 'pos_z'])
print(f"打开串口 {port} @ {baud} baud ...")
ser = serial.Serial(port, baud, timeout=1)
# 静止标定
accel_bias, gyro_bias = calibrate(ser)
tracker.accel_bias = accel_bias
tracker.gyro_bias = gyro_bias
print(f"标定完成: accel_bias=({accel_bias[0]:.4f}, {accel_bias[1]:.4f}, {accel_bias[2]:.4f}) m/s^2"
f" gyro_bias=({gyro_bias[0]:.4f}, {gyro_bias[1]:.4f}, {gyro_bias[2]:.4f}) rad/s")
print("开始跟踪 ...\n按 Ctrl+C 停止")
frame_count = 0
last_ts = None
last_draw = time.time()
try:
while plt.fignum_exists(viewer.fig.number):
frame = read_frame(ser)
if frame is None:
continue
ts = frame['timestamp_ms']
gx, gy, gz = frame['gyro']
ax, ay, az = frame['accel']
qw, qx, qy, qz = frame['quat']
gyro = np.array([gx, gy, gz])
accel = np.array([ax, ay, az])
# 使用时间戳计算 dt
if last_ts is not None:
dt = (ts - last_ts) / 1000.0
# 处理时间戳回绕 (uint32 溢出) 和异常值
if dt <= 0 or dt > 0.1:
dt = 0.005 # 回退到 200Hz 默认值
else:
dt = 0.005
last_ts = ts
pos = tracker.update(gyro, accel, qw, qx, qy, qz, dt)
# 保存 CSV
if csv_writer:
csv_writer.writerow([ts, gx, gy, gz, ax, ay, az, qw, qx, qy, qz,
pos[0], pos[1], pos[2]])
# 30Hz 刷新显示
now = time.time()
if now - last_draw >= 0.033:
viewer.update(tracker.history_array)
plt.pause(0.001)
last_draw = now
frame_count += 1
if frame_count % 200 == 0:
print(f"[{frame_count:06d}] ts={ts} dt={dt*1000:.1f}ms "
f"pos=({pos[0]:.3f}, {pos[1]:.3f}, {pos[2]:.3f})")
except KeyboardInterrupt:
print(f"\n停止。共接收 {frame_count} 帧。")
finally:
ser.close()
if csv_file:
csv_file.close()
print(f"轨迹已保存至 {save_csv}")
viewer.close()
def run_replay(csv_path):
"""回放模式: 从 CSV 文件加载数据并显示 3D 轨迹"""
# 加载 CSV
rows = []
with open(csv_path, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
rows.append(row)
print(f"加载 {len(rows)} 帧数据")
# 从开头静止帧标定偏置
calib_samples = min(200, len(rows) // 4)
accel_samples, gyro_samples = [], []
qw_s, qx_s, qy_s, qz_s = [], [], [], []
for i in range(calib_samples):
row = rows[i]
accel_samples.append([float(row['accel_x']), float(row['accel_y']), float(row['accel_z'])])
gyro_samples.append([float(row['gyro_x']), float(row['gyro_y']), float(row['gyro_z'])])
qw_s.append(float(row['qw']))
qx_s.append(float(row['qx']))
qy_s.append(float(row['qy']))
qz_s.append(float(row['qz']))
accel_bias, gyro_bias = Tracker.calibrate_from_samples(
np.array(accel_samples), np.array(gyro_samples),
np.array(qw_s), np.array(qx_s), np.array(qy_s), np.array(qz_s))
print(f"标定 (前{calib_samples}帧): accel_bias=({accel_bias[0]:.4f},{accel_bias[1]:.4f},{accel_bias[2]:.4f})"
f" gyro_bias=({gyro_bias[0]:.4f},{gyro_bias[1]:.4f},{gyro_bias[2]:.4f})")
tracker = Tracker()
tracker.accel_bias = accel_bias
tracker.gyro_bias = gyro_bias
viewer = TrajectoryViewer()
print("开始回放 ...(空格键暂停)")
last_ts = None
last_draw = time.time()
idx = 0
paused = False
def on_key(event):
nonlocal paused
if event.key == ' ':
paused = not paused
print("暂停" if paused else "继续")
viewer.fig.canvas.mpl_connect('key_press_event', on_key)
try:
while plt.fignum_exists(viewer.fig.number) and idx < len(rows):
if not paused:
row = rows[idx]
gyro = np.array([float(row['gyro_x']), float(row['gyro_y']), float(row['gyro_z'])])
accel = np.array([float(row['accel_x']), float(row['accel_y']), float(row['accel_z'])])
qw, qx, qy, qz = float(row['qw']), float(row['qx']), float(row['qy']), float(row['qz'])
# 使用时间戳计算 dt
ts = int(row['timestamp_ms']) if 'timestamp_ms' in row else None
if ts is not None and last_ts is not None:
dt = (ts - last_ts) / 1000.0
if dt <= 0 or dt > 0.1:
dt = 0.005
else:
dt = 0.005
if ts is not None:
last_ts = ts
tracker.update(gyro, accel, qw, qx, qy, qz, dt)
idx += 1
now = time.time()
if now - last_draw >= 0.033:
viewer.update(tracker.history_array)
plt.pause(0.001)
last_draw = now
else:
plt.pause(0.05)
if idx % 200 == 0:
pos = tracker.position
print(f"[{idx:06d}/{len(rows)}] dt={dt*1000:.1f}ms "
f"pos=({pos[0]:.3f}, {pos[1]:.3f}, {pos[2]:.3f})")
except KeyboardInterrupt:
print("\n回放中断。")
finally:
viewer.close()
print(f"回放完成,共处理 {idx} 帧。")
def main():
parser = argparse.ArgumentParser(description='IMU 惯性三维里程计')
parser.add_argument('port', nargs='?', default=None,
help='串口号 (如 COM5)')
parser.add_argument('baud', nargs='?', type=int, default=115200,
help='波特率 (默认 115200)')
parser.add_argument('--save', metavar='FILE',
help='保存轨迹到 CSV 文件')
parser.add_argument('--replay', metavar='FILE',
help='从 CSV 文件回放轨迹')
args = parser.parse_args()
if args.replay:
run_replay(args.replay)
elif args.port:
run_live(args.port, args.baud, save_csv=args.save)
else:
parser.print_help()
print("\n示例:")
print(" python main_odometry.py COM5")
print(" python main_odometry.py COM5 921600")
print(" python main_odometry.py COM5 --save traj.csv")
print(" python main_odometry.py --replay traj.csv")
if __name__ == '__main__':
main()

56
test_3d_demo.py Normal file
View File

@@ -0,0 +1,56 @@
"""3D 显示验证 — 模拟圆周运动轨迹 (无需串口, 使用四元数)"""
import numpy as np
import time
import matplotlib.pyplot as plt
from scipy.spatial.transform import Rotation
from trajectory_tracker import Tracker
from visualize_3d import TrajectoryViewer
def main():
tracker = Tracker(zupt_frames=50)
viewer = TrajectoryViewer(title="IMU 3D Demo — 四元数模拟")
dt = 0.005 # 200Hz
t = 0.0
omega = 0.5 # 角速度 (rad/s)
last_draw = time.time()
print("3D 演示运行中... 按 Ctrl+C 停止")
try:
while plt.fignum_exists(viewer.fig.number):
# X-Y 平面圆周运动 + Z 轴正弦起伏
ax_linear = -0.25 * np.cos(omega * t)
ay_linear = -0.25 * np.sin(omega * t)
az_linear = 0.3 * np.sin(2 * t)
# 模拟四元数 (绕 Z 轴旋转 omega*t)
yaw = omega * t
r = Rotation.from_euler('ZYX', [np.degrees(yaw), 0, 0])
q = r.as_quat() # [x, y, z, w]
# body 加速度 = R^T @ (a_linear + gravity)
a_world = np.array([ax_linear, ay_linear, az_linear + 9.81])
accel_body = r.as_matrix().T @ a_world
gyro = np.array([0.0, 0.0, omega])
pos = tracker.update(gyro, accel_body, q[3], q[0], q[1], q[2], dt)
now = time.time()
if now - last_draw >= 0.033:
viewer.update(tracker.history_array)
plt.pause(0.001)
last_draw = now
t += dt
except KeyboardInterrupt:
print("停止。")
finally:
viewer.close()
if __name__ == '__main__':
main()

198
trajectory_tracker.py Normal file
View File

@@ -0,0 +1,198 @@
"""
IMU 惯性三维里程计 — 核心算法模块
管线:
四元数 → 旋转矩阵 (scipy Rotation.from_quat)
a_world = R @ (a_body - accel_bias)
a_linear = a_world - [0, 0, 9.81]
双重积分 (梯形积分 + ZUPT 静止检测 + 加速度死区)
使用 EKF 四元数标定加速度计偏置, 保证 R @ corrected_accel ≈ [0,0,g]。
"""
import numpy as np
from scipy.spatial.transform import Rotation
GRAVITY = np.array([0.0, 0.0, 9.81]) # Z-up 右手系
def quat_to_rotation(qw, qx, qy, qz):
"""四元数 → body→world 旋转矩阵
Args:
qw, qx, qy, qz: STM32 四元数 (scalar-first)
Returns:
R: 3x3 旋转矩阵
"""
return Rotation.from_quat([qx, qy, qz, qw]).as_matrix()
def rotate_accel(accel_body, R):
"""机体加速度 → 世界坐标系"""
return R @ np.asarray(accel_body)
def gravity_compensate(a_world):
"""减去重力向量"""
return a_world - GRAVITY
def apply_deadzone(a, threshold=0.03):
"""幅值小于阈值的分量置零"""
a = np.asarray(a).copy()
a[np.abs(a) < threshold] = 0.0
return a
class Tracker:
"""三维轨迹跟踪器
使用 EKF 四元数进行姿态旋转, 标定加速度计偏置, ZUPT 抑制静止漂移。
"""
def __init__(self, zupt_threshold_accel=0.10, zupt_threshold_gyro=0.03,
zupt_frames=8, deadzone_threshold=0.02,
var_window_size=20, zupt_var_threshold=0.002):
"""
Args:
zupt_threshold_accel: ZUPT ‖a_linear‖ 阈值 (m/s^2)
zupt_threshold_gyro: ZUPT ‖gyro‖ 阈值 (rad/s)
zupt_frames: 连续静止帧数阈值
deadzone_threshold: 加速度分量死区 (m/s^2)
var_window_size: 方差窗口大小 (帧)
zupt_var_threshold: ZUPT 方差阈值 (m^2/s^4)
"""
self.position = np.zeros(3)
self.velocity = np.zeros(3)
self.position_history = [self.position.copy()]
self.zupt_threshold_accel = zupt_threshold_accel
self.zupt_threshold_gyro = zupt_threshold_gyro
self.zupt_frames = zupt_frames
self.deadzone_threshold = deadzone_threshold
self.var_window_size = var_window_size
self._zupt_var_threshold = zupt_var_threshold
# 传感器偏置
self.accel_bias = np.zeros(3)
self.gyro_bias = np.zeros(3)
# ZUPT 状态
self._zupt_counter = 0
self._linear_var_window = []
self._prev_accel_linear = np.zeros(3)
# 当前四元数 (用于外部查询)
self.qw, self.qx, self.qy, self.qz = 1.0, 0.0, 0.0, 0.0
def update(self, gyro, accel, qw, qx, qy, qz, dt):
"""处理一帧 IMU 数据
Args:
gyro: 机体角速度 [gx, gy, gz] rad/s
accel: 机体加速度 [ax, ay, az] m/s^2
qw, qx, qy, qz: EKF 四元数 (scalar-first)
dt: 时间步长 (s), 由时间戳差值计算
"""
self.qw, self.qx, self.qy, self.qz = qw, qx, qy, qz
accel = np.asarray(accel, dtype=float)
gyro = np.asarray(gyro, dtype=float) - self.gyro_bias
# 0. 减去加速度计偏置
accel_corrected = accel - self.accel_bias
# 1. 四元数 → 旋转矩阵
R = quat_to_rotation(qw, qx, qy, qz)
# 2. 机体加速度 → 世界加速度 → 重力补偿
a_world = rotate_accel(accel_corrected, R)
a_linear = gravity_compensate(a_world)
# 3. 加速度死区
a_linear = apply_deadzone(a_linear, self.deadzone_threshold)
# 4. ZUPT 静止检测
gyro_norm = np.linalg.norm(gyro)
linear_magnitude = np.linalg.norm(a_linear)
self._linear_var_window.append(a_linear.copy())
if len(self._linear_var_window) > self.var_window_size:
self._linear_var_window.pop(0)
linear_variance = 0.0
if len(self._linear_var_window) >= self.var_window_size:
linear_variance = np.var(self._linear_var_window, axis=0).mean()
is_static = (
gyro_norm < self.zupt_threshold_gyro
and linear_magnitude < self.zupt_threshold_accel
and linear_variance < self._zupt_var_threshold
)
if is_static:
self._zupt_counter += 1
else:
self._zupt_counter = 0
zupt_active = self._zupt_counter >= self.zupt_frames
# 5. 梯形积分 (使用真实 dt)
if dt > 0:
if zupt_active:
self.velocity[:] = 0.0
self._prev_accel_linear = np.zeros(3)
else:
a_prev = self._prev_accel_linear
self.velocity = self.velocity + (a_prev + a_linear) * dt / 2.0
self._prev_accel_linear = a_linear.copy()
self.position = self.position + self.velocity * dt
self.position_history.append(self.position.copy())
return self.position
@staticmethod
def calibrate_from_samples(accel_samples, gyro_samples, qw_samples, qx_samples, qy_samples, qz_samples):
"""从静止采样数据计算传感器偏置
利用 EKF 四元数计算理论 body 重力: R(q)^T @ [0, 0, 9.81]
偏置 = 实测均值 - 理论均值
Args:
accel_samples: Nx3, 机体加速度 (m/s^2)
gyro_samples: Nx3, 角速度 (rad/s)
qw/qx/qy/qz_samples: N, EKF 四元数分量
Returns:
accel_bias: (3,) 加速度计偏置 (m/s^2)
gyro_bias: (3,) 陀螺仪偏置 (rad/s)
"""
accel_mean = np.mean(accel_samples, axis=0)
gyro_bias = np.mean(gyro_samples, axis=0)
# 用平均四元数计算理论的 body 重力向量
qw_m = np.mean(qw_samples)
qx_m = np.mean(qx_samples)
qy_m = np.mean(qy_samples)
qz_m = np.mean(qz_samples)
R_mean = quat_to_rotation(qw_m, qx_m, qy_m, qz_m)
gravity_body_expected = R_mean.T @ GRAVITY
accel_bias = accel_mean - gravity_body_expected
return accel_bias, gyro_bias
def reset(self):
"""重置轨迹, 保留偏置"""
self.position = np.zeros(3)
self.velocity = np.zeros(3)
self.position_history = [self.position.copy()]
self._zupt_counter = 0
self._prev_accel_linear = np.zeros(3)
self._linear_var_window = []
@property
def history_array(self):
"""返回 Nx3 numpy 数组"""
return np.array(self.position_history)

114
visualize_3d.py Normal file
View File

@@ -0,0 +1,114 @@
"""
IMU 3D 轨迹实时可视化
matplotlib 3D 窗口, 30Hz 刷新, 显示:
- 蓝色轨迹线
- 当前点红点
- 原点坐标系指示
- 等比例坐标轴
"""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
class TrajectoryViewer:
"""3D 轨迹实时显示窗口"""
def __init__(self, title="IMU 3D Odometry", refresh_interval=33):
"""
Args:
title: 窗口标题
refresh_interval: 刷新间隔 (ms), 默认 33ms ≈ 30Hz
"""
self.refresh_interval = refresh_interval
self.fig = plt.figure(figsize=(8, 7))
self.fig.canvas.manager.set_window_title(title)
self.ax = self.fig.add_subplot(111, projection='3d')
# 轨迹线 (蓝色)
self.traj_line, = self.ax.plot([], [], [], 'b-', linewidth=1.0, label='Trajectory')
# 当前点 (红色)
self.current_point, = self.ax.plot([], [], [], 'ro', markersize=6, label='Current')
# 坐标系指示 (原点处)
axis_len = 0.3
self.origin_axes = [
self.ax.quiver(0, 0, 0, axis_len, 0, 0, color='r', arrow_length_ratio=0.15, label='X'),
self.ax.quiver(0, 0, 0, 0, axis_len, 0, color='g', arrow_length_ratio=0.15, label='Y'),
self.ax.quiver(0, 0, 0, 0, 0, axis_len, color='b', arrow_length_ratio=0.15, label='Z'),
]
self._setup_axes()
# 动画
self.anim = FuncAnimation(self.fig, self._animate, interval=self.refresh_interval,
cache_frame_data=False, blit=False)
def _setup_axes(self):
"""初始化坐标轴"""
self.ax.set_xlabel('X (front)')
self.ax.set_ylabel('Y (left)')
self.ax.set_zlabel('Z (up)')
self.ax.set_title("IMU 3D Trajectory (Z-up)")
self.ax.legend(loc='upper left')
# 初始范围
self.ax.set_xlim([-1, 1])
self.ax.set_ylim([-1, 1])
self.ax.set_zlim([-1, 1])
try:
self.ax.set_box_aspect([1, 1, 1])
except NotImplementedError:
pass
self.ax.grid(True)
def _animate(self, frame):
"""动画帧回调 (不做任何事, 数据由外部 update 驱动)"""
pass # 通过 plt.pause 驱动, FuncAnimation 仅用于保持窗口响应
def update(self, history_array):
"""更新显示的轨迹数据
Args:
history_array: Nx3 numpy array, 位置历史
"""
if len(history_array) < 1:
return
x, y, z = history_array[:, 0], history_array[:, 1], history_array[:, 2]
# 更新轨迹线
self.traj_line.set_data(x, y)
self.traj_line.set_3d_properties(z)
# 更新当前点
self.current_point.set_data([x[-1]], [y[-1]])
self.current_point.set_3d_properties([z[-1]])
# 自适应坐标轴范围
self._auto_scale(x, y, z)
def _auto_scale(self, x, y, z):
"""根据数据自动调整坐标轴范围, 保持等比例"""
all_coords = np.concatenate([x, y, z])
margin = max(np.ptp(all_coords) * 0.2, 0.5)
mid = (all_coords.min() + all_coords.max()) / 2
half = (all_coords.max() - all_coords.min()) / 2 + margin
self.ax.set_xlim([mid - half, mid + half])
self.ax.set_ylim([mid - half, mid + half])
self.ax.set_zlim([mid - half, mid + half])
def show(self):
"""阻塞显示窗口"""
plt.show()
def close(self):
"""关闭窗口"""
plt.close(self.fig)