first commit
This commit is contained in:
305
main_odometry.py
Normal file
305
main_odometry.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
IMU 惯性三维里程计 — 主入口
|
||||
|
||||
用法:
|
||||
python main_odometry.py COM5 [baud]
|
||||
python main_odometry.py COM5 --save traj.csv
|
||||
python main_odometry.py --replay traj.csv
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
import struct
|
||||
import csv
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import serial
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from imu_decode import HEADER, PAYLOAD_LEN, CRC_LEN, crc16, parse_frame
|
||||
from trajectory_tracker import Tracker
|
||||
from visualize_3d import TrajectoryViewer
|
||||
|
||||
|
||||
def read_frame(ser):
|
||||
"""从串口读取一帧, 返回解析后的 dict 或 None
|
||||
|
||||
逐字节寻找帧头 0xAA 0x55, 校验 CRC, 解包 payload。
|
||||
"""
|
||||
while True:
|
||||
b = ser.read(1)
|
||||
if not b:
|
||||
return None
|
||||
if b[0] == 0xAA:
|
||||
b2 = ser.read(1)
|
||||
if not b2 or b2[0] != 0x55:
|
||||
continue
|
||||
|
||||
frame = ser.read(PAYLOAD_LEN + CRC_LEN)
|
||||
if len(frame) < PAYLOAD_LEN + CRC_LEN:
|
||||
return None
|
||||
|
||||
payload_bytes = frame[:PAYLOAD_LEN]
|
||||
crc_recv = struct.unpack('<H', frame[PAYLOAD_LEN:])[0]
|
||||
crc_calc = crc16(HEADER + payload_bytes)
|
||||
|
||||
if crc_calc != crc_recv:
|
||||
continue
|
||||
|
||||
return parse_frame(payload_bytes)
|
||||
|
||||
|
||||
def calibrate(ser, num_samples=200, skip_first=10):
|
||||
"""静止标定: 采集 N 帧静止数据, 计算传感器偏置
|
||||
|
||||
Args:
|
||||
ser: 已打开的串口
|
||||
num_samples: 标定采样帧数
|
||||
skip_first: 跳过前 N 帧 (等待数据稳定)
|
||||
|
||||
Returns:
|
||||
accel_bias: (3,) array
|
||||
gyro_bias: (3,) array
|
||||
"""
|
||||
accel_samples = []
|
||||
gyro_samples = []
|
||||
qw_samples, qx_samples, qy_samples, qz_samples = [], [], [], []
|
||||
|
||||
print(f"静止标定: 请保持 IMU 静止, 采集 {num_samples} 帧 ...")
|
||||
collected = 0
|
||||
|
||||
while collected < num_samples + skip_first:
|
||||
frame = read_frame(ser)
|
||||
if frame is None:
|
||||
continue
|
||||
|
||||
collected += 1
|
||||
if collected <= skip_first:
|
||||
continue
|
||||
|
||||
gx, gy, gz = frame['gyro']
|
||||
ax, ay, az = frame['accel']
|
||||
qw, qx, qy, qz = frame['quat']
|
||||
accel_samples.append([ax, ay, az])
|
||||
gyro_samples.append([gx, gy, gz])
|
||||
qw_samples.append(qw)
|
||||
qx_samples.append(qx)
|
||||
qy_samples.append(qy)
|
||||
qz_samples.append(qz)
|
||||
|
||||
if collected % 50 == 0:
|
||||
print(f" 标定进度: {min(collected - skip_first, num_samples)}/{num_samples}")
|
||||
|
||||
accel_arr = np.array(accel_samples)
|
||||
gyro_arr = np.array(gyro_samples)
|
||||
qw_arr = np.array(qw_samples)
|
||||
qx_arr = np.array(qx_samples)
|
||||
qy_arr = np.array(qy_samples)
|
||||
qz_arr = np.array(qz_samples)
|
||||
return Tracker.calibrate_from_samples(accel_arr, gyro_arr, qw_arr, qx_arr, qy_arr, qz_arr)
|
||||
|
||||
|
||||
def run_live(port, baud, save_csv=None):
|
||||
"""实时模式: 从串口读取数据并显示 3D 轨迹"""
|
||||
tracker = Tracker()
|
||||
viewer = TrajectoryViewer()
|
||||
|
||||
csv_file = None
|
||||
csv_writer = None
|
||||
if save_csv:
|
||||
csv_file = open(save_csv, 'w', newline='')
|
||||
csv_writer = csv.writer(csv_file)
|
||||
csv_writer.writerow(['timestamp_ms',
|
||||
'gyro_x', 'gyro_y', 'gyro_z',
|
||||
'accel_x', 'accel_y', 'accel_z',
|
||||
'qw', 'qx', 'qy', 'qz',
|
||||
'pos_x', 'pos_y', 'pos_z'])
|
||||
|
||||
print(f"打开串口 {port} @ {baud} baud ...")
|
||||
ser = serial.Serial(port, baud, timeout=1)
|
||||
|
||||
# 静止标定
|
||||
accel_bias, gyro_bias = calibrate(ser)
|
||||
tracker.accel_bias = accel_bias
|
||||
tracker.gyro_bias = gyro_bias
|
||||
print(f"标定完成: accel_bias=({accel_bias[0]:.4f}, {accel_bias[1]:.4f}, {accel_bias[2]:.4f}) m/s^2"
|
||||
f" gyro_bias=({gyro_bias[0]:.4f}, {gyro_bias[1]:.4f}, {gyro_bias[2]:.4f}) rad/s")
|
||||
print("开始跟踪 ...\n按 Ctrl+C 停止")
|
||||
|
||||
frame_count = 0
|
||||
last_ts = None
|
||||
last_draw = time.time()
|
||||
|
||||
try:
|
||||
while plt.fignum_exists(viewer.fig.number):
|
||||
frame = read_frame(ser)
|
||||
if frame is None:
|
||||
continue
|
||||
|
||||
ts = frame['timestamp_ms']
|
||||
gx, gy, gz = frame['gyro']
|
||||
ax, ay, az = frame['accel']
|
||||
qw, qx, qy, qz = frame['quat']
|
||||
|
||||
gyro = np.array([gx, gy, gz])
|
||||
accel = np.array([ax, ay, az])
|
||||
|
||||
# 使用时间戳计算 dt
|
||||
if last_ts is not None:
|
||||
dt = (ts - last_ts) / 1000.0
|
||||
# 处理时间戳回绕 (uint32 溢出) 和异常值
|
||||
if dt <= 0 or dt > 0.1:
|
||||
dt = 0.005 # 回退到 200Hz 默认值
|
||||
else:
|
||||
dt = 0.005
|
||||
last_ts = ts
|
||||
|
||||
pos = tracker.update(gyro, accel, qw, qx, qy, qz, dt)
|
||||
|
||||
# 保存 CSV
|
||||
if csv_writer:
|
||||
csv_writer.writerow([ts, gx, gy, gz, ax, ay, az, qw, qx, qy, qz,
|
||||
pos[0], pos[1], pos[2]])
|
||||
|
||||
# 30Hz 刷新显示
|
||||
now = time.time()
|
||||
if now - last_draw >= 0.033:
|
||||
viewer.update(tracker.history_array)
|
||||
plt.pause(0.001)
|
||||
last_draw = now
|
||||
|
||||
frame_count += 1
|
||||
if frame_count % 200 == 0:
|
||||
print(f"[{frame_count:06d}] ts={ts} dt={dt*1000:.1f}ms "
|
||||
f"pos=({pos[0]:.3f}, {pos[1]:.3f}, {pos[2]:.3f})")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n停止。共接收 {frame_count} 帧。")
|
||||
finally:
|
||||
ser.close()
|
||||
if csv_file:
|
||||
csv_file.close()
|
||||
print(f"轨迹已保存至 {save_csv}")
|
||||
viewer.close()
|
||||
|
||||
|
||||
def run_replay(csv_path):
|
||||
"""回放模式: 从 CSV 文件加载数据并显示 3D 轨迹"""
|
||||
# 加载 CSV
|
||||
rows = []
|
||||
with open(csv_path, 'r') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
rows.append(row)
|
||||
|
||||
print(f"加载 {len(rows)} 帧数据")
|
||||
|
||||
# 从开头静止帧标定偏置
|
||||
calib_samples = min(200, len(rows) // 4)
|
||||
accel_samples, gyro_samples = [], []
|
||||
qw_s, qx_s, qy_s, qz_s = [], [], [], []
|
||||
for i in range(calib_samples):
|
||||
row = rows[i]
|
||||
accel_samples.append([float(row['accel_x']), float(row['accel_y']), float(row['accel_z'])])
|
||||
gyro_samples.append([float(row['gyro_x']), float(row['gyro_y']), float(row['gyro_z'])])
|
||||
qw_s.append(float(row['qw']))
|
||||
qx_s.append(float(row['qx']))
|
||||
qy_s.append(float(row['qy']))
|
||||
qz_s.append(float(row['qz']))
|
||||
accel_bias, gyro_bias = Tracker.calibrate_from_samples(
|
||||
np.array(accel_samples), np.array(gyro_samples),
|
||||
np.array(qw_s), np.array(qx_s), np.array(qy_s), np.array(qz_s))
|
||||
print(f"标定 (前{calib_samples}帧): accel_bias=({accel_bias[0]:.4f},{accel_bias[1]:.4f},{accel_bias[2]:.4f})"
|
||||
f" gyro_bias=({gyro_bias[0]:.4f},{gyro_bias[1]:.4f},{gyro_bias[2]:.4f})")
|
||||
|
||||
tracker = Tracker()
|
||||
tracker.accel_bias = accel_bias
|
||||
tracker.gyro_bias = gyro_bias
|
||||
viewer = TrajectoryViewer()
|
||||
|
||||
print("开始回放 ...(空格键暂停)")
|
||||
|
||||
last_ts = None
|
||||
last_draw = time.time()
|
||||
idx = 0
|
||||
paused = False
|
||||
|
||||
def on_key(event):
|
||||
nonlocal paused
|
||||
if event.key == ' ':
|
||||
paused = not paused
|
||||
print("暂停" if paused else "继续")
|
||||
|
||||
viewer.fig.canvas.mpl_connect('key_press_event', on_key)
|
||||
|
||||
try:
|
||||
while plt.fignum_exists(viewer.fig.number) and idx < len(rows):
|
||||
if not paused:
|
||||
row = rows[idx]
|
||||
gyro = np.array([float(row['gyro_x']), float(row['gyro_y']), float(row['gyro_z'])])
|
||||
accel = np.array([float(row['accel_x']), float(row['accel_y']), float(row['accel_z'])])
|
||||
qw, qx, qy, qz = float(row['qw']), float(row['qx']), float(row['qy']), float(row['qz'])
|
||||
|
||||
# 使用时间戳计算 dt
|
||||
ts = int(row['timestamp_ms']) if 'timestamp_ms' in row else None
|
||||
if ts is not None and last_ts is not None:
|
||||
dt = (ts - last_ts) / 1000.0
|
||||
if dt <= 0 or dt > 0.1:
|
||||
dt = 0.005
|
||||
else:
|
||||
dt = 0.005
|
||||
if ts is not None:
|
||||
last_ts = ts
|
||||
|
||||
tracker.update(gyro, accel, qw, qx, qy, qz, dt)
|
||||
idx += 1
|
||||
|
||||
now = time.time()
|
||||
if now - last_draw >= 0.033:
|
||||
viewer.update(tracker.history_array)
|
||||
plt.pause(0.001)
|
||||
last_draw = now
|
||||
else:
|
||||
plt.pause(0.05)
|
||||
|
||||
if idx % 200 == 0:
|
||||
pos = tracker.position
|
||||
print(f"[{idx:06d}/{len(rows)}] dt={dt*1000:.1f}ms "
|
||||
f"pos=({pos[0]:.3f}, {pos[1]:.3f}, {pos[2]:.3f})")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n回放中断。")
|
||||
finally:
|
||||
viewer.close()
|
||||
|
||||
print(f"回放完成,共处理 {idx} 帧。")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='IMU 惯性三维里程计')
|
||||
parser.add_argument('port', nargs='?', default=None,
|
||||
help='串口号 (如 COM5)')
|
||||
parser.add_argument('baud', nargs='?', type=int, default=115200,
|
||||
help='波特率 (默认 115200)')
|
||||
parser.add_argument('--save', metavar='FILE',
|
||||
help='保存轨迹到 CSV 文件')
|
||||
parser.add_argument('--replay', metavar='FILE',
|
||||
help='从 CSV 文件回放轨迹')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.replay:
|
||||
run_replay(args.replay)
|
||||
elif args.port:
|
||||
run_live(args.port, args.baud, save_csv=args.save)
|
||||
else:
|
||||
parser.print_help()
|
||||
print("\n示例:")
|
||||
print(" python main_odometry.py COM5")
|
||||
print(" python main_odometry.py COM5 921600")
|
||||
print(" python main_odometry.py COM5 --save traj.csv")
|
||||
print(" python main_odometry.py --replay traj.csv")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user