307 lines
9.9 KiB
Python
307 lines
9.9 KiB
Python
"""
|
|
IMU 惯性三维里程计 — 主入口
|
|
|
|
用法:
|
|
python main_odometry.py COM5 [baud]
|
|
python main_odometry.py COM5 --save traj.csv
|
|
python main_odometry.py --replay traj.csv
|
|
"""
|
|
|
|
import sys
|
|
import time
|
|
import struct
|
|
import csv
|
|
import argparse
|
|
|
|
import numpy as np
|
|
import serial
|
|
import matplotlib.pyplot as plt
|
|
|
|
from imu_decode import HEADER, PAYLOAD_LEN, CRC_LEN, crc16, parse_frame
|
|
from trajectory_tracker import Tracker
|
|
from visualize_3d import TrajectoryViewer
|
|
|
|
|
|
def read_frame(ser):
|
|
"""从串口读取一帧, 返回解析后的 dict 或 None
|
|
|
|
逐字节寻找帧头 0xAA 0x55, 校验 CRC, 解包 payload。
|
|
"""
|
|
while True:
|
|
b = ser.read(1)
|
|
if not b:
|
|
return None
|
|
if b[0] == 0xAA:
|
|
b2 = ser.read(1)
|
|
if not b2 or b2[0] != 0x55:
|
|
continue
|
|
|
|
frame = ser.read(PAYLOAD_LEN + CRC_LEN)
|
|
if len(frame) < PAYLOAD_LEN + CRC_LEN:
|
|
return None
|
|
|
|
payload_bytes = frame[:PAYLOAD_LEN]
|
|
crc_recv = struct.unpack('<H', frame[PAYLOAD_LEN:])[0]
|
|
crc_calc = crc16(HEADER + payload_bytes)
|
|
|
|
if crc_calc != crc_recv:
|
|
continue
|
|
|
|
return parse_frame(payload_bytes)
|
|
|
|
|
|
def calibrate(ser, num_samples=200, skip_first=10):
|
|
"""静止标定: 采集 N 帧静止数据, 计算传感器偏置
|
|
|
|
Args:
|
|
ser: 已打开的串口
|
|
num_samples: 标定采样帧数
|
|
skip_first: 跳过前 N 帧 (等待数据稳定)
|
|
|
|
Returns:
|
|
accel_bias: (3,) array
|
|
gyro_bias: (3,) array
|
|
"""
|
|
accel_samples = []
|
|
gyro_samples = []
|
|
qw_samples, qx_samples, qy_samples, qz_samples = [], [], [], []
|
|
|
|
print(f"静止标定: 请保持 IMU 静止, 采集 {num_samples} 帧 ...")
|
|
collected = 0
|
|
|
|
while collected < num_samples + skip_first:
|
|
frame = read_frame(ser)
|
|
if frame is None:
|
|
continue
|
|
|
|
collected += 1
|
|
if collected <= skip_first:
|
|
continue
|
|
|
|
gx, gy, gz = frame['gyro']
|
|
ax, ay, az = frame['accel']
|
|
qw, qx, qy, qz = frame['quat']
|
|
accel_samples.append([ax, ay, az])
|
|
gyro_samples.append([gx, gy, gz])
|
|
qw_samples.append(qw)
|
|
qx_samples.append(qx)
|
|
qy_samples.append(qy)
|
|
qz_samples.append(qz)
|
|
|
|
if collected % 50 == 0:
|
|
print(f" 标定进度: {min(collected - skip_first, num_samples)}/{num_samples}")
|
|
|
|
accel_arr = np.array(accel_samples)
|
|
gyro_arr = np.array(gyro_samples)
|
|
qw_arr = np.array(qw_samples)
|
|
qx_arr = np.array(qx_samples)
|
|
qy_arr = np.array(qy_samples)
|
|
qz_arr = np.array(qz_samples)
|
|
return Tracker.calibrate_from_samples(accel_arr, gyro_arr, qw_arr, qx_arr, qy_arr, qz_arr)
|
|
|
|
|
|
def run_live(port, baud, save_csv=None):
|
|
"""实时模式: 从串口读取数据并显示 3D 轨迹"""
|
|
tracker = Tracker()
|
|
viewer = TrajectoryViewer()
|
|
|
|
csv_file = None
|
|
csv_writer = None
|
|
if save_csv:
|
|
csv_file = open(save_csv, 'w', newline='')
|
|
csv_writer = csv.writer(csv_file)
|
|
csv_writer.writerow(['timestamp_ms',
|
|
'gyro_x', 'gyro_y', 'gyro_z',
|
|
'accel_x', 'accel_y', 'accel_z',
|
|
'qw', 'qx', 'qy', 'qz',
|
|
'pos_x', 'pos_y', 'pos_z'])
|
|
|
|
print(f"打开串口 {port} @ {baud} baud ...")
|
|
ser = serial.Serial(port, baud, timeout=0.01) # 短超时, 避免 read(1) 阻塞
|
|
|
|
# 静止标定
|
|
accel_bias, gyro_bias = calibrate(ser)
|
|
tracker.accel_bias = accel_bias
|
|
tracker.gyro_bias = gyro_bias
|
|
print(f"标定完成: accel_bias=({accel_bias[0]:.4f}, {accel_bias[1]:.4f}, {accel_bias[2]:.4f}) m/s^2"
|
|
f" gyro_bias=({gyro_bias[0]:.4f}, {gyro_bias[1]:.4f}, {gyro_bias[2]:.4f}) rad/s")
|
|
print("开始跟踪 ...\n按 Ctrl+C 停止")
|
|
|
|
frame_count = 0
|
|
last_ts = None
|
|
last_draw = time.time()
|
|
|
|
try:
|
|
while plt.fignum_exists(viewer.fig.number):
|
|
frame = read_frame(ser)
|
|
if frame is None:
|
|
time.sleep(0.001) # 无数据时短暂休眠, 避免忙等
|
|
continue
|
|
|
|
ts = frame['timestamp_ms']
|
|
gx, gy, gz = frame['gyro']
|
|
ax, ay, az = frame['accel']
|
|
qw, qx, qy, qz = frame['quat']
|
|
|
|
gyro = np.array([gx, gy, gz])
|
|
accel = np.array([ax, ay, az])
|
|
|
|
# 使用时间戳计算 dt
|
|
if last_ts is not None:
|
|
dt = (ts - last_ts) / 1000.0
|
|
# 处理时间戳回绕 (uint32 溢出) 和异常值
|
|
if dt <= 0 or dt > 0.1:
|
|
dt = 0.005 # 回退到 200Hz 默认值
|
|
else:
|
|
dt = 0.005
|
|
last_ts = ts
|
|
|
|
pos = tracker.update(gyro, accel, qw, qx, qy, qz, dt)
|
|
|
|
# 保存 CSV
|
|
if csv_writer:
|
|
csv_writer.writerow([ts, gx, gy, gz, ax, ay, az, qw, qx, qy, qz,
|
|
pos[0], pos[1], pos[2]])
|
|
|
|
# 30Hz 刷新显示
|
|
now = time.time()
|
|
if now - last_draw >= 0.033:
|
|
viewer.update(tracker.history_array)
|
|
plt.pause(0.001)
|
|
last_draw = now
|
|
|
|
frame_count += 1
|
|
if frame_count % 200 == 0:
|
|
print(f"[{frame_count:06d}] ts={ts} dt={dt*1000:.1f}ms "
|
|
f"pos=({pos[0]:.3f}, {pos[1]:.3f}, {pos[2]:.3f})")
|
|
|
|
except KeyboardInterrupt:
|
|
print(f"\n停止。共接收 {frame_count} 帧。")
|
|
finally:
|
|
ser.close()
|
|
if csv_file:
|
|
csv_file.close()
|
|
print(f"轨迹已保存至 {save_csv}")
|
|
viewer.close()
|
|
|
|
|
|
def run_replay(csv_path):
|
|
"""回放模式: 从 CSV 文件加载数据并显示 3D 轨迹"""
|
|
# 加载 CSV
|
|
rows = []
|
|
with open(csv_path, 'r') as f:
|
|
reader = csv.DictReader(f)
|
|
for row in reader:
|
|
rows.append(row)
|
|
|
|
print(f"加载 {len(rows)} 帧数据")
|
|
|
|
# 从开头静止帧标定偏置
|
|
calib_samples = min(200, len(rows) // 4)
|
|
accel_samples, gyro_samples = [], []
|
|
qw_s, qx_s, qy_s, qz_s = [], [], [], []
|
|
for i in range(calib_samples):
|
|
row = rows[i]
|
|
accel_samples.append([float(row['accel_x']), float(row['accel_y']), float(row['accel_z'])])
|
|
gyro_samples.append([float(row['gyro_x']), float(row['gyro_y']), float(row['gyro_z'])])
|
|
qw_s.append(float(row['qw']))
|
|
qx_s.append(float(row['qx']))
|
|
qy_s.append(float(row['qy']))
|
|
qz_s.append(float(row['qz']))
|
|
accel_bias, gyro_bias = Tracker.calibrate_from_samples(
|
|
np.array(accel_samples), np.array(gyro_samples),
|
|
np.array(qw_s), np.array(qx_s), np.array(qy_s), np.array(qz_s))
|
|
print(f"标定 (前{calib_samples}帧): accel_bias=({accel_bias[0]:.4f},{accel_bias[1]:.4f},{accel_bias[2]:.4f})"
|
|
f" gyro_bias=({gyro_bias[0]:.4f},{gyro_bias[1]:.4f},{gyro_bias[2]:.4f})")
|
|
|
|
tracker = Tracker()
|
|
tracker.accel_bias = accel_bias
|
|
tracker.gyro_bias = gyro_bias
|
|
viewer = TrajectoryViewer()
|
|
|
|
print("开始回放 ...(空格键暂停)")
|
|
|
|
last_ts = None
|
|
last_draw = time.time()
|
|
idx = 0
|
|
paused = False
|
|
|
|
def on_key(event):
|
|
nonlocal paused
|
|
if event.key == ' ':
|
|
paused = not paused
|
|
print("暂停" if paused else "继续")
|
|
|
|
viewer.fig.canvas.mpl_connect('key_press_event', on_key)
|
|
|
|
try:
|
|
while plt.fignum_exists(viewer.fig.number) and idx < len(rows):
|
|
if not paused:
|
|
row = rows[idx]
|
|
gyro = np.array([float(row['gyro_x']), float(row['gyro_y']), float(row['gyro_z'])])
|
|
accel = np.array([float(row['accel_x']), float(row['accel_y']), float(row['accel_z'])])
|
|
qw, qx, qy, qz = float(row['qw']), float(row['qx']), float(row['qy']), float(row['qz'])
|
|
|
|
# 使用时间戳计算 dt
|
|
ts = int(row['timestamp_ms']) if 'timestamp_ms' in row else None
|
|
if ts is not None and last_ts is not None:
|
|
dt = (ts - last_ts) / 1000.0
|
|
if dt <= 0 or dt > 0.1:
|
|
dt = 0.005
|
|
else:
|
|
dt = 0.005
|
|
if ts is not None:
|
|
last_ts = ts
|
|
|
|
tracker.update(gyro, accel, qw, qx, qy, qz, dt)
|
|
idx += 1
|
|
|
|
now = time.time()
|
|
if now - last_draw >= 0.033:
|
|
viewer.update(tracker.history_array)
|
|
plt.pause(0.001)
|
|
last_draw = now
|
|
else:
|
|
plt.pause(0.05)
|
|
|
|
if idx % 200 == 0:
|
|
pos = tracker.position
|
|
print(f"[{idx:06d}/{len(rows)}] dt={dt*1000:.1f}ms "
|
|
f"pos=({pos[0]:.3f}, {pos[1]:.3f}, {pos[2]:.3f})")
|
|
|
|
except KeyboardInterrupt:
|
|
print("\n回放中断。")
|
|
finally:
|
|
viewer.close()
|
|
|
|
print(f"回放完成,共处理 {idx} 帧。")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='IMU 惯性三维里程计')
|
|
parser.add_argument('port', nargs='?', default=None,
|
|
help='串口号 (如 COM5)')
|
|
parser.add_argument('baud', nargs='?', type=int, default=115200,
|
|
help='波特率 (默认 115200)')
|
|
parser.add_argument('--save', metavar='FILE',
|
|
help='保存轨迹到 CSV 文件')
|
|
parser.add_argument('--replay', metavar='FILE',
|
|
help='从 CSV 文件回放轨迹')
|
|
args = parser.parse_args()
|
|
|
|
if args.replay:
|
|
run_replay(args.replay)
|
|
elif args.port:
|
|
run_live(args.port, args.baud, save_csv=args.save)
|
|
else:
|
|
parser.print_help()
|
|
print("\n示例:")
|
|
print(" python main_odometry.py COM5")
|
|
print(" python main_odometry.py COM5 921600")
|
|
print(" python main_odometry.py COM5 --save traj.csv")
|
|
print(" python main_odometry.py --replay traj.csv")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|