diff --git a/__pycache__/main_odometry.cpython-310.pyc b/__pycache__/main_odometry.cpython-310.pyc index c3bd3ff..771f18b 100644 Binary files a/__pycache__/main_odometry.cpython-310.pyc and b/__pycache__/main_odometry.cpython-310.pyc differ diff --git a/__pycache__/trajectory_tracker.cpython-310.pyc b/__pycache__/trajectory_tracker.cpython-310.pyc index db840aa..9f5abed 100644 Binary files a/__pycache__/trajectory_tracker.cpython-310.pyc and b/__pycache__/trajectory_tracker.cpython-310.pyc differ diff --git a/__pycache__/visualize_3d.cpython-310.pyc b/__pycache__/visualize_3d.cpython-310.pyc index f34ddc2..ea345d4 100644 Binary files a/__pycache__/visualize_3d.cpython-310.pyc and b/__pycache__/visualize_3d.cpython-310.pyc differ diff --git a/main_odometry.py b/main_odometry.py index 1347a4f..86bbcb7 100644 --- a/main_odometry.py +++ b/main_odometry.py @@ -166,7 +166,7 @@ def run_live(port, baud, save_csv=None): # 30Hz 刷新显示 now = time.time() if now - last_draw >= 0.033: - viewer.update(tracker.history_array) + viewer.update(tracker.history_array, tracker.R) plt.pause(0.001) last_draw = now @@ -258,7 +258,7 @@ def run_replay(csv_path): now = time.time() if now - last_draw >= 0.033: - viewer.update(tracker.history_array) + viewer.update(tracker.history_array, tracker.R) plt.pause(0.001) last_draw = now else: diff --git a/trajectory_tracker.py b/trajectory_tracker.py index 8869255..5905429 100644 --- a/trajectory_tracker.py +++ b/trajectory_tracker.py @@ -66,6 +66,7 @@ class Tracker: self.velocity = np.zeros(3) self.position_history = [self.position.copy()] + self._max_history = 800 self.zupt_threshold_accel = zupt_threshold_accel self.zupt_threshold_gyro = zupt_threshold_gyro @@ -83,8 +84,9 @@ class Tracker: self._linear_var_window = [] self._prev_accel_linear = np.zeros(3) - # 当前四元数 (用于外部查询) + # 当前四元数和旋转矩阵 (用于外部查询) self.qw, self.qx, self.qy, self.qz = 1.0, 0.0, 0.0, 0.0 + self.R = np.eye(3) def update(self, gyro, accel, qw, qx, qy, qz, dt): """处理一帧 IMU 数据 @@ -104,7 +106,7 @@ class Tracker: accel_corrected = accel - self.accel_bias # 1. 四元数 → 旋转矩阵 - R = quat_to_rotation(qw, qx, qy, qz) + self.R = R = quat_to_rotation(qw, qx, qy, qz) # 2. 机体加速度 → 世界加速度 → 重力补偿 a_world = rotate_accel(accel_corrected, R) @@ -113,9 +115,10 @@ class Tracker: # 3. 加速度死区 a_linear = apply_deadzone(a_linear, self.deadzone_threshold) - # 4. ZUPT 静止检测 + # 4. ZUPT 静止/纯旋转检测 gyro_norm = np.linalg.norm(gyro) linear_magnitude = np.linalg.norm(a_linear) + accel_mag = np.linalg.norm(accel_corrected) self._linear_var_window.append(a_linear.copy()) if len(self._linear_var_window) > self.var_window_size: @@ -125,13 +128,22 @@ class Tracker: if len(self._linear_var_window) >= self.var_window_size: linear_variance = np.var(self._linear_var_window, axis=0).mean() + # 静止: gyro 低 + a_linear 小 + 方差低 is_static = ( gyro_norm < self.zupt_threshold_gyro and linear_magnitude < self.zupt_threshold_accel and linear_variance < self._zupt_var_threshold ) - if is_static: + # 纯旋转: a_linear 小 + 方差低 + body accel 幅值 ≈ 纯重力 (无平移) + is_rotation_only = ( + not is_static + and linear_magnitude < self.zupt_threshold_accel + and linear_variance < self._zupt_var_threshold + and abs(accel_mag - 9.81) < 0.3 + ) + + if is_static or is_rotation_only: self._zupt_counter += 1 else: self._zupt_counter = 0 @@ -150,6 +162,8 @@ class Tracker: self.position = self.position + self.velocity * dt self.position_history.append(self.position.copy()) + if len(self.position_history) > self._max_history: + self.position_history = self.position_history[-self._max_history:] return self.position @staticmethod diff --git a/visualize_3d.py b/visualize_3d.py index 868072e..4258c0d 100644 --- a/visualize_3d.py +++ b/visualize_3d.py @@ -2,9 +2,9 @@ IMU 3D 轨迹实时可视化 matplotlib 3D 窗口, 30Hz 刷新, 显示: - - 蓝色轨迹线 (自动降采样, 最多 800 点) - - 当前点红点 - - 原点坐标系指示 + - 蓝色轨迹线 (最近 800 点, Tracker 端已限长) + - 当前点红点 + 朝向指示线 (RGB = body XYZ) + - 原点世界坐标系指示 - 等比例坐标轴 """ @@ -15,15 +15,7 @@ import matplotlib.pyplot as plt class TrajectoryViewer: """3D 轨迹实时显示窗口""" - def __init__(self, title="IMU 3D Odometry", max_trail_points=800): - """ - Args: - title: 窗口标题 - max_trail_points: 轨迹线最多显示点数 (降采样, 避免渲染卡顿) - """ - self.max_trail_points = max_trail_points - self._full_count = 0 - + def __init__(self, title="IMU 3D Odometry"): self.fig = plt.figure(figsize=(8, 7)) self.fig.canvas.manager.set_window_title(title) self.ax = self.fig.add_subplot(111, projection='3d') @@ -34,27 +26,28 @@ class TrajectoryViewer: # 当前点 (红色) self.current_point, = self.ax.plot([], [], [], 'ro', markersize=6, label='Current') - # 坐标系指示 (原点处) + # 朝向指示线 (body XYZ, 用 plot 可原地更新) + self.body_x_line, = self.ax.plot([], [], [], 'r-', linewidth=2.0, alpha=0.9) + self.body_y_line, = self.ax.plot([], [], [], 'g-', linewidth=2.0, alpha=0.9) + self.body_z_line, = self.ax.plot([], [], [], 'b-', linewidth=2.0, alpha=0.9) + + # 世界坐标系指示 (固定在原点) axis_len = 0.3 - self.origin_axes = [ - self.ax.quiver(0, 0, 0, axis_len, 0, 0, color='r', arrow_length_ratio=0.15, label='X'), - self.ax.quiver(0, 0, 0, 0, axis_len, 0, color='g', arrow_length_ratio=0.15, label='Y'), - self.ax.quiver(0, 0, 0, 0, 0, axis_len, color='b', arrow_length_ratio=0.15, label='Z'), - ] + self.ax.quiver(0, 0, 0, axis_len, 0, 0, color='r', arrow_length_ratio=0.15) + self.ax.quiver(0, 0, 0, 0, axis_len, 0, color='g', arrow_length_ratio=0.15) + self.ax.quiver(0, 0, 0, 0, 0, axis_len, color='b', arrow_length_ratio=0.15) self._setup_axes() + self._frame = 0 - # 显示非阻塞窗口 plt.show(block=False) plt.pause(0.1) def _setup_axes(self): - """初始化坐标轴""" self.ax.set_xlabel('X (front)') self.ax.set_ylabel('Y (left)') self.ax.set_zlabel('Z (up)') self.ax.set_title("IMU 3D Trajectory (Z-up)") - self.ax.legend(loc='upper left') self.ax.set_xlim([-1, 1]) self.ax.set_ylim([-1, 1]) @@ -67,25 +60,18 @@ class TrajectoryViewer: self.ax.grid(True) - def update(self, history_array): - """更新显示的轨迹数据 (自动降采样) + def update(self, history_array, rotation_matrix=None): + """更新显示 Args: - history_array: Nx3 numpy array, 位置历史 + history_array: Nx3 位置历史 (已由 Tracker 限长为 800) + rotation_matrix: 3x3 body→world 旋转矩阵 (None 则不显示朝向) """ n = len(history_array) if n < 1: return - # 降采样: 超过 max_trail_points 时取等间隔子集 - if n > self.max_trail_points: - step = n // self.max_trail_points - indices = np.arange(0, n, step) - sampled = history_array[indices] - else: - sampled = history_array - - x, y, z = sampled[:, 0], sampled[:, 1], sampled[:, 2] + x, y, z = history_array[:, 0], history_array[:, 1], history_array[:, 2] self.traj_line.set_data(x, y) self.traj_line.set_3d_properties(z) @@ -95,13 +81,40 @@ class TrajectoryViewer: self.current_point.set_data([last[0]], [last[1]]) self.current_point.set_3d_properties([last[2]]) - # 自适应坐标轴 (每 20 次完整更新做一次, 减少开销) - self._full_count += 1 - if self._full_count % 5 == 0: + # 朝向 (原地更新, 不重建对象) + if rotation_matrix is not None: + self._update_orientation(last, rotation_matrix) + else: + self.body_x_line.set_data([], []) + self.body_x_line.set_3d_properties([]) + self.body_y_line.set_data([], []) + self.body_y_line.set_3d_properties([]) + self.body_z_line.set_data([], []) + self.body_z_line.set_3d_properties([]) + + # 坐标轴范围 (每 5 帧更新一次) + self._frame += 1 + if self._frame % 8 == 0: self._auto_scale(x, y, z) + def _update_orientation(self, position, R): + """更新 body 系朝向指示线 (不重建对象) + + Args: + position: (3,) 当前位置 + R: 3x3 body→world 旋转矩阵 + """ + px, py, pz = position + length = 0.15 + + for line, col in [(self.body_x_line, 0), + (self.body_y_line, 1), + (self.body_z_line, 2)]: + d = R[:, col] * length + line.set_data([px, px + d[0]], [py, py + d[1]]) + line.set_3d_properties([pz, pz + d[2]]) + def _auto_scale(self, x, y, z): - """自适应等比例坐标轴""" all_coords = np.concatenate([x, y, z]) margin = max(np.ptp(all_coords) * 0.2, 0.5) mid = (all_coords.min() + all_coords.max()) / 2 @@ -112,5 +125,4 @@ class TrajectoryViewer: self.ax.set_zlim([mid - half, mid + half]) def close(self): - """关闭窗口""" plt.close(self.fig)