Files
imu_decode/main_odometry.py
2026-05-26 03:37:33 +08:00

307 lines
9.9 KiB
Python

"""
IMU 惯性三维里程计 — 主入口
用法:
python main_odometry.py COM5 [baud]
python main_odometry.py COM5 --save traj.csv
python main_odometry.py --replay traj.csv
"""
import sys
import time
import struct
import csv
import argparse
import numpy as np
import serial
import matplotlib.pyplot as plt
from imu_decode import HEADER, PAYLOAD_LEN, CRC_LEN, crc16, parse_frame
from trajectory_tracker import Tracker
from visualize_3d import TrajectoryViewer
def read_frame(ser):
"""从串口读取一帧, 返回解析后的 dict 或 None
逐字节寻找帧头 0xAA 0x55, 校验 CRC, 解包 payload。
"""
while True:
b = ser.read(1)
if not b:
return None
if b[0] == 0xAA:
b2 = ser.read(1)
if not b2 or b2[0] != 0x55:
continue
frame = ser.read(PAYLOAD_LEN + CRC_LEN)
if len(frame) < PAYLOAD_LEN + CRC_LEN:
return None
payload_bytes = frame[:PAYLOAD_LEN]
crc_recv = struct.unpack('<H', frame[PAYLOAD_LEN:])[0]
crc_calc = crc16(HEADER + payload_bytes)
if crc_calc != crc_recv:
continue
return parse_frame(payload_bytes)
def calibrate(ser, num_samples=200, skip_first=10):
"""静止标定: 采集 N 帧静止数据, 计算传感器偏置
Args:
ser: 已打开的串口
num_samples: 标定采样帧数
skip_first: 跳过前 N 帧 (等待数据稳定)
Returns:
accel_bias: (3,) array
gyro_bias: (3,) array
"""
accel_samples = []
gyro_samples = []
qw_samples, qx_samples, qy_samples, qz_samples = [], [], [], []
print(f"静止标定: 请保持 IMU 静止, 采集 {num_samples} 帧 ...")
collected = 0
while collected < num_samples + skip_first:
frame = read_frame(ser)
if frame is None:
continue
collected += 1
if collected <= skip_first:
continue
gx, gy, gz = frame['gyro']
ax, ay, az = frame['accel']
qw, qx, qy, qz = frame['quat']
accel_samples.append([ax, ay, az])
gyro_samples.append([gx, gy, gz])
qw_samples.append(qw)
qx_samples.append(qx)
qy_samples.append(qy)
qz_samples.append(qz)
if collected % 50 == 0:
print(f" 标定进度: {min(collected - skip_first, num_samples)}/{num_samples}")
accel_arr = np.array(accel_samples)
gyro_arr = np.array(gyro_samples)
qw_arr = np.array(qw_samples)
qx_arr = np.array(qx_samples)
qy_arr = np.array(qy_samples)
qz_arr = np.array(qz_samples)
return Tracker.calibrate_from_samples(accel_arr, gyro_arr, qw_arr, qx_arr, qy_arr, qz_arr)
def run_live(port, baud, save_csv=None):
"""实时模式: 从串口读取数据并显示 3D 轨迹"""
tracker = Tracker()
viewer = TrajectoryViewer()
csv_file = None
csv_writer = None
if save_csv:
csv_file = open(save_csv, 'w', newline='')
csv_writer = csv.writer(csv_file)
csv_writer.writerow(['timestamp_ms',
'gyro_x', 'gyro_y', 'gyro_z',
'accel_x', 'accel_y', 'accel_z',
'qw', 'qx', 'qy', 'qz',
'pos_x', 'pos_y', 'pos_z'])
print(f"打开串口 {port} @ {baud} baud ...")
ser = serial.Serial(port, baud, timeout=0.01) # 短超时, 避免 read(1) 阻塞
# 静止标定
accel_bias, gyro_bias = calibrate(ser)
tracker.accel_bias = accel_bias
tracker.gyro_bias = gyro_bias
print(f"标定完成: accel_bias=({accel_bias[0]:.4f}, {accel_bias[1]:.4f}, {accel_bias[2]:.4f}) m/s^2"
f" gyro_bias=({gyro_bias[0]:.4f}, {gyro_bias[1]:.4f}, {gyro_bias[2]:.4f}) rad/s")
print("开始跟踪 ...\n按 Ctrl+C 停止")
frame_count = 0
last_ts = None
last_draw = time.time()
try:
while plt.fignum_exists(viewer.fig.number):
frame = read_frame(ser)
if frame is None:
time.sleep(0.001) # 无数据时短暂休眠, 避免忙等
continue
ts = frame['timestamp_ms']
gx, gy, gz = frame['gyro']
ax, ay, az = frame['accel']
qw, qx, qy, qz = frame['quat']
gyro = np.array([gx, gy, gz])
accel = np.array([ax, ay, az])
# 使用时间戳计算 dt
if last_ts is not None:
dt = (ts - last_ts) / 1000.0
# 处理时间戳回绕 (uint32 溢出) 和异常值
if dt <= 0 or dt > 0.1:
dt = 0.005 # 回退到 200Hz 默认值
else:
dt = 0.005
last_ts = ts
pos = tracker.update(gyro, accel, qw, qx, qy, qz, dt)
# 保存 CSV
if csv_writer:
csv_writer.writerow([ts, gx, gy, gz, ax, ay, az, qw, qx, qy, qz,
pos[0], pos[1], pos[2]])
# 30Hz 刷新显示
now = time.time()
if now - last_draw >= 0.033:
viewer.update(tracker.history_array)
plt.pause(0.001)
last_draw = now
frame_count += 1
if frame_count % 200 == 0:
print(f"[{frame_count:06d}] ts={ts} dt={dt*1000:.1f}ms "
f"pos=({pos[0]:.3f}, {pos[1]:.3f}, {pos[2]:.3f})")
except KeyboardInterrupt:
print(f"\n停止。共接收 {frame_count} 帧。")
finally:
ser.close()
if csv_file:
csv_file.close()
print(f"轨迹已保存至 {save_csv}")
viewer.close()
def run_replay(csv_path):
"""回放模式: 从 CSV 文件加载数据并显示 3D 轨迹"""
# 加载 CSV
rows = []
with open(csv_path, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
rows.append(row)
print(f"加载 {len(rows)} 帧数据")
# 从开头静止帧标定偏置
calib_samples = min(200, len(rows) // 4)
accel_samples, gyro_samples = [], []
qw_s, qx_s, qy_s, qz_s = [], [], [], []
for i in range(calib_samples):
row = rows[i]
accel_samples.append([float(row['accel_x']), float(row['accel_y']), float(row['accel_z'])])
gyro_samples.append([float(row['gyro_x']), float(row['gyro_y']), float(row['gyro_z'])])
qw_s.append(float(row['qw']))
qx_s.append(float(row['qx']))
qy_s.append(float(row['qy']))
qz_s.append(float(row['qz']))
accel_bias, gyro_bias = Tracker.calibrate_from_samples(
np.array(accel_samples), np.array(gyro_samples),
np.array(qw_s), np.array(qx_s), np.array(qy_s), np.array(qz_s))
print(f"标定 (前{calib_samples}帧): accel_bias=({accel_bias[0]:.4f},{accel_bias[1]:.4f},{accel_bias[2]:.4f})"
f" gyro_bias=({gyro_bias[0]:.4f},{gyro_bias[1]:.4f},{gyro_bias[2]:.4f})")
tracker = Tracker()
tracker.accel_bias = accel_bias
tracker.gyro_bias = gyro_bias
viewer = TrajectoryViewer()
print("开始回放 ...(空格键暂停)")
last_ts = None
last_draw = time.time()
idx = 0
paused = False
def on_key(event):
nonlocal paused
if event.key == ' ':
paused = not paused
print("暂停" if paused else "继续")
viewer.fig.canvas.mpl_connect('key_press_event', on_key)
try:
while plt.fignum_exists(viewer.fig.number) and idx < len(rows):
if not paused:
row = rows[idx]
gyro = np.array([float(row['gyro_x']), float(row['gyro_y']), float(row['gyro_z'])])
accel = np.array([float(row['accel_x']), float(row['accel_y']), float(row['accel_z'])])
qw, qx, qy, qz = float(row['qw']), float(row['qx']), float(row['qy']), float(row['qz'])
# 使用时间戳计算 dt
ts = int(row['timestamp_ms']) if 'timestamp_ms' in row else None
if ts is not None and last_ts is not None:
dt = (ts - last_ts) / 1000.0
if dt <= 0 or dt > 0.1:
dt = 0.005
else:
dt = 0.005
if ts is not None:
last_ts = ts
tracker.update(gyro, accel, qw, qx, qy, qz, dt)
idx += 1
now = time.time()
if now - last_draw >= 0.033:
viewer.update(tracker.history_array)
plt.pause(0.001)
last_draw = now
else:
plt.pause(0.05)
if idx % 200 == 0:
pos = tracker.position
print(f"[{idx:06d}/{len(rows)}] dt={dt*1000:.1f}ms "
f"pos=({pos[0]:.3f}, {pos[1]:.3f}, {pos[2]:.3f})")
except KeyboardInterrupt:
print("\n回放中断。")
finally:
viewer.close()
print(f"回放完成,共处理 {idx} 帧。")
def main():
parser = argparse.ArgumentParser(description='IMU 惯性三维里程计')
parser.add_argument('port', nargs='?', default=None,
help='串口号 (如 COM5)')
parser.add_argument('baud', nargs='?', type=int, default=115200,
help='波特率 (默认 115200)')
parser.add_argument('--save', metavar='FILE',
help='保存轨迹到 CSV 文件')
parser.add_argument('--replay', metavar='FILE',
help='从 CSV 文件回放轨迹')
args = parser.parse_args()
if args.replay:
run_replay(args.replay)
elif args.port:
run_live(args.port, args.baud, save_csv=args.save)
else:
parser.print_help()
print("\n示例:")
print(" python main_odometry.py COM5")
print(" python main_odometry.py COM5 921600")
print(" python main_odometry.py COM5 --save traj.csv")
print(" python main_odometry.py --replay traj.csv")
if __name__ == '__main__':
main()