Files
imu_decode/trajectory_tracker.py
2026-05-26 03:37:33 +08:00

199 lines
6.4 KiB
Python

"""
IMU 惯性三维里程计 — 核心算法模块
管线:
四元数 → 旋转矩阵 (scipy Rotation.from_quat)
a_world = R @ (a_body - accel_bias)
a_linear = a_world - [0, 0, 9.81]
双重积分 (梯形积分 + ZUPT 静止检测 + 加速度死区)
使用 EKF 四元数标定加速度计偏置, 保证 R @ corrected_accel ≈ [0,0,g]。
"""
import numpy as np
from scipy.spatial.transform import Rotation
GRAVITY = np.array([0.0, 0.0, 9.81]) # Z-up 右手系
def quat_to_rotation(qw, qx, qy, qz):
"""四元数 → body→world 旋转矩阵
Args:
qw, qx, qy, qz: STM32 四元数 (scalar-first)
Returns:
R: 3x3 旋转矩阵
"""
return Rotation.from_quat([qx, qy, qz, qw]).as_matrix()
def rotate_accel(accel_body, R):
"""机体加速度 → 世界坐标系"""
return R @ np.asarray(accel_body)
def gravity_compensate(a_world):
"""减去重力向量"""
return a_world - GRAVITY
def apply_deadzone(a, threshold=0.03):
"""幅值小于阈值的分量置零"""
a = np.asarray(a).copy()
a[np.abs(a) < threshold] = 0.0
return a
class Tracker:
"""三维轨迹跟踪器
使用 EKF 四元数进行姿态旋转, 标定加速度计偏置, ZUPT 抑制静止漂移。
"""
def __init__(self, zupt_threshold_accel=0.20, zupt_threshold_gyro=0.05,
zupt_frames=15, deadzone_threshold=0.03,
var_window_size=30, zupt_var_threshold=0.005):
"""
Args:
zupt_threshold_accel: ZUPT ‖a_linear‖ 阈值 (m/s^2)
zupt_threshold_gyro: ZUPT ‖gyro‖ 阈值 (rad/s)
zupt_frames: 连续静止帧数阈值
deadzone_threshold: 加速度分量死区 (m/s^2)
var_window_size: 方差窗口大小 (帧)
zupt_var_threshold: ZUPT 方差阈值 (m^2/s^4)
"""
self.position = np.zeros(3)
self.velocity = np.zeros(3)
self.position_history = [self.position.copy()]
self.zupt_threshold_accel = zupt_threshold_accel
self.zupt_threshold_gyro = zupt_threshold_gyro
self.zupt_frames = zupt_frames
self.deadzone_threshold = deadzone_threshold
self.var_window_size = var_window_size
self._zupt_var_threshold = zupt_var_threshold
# 传感器偏置
self.accel_bias = np.zeros(3)
self.gyro_bias = np.zeros(3)
# ZUPT 状态
self._zupt_counter = 0
self._linear_var_window = []
self._prev_accel_linear = np.zeros(3)
# 当前四元数 (用于外部查询)
self.qw, self.qx, self.qy, self.qz = 1.0, 0.0, 0.0, 0.0
def update(self, gyro, accel, qw, qx, qy, qz, dt):
"""处理一帧 IMU 数据
Args:
gyro: 机体角速度 [gx, gy, gz] rad/s
accel: 机体加速度 [ax, ay, az] m/s^2
qw, qx, qy, qz: EKF 四元数 (scalar-first)
dt: 时间步长 (s), 由时间戳差值计算
"""
self.qw, self.qx, self.qy, self.qz = qw, qx, qy, qz
accel = np.asarray(accel, dtype=float)
gyro = np.asarray(gyro, dtype=float) - self.gyro_bias
# 0. 减去加速度计偏置
accel_corrected = accel - self.accel_bias
# 1. 四元数 → 旋转矩阵
R = quat_to_rotation(qw, qx, qy, qz)
# 2. 机体加速度 → 世界加速度 → 重力补偿
a_world = rotate_accel(accel_corrected, R)
a_linear = gravity_compensate(a_world)
# 3. 加速度死区
a_linear = apply_deadzone(a_linear, self.deadzone_threshold)
# 4. ZUPT 静止检测
gyro_norm = np.linalg.norm(gyro)
linear_magnitude = np.linalg.norm(a_linear)
self._linear_var_window.append(a_linear.copy())
if len(self._linear_var_window) > self.var_window_size:
self._linear_var_window.pop(0)
linear_variance = 0.0
if len(self._linear_var_window) >= self.var_window_size:
linear_variance = np.var(self._linear_var_window, axis=0).mean()
is_static = (
gyro_norm < self.zupt_threshold_gyro
and linear_magnitude < self.zupt_threshold_accel
and linear_variance < self._zupt_var_threshold
)
if is_static:
self._zupt_counter += 1
else:
self._zupt_counter = 0
zupt_active = self._zupt_counter >= self.zupt_frames
# 5. 梯形积分 (使用真实 dt)
if dt > 0:
if zupt_active:
self.velocity[:] = 0.0
self._prev_accel_linear = np.zeros(3)
else:
a_prev = self._prev_accel_linear
self.velocity = self.velocity + (a_prev + a_linear) * dt / 2.0
self._prev_accel_linear = a_linear.copy()
self.position = self.position + self.velocity * dt
self.position_history.append(self.position.copy())
return self.position
@staticmethod
def calibrate_from_samples(accel_samples, gyro_samples, qw_samples, qx_samples, qy_samples, qz_samples):
"""从静止采样数据计算传感器偏置
利用 EKF 四元数计算理论 body 重力: R(q)^T @ [0, 0, 9.81]
偏置 = 实测均值 - 理论均值
Args:
accel_samples: Nx3, 机体加速度 (m/s^2)
gyro_samples: Nx3, 角速度 (rad/s)
qw/qx/qy/qz_samples: N, EKF 四元数分量
Returns:
accel_bias: (3,) 加速度计偏置 (m/s^2)
gyro_bias: (3,) 陀螺仪偏置 (rad/s)
"""
accel_mean = np.mean(accel_samples, axis=0)
gyro_bias = np.mean(gyro_samples, axis=0)
# 用平均四元数计算理论的 body 重力向量
qw_m = np.mean(qw_samples)
qx_m = np.mean(qx_samples)
qy_m = np.mean(qy_samples)
qz_m = np.mean(qz_samples)
R_mean = quat_to_rotation(qw_m, qx_m, qy_m, qz_m)
gravity_body_expected = R_mean.T @ GRAVITY
accel_bias = accel_mean - gravity_body_expected
return accel_bias, gyro_bias
def reset(self):
"""重置轨迹, 保留偏置"""
self.position = np.zeros(3)
self.velocity = np.zeros(3)
self.position_history = [self.position.copy()]
self._zupt_counter = 0
self._prev_accel_linear = np.zeros(3)
self._linear_var_window = []
@property
def history_array(self):
"""返回 Nx3 numpy 数组"""
return np.array(self.position_history)