diff --git a/.DS_Store b/.DS_Store index fc018fc..a166926 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/src/rm_auto_aim/armor_mpc_solver/CMakeLists.txt b/src/rm_auto_aim/armor_mpc_solver/CMakeLists.txt new file mode 100644 index 0000000..c1972e2 --- /dev/null +++ b/src/rm_auto_aim/armor_mpc_solver/CMakeLists.txt @@ -0,0 +1,87 @@ +cmake_minimum_required(VERSION 3.8) +project(armor_mpc_solver) + +if(CMAKE_CXX_STANDARD GREATER_RANGE 17) + set(CMAKE_CXX_STANDARD 17) +else() + set(CMAKE_CXX_STANDARD 17) +endif() + +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +# Find dependencies +find_package(ament_cmake REQUIRED) +find_package(rclcpp REQUIRED) +find_package(rm_interfaces REQUIRED) +find_package(rm_utils REQUIRED) +find_package(rm_tinympc REQUIRED) +find_package(Eigen3 REQUIRED) +find_package(tf2 REQUIRED) +find_package(tf2_ros REQUIRED) +find_package(visualization_msgs REQUIRED) +find_package(geometry_msgs REQUIRED) +find_package(std_msgs REQUIRED) + +include_directories( + include + ${EIGEN3_INCLUDE_DIRS} +) + +set(${PROJECT_NAME}_HEADERS + include/armor_mpc_solver/solver.hpp + include/armor_mpc_solver/solver_comparer.hpp + include/armor_mpc_solver/solver_comparison_node.hpp +) + +add_library(${PROJECT_NAME} SHARED + src/solver.cpp + src/solver_comparer.cpp +) + +add_executable(${PROJECT_NAME}_node src/solver_comparison_node.cpp) +ament_target_dependencies(${PROJECT_NAME}_node + rclcpp + rm_interfaces + rm_utils + rm_tinympc + Eigen3 + tf2 + tf2_ros + visualization_msgs + geometry_msgs + std_msgs + ${PROJECT_NAME} +) + +ament_target_dependencies(${PROJECT_NAME} + rclcpp + rm_interfaces + rm_utils + rm_tinympc + Eigen3 + tf2 + tf2_ros + visualization_msgs + geometry_msgs + std_msgs +) + +target_include_directories(${PROJECT_NAME} PUBLIC + $ + $ +) + +install(TARGETS ${PROJECT_NAME} ${PROJECT_NAME}_node + ARCHIVE DESTINATION lib + LIBRARY DESTINATION lib + RUNTIME DESTINATION bin +) + +install(DIRECTORY include/ + DESTINATION include +) + +ament_export_include_directories(include) +ament_export_libraries(${PROJECT_NAME}) + +ament_package() diff --git a/src/rm_auto_aim/armor_mpc_solver/include/armor_mpc_solver/solver.hpp b/src/rm_auto_aim/armor_mpc_solver/include/armor_mpc_solver/solver.hpp new file mode 100644 index 0000000..c5f803d --- /dev/null +++ b/src/rm_auto_aim/armor_mpc_solver/include/armor_mpc_solver/solver.hpp @@ -0,0 +1,311 @@ +// Copyright (C) FYT Vision Group. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ARMOR_MPC_SOLVER_SOLVER_HPP_ +#define ARMOR_MPC_SOLVER_SOLVER_HPP_ + +#include +#include +#include +#include +#include +#include +#include +#include "rm_tinympc/types.hpp" +#include "rm_tinympc/admm.hpp" +#include "rm_tinympc/trajectory.hpp" +#include "rm_utils/math/trajectory_compensator.hpp" +#include "rm_utils/math/manual_compensator.hpp" + +namespace fyt::auto_aim { + +/** + * Trajectory point for visualization and debugging + */ +struct TrajPoint { + double t; + double yaw; + double pitch; + double yaw_vel; + double pitch_vel; +}; + +/** + * Quintic segment for polynomial trajectory + * Provides smooth position/velocity/acceleration trajectories + */ +struct QuinticSegment { + double a0, a1, a2, a3, a4, a5; // Coefficients: p(t) = a0 + a1*t + a2*t^2 + a3*t^3 + a4*t^4 + a5*t^5 + + double position(double t) const; + double velocity(double t) const; + double acceleration(double t) const; + + static QuinticSegment fromBoundaryConditions( + double p0, double v0, double a0, + double p1, double v1, double a1, + double duration); +}; + +/** + * LimitTrajectory - Acceleration-constrained trajectory generation + * Ensures gimbal acceleration stays within limits + */ +class LimitTrajectory { +public: + struct TrajectoryPoint { + double t; + double yaw; + double pitch; + double v_yaw; + double v_pitch; + double a_yaw; + double a_pitch; + }; + + LimitTrajectory(double max_yaw_acc, double max_pitch_acc, double dt); + + std::vector generate( + double start_yaw, double start_pitch, + double target_yaw, double target_pitch, + double max_vel_yaw, double max_vel_pitch); + +private: + double max_yaw_acc_; + double max_pitch_acc_; + double dt_; + + double trapezoidalTime(double dist, double max_vel, double max_acc); + QuinticSegment quinticPlan(double p0, double v0, double a0, double p1, double v1, double a1, double T); +}; + +/** + * Armor selection result + */ +struct ArmorSelectResult { + int armor_id; // Selected armor index (0-3) + double coming_angle; // Angle when approaching + double leaving_angle; // Angle when leaving + bool is_switching; // Whether switching armor +}; + +/** + * Gimbal MPC Solver using TinyMPC ADMM + * + * State: [yaw, pitch, yaw_rate, pitch_rate] + * Input: [yaw_acc, pitch_acc] + * + * Cost: sum ||state - ref||_Q + ||input||_R + * Constraints: yaw_rate, pitch_rate limits + */ +class MpcSolver { +public: + MpcSolver(std::weak_ptr n); + + rm_interfaces::msg::GimbalCmd solve( + const rm_interfaces::msg::Target& target, + const rclcpp::Time& current_time, + std::shared_ptr tf2_buffer); + + // Trajectory getters for visualization + std::vector getMpcTrajectory() const { return mpc_trajectory_; } + std::vector getReferenceTrajectory() const { return reference_trajectory_; } + std::vector getOptimalTrajectory() const { return optimal_trajectory_; } + + void setBulletSpeed(double bullet_speed) { bullet_speed_ = bullet_speed; } + + // Limit trajectory for acceleration-constrained paths + std::vector getLimitTrajectory() const { return limit_trajectory_; } + +private: + // State and input dimensions + static constexpr int N_X = 4; // [yaw, pitch, yaw_rate, pitch_rate] + static constexpr int N_U = 2; // [yaw_acc, pitch_acc] + + // MPC parameters + double bullet_speed_; + double gravity_; + double prediction_horizon_; + double dt_; + + // Gimbal rate limits (rad/s) + double max_yaw_rate_; + double max_pitch_rate_; + // Gimbal acceleration limits (rad/s^2) + double max_yaw_acc_; + double max_pitch_acc_; + + // Cost weights - separated for yaw and pitch + // Q = [position_cost, velocity_cost] + Eigen::Vector2d Q_yaw_; + Eigen::Vector2d Q_pitch_; + Eigen::Vector2d R_yaw_; + Eigen::Vector2d R_pitch_; + Eigen::Vector2d Qf_yaw_; + Eigen::Vector2d Qf_pitch_; + + // Legacy combined cost weights (for compatibility) + Eigen::Vector4d Q_; // State cost [yaw, pitch, yaw_rate, pitch_rate] + Eigen::Vector2d R_; // Input cost [yaw_acc, pitch_acc] + Eigen::Vector4d Qf_; // Terminal cost + + // Fire decision parameters + double delay_enable_fire_error_; + double yaw_limit_deg_; + double shooting_range_h_; + double shooting_range_small_w_; + double shooting_range_big_w_; + double min_enable_pitch_deg_; + double min_enable_yaw_deg_; + double comming_angle_deg_; + double leaving_angle_deg_; + + // Armor selection + double coming_angle_thresh_; + double leaving_angle_thresh_; + + // ADMM MPC solver - separated for yaw and pitch + std::unique_ptr admm_solver_yaw_; + std::unique_ptr admm_solver_pitch_; + + // Trajectory generators + rm_tinympc::GimbalTrajectoryGenerator trajectory_generator_; + + // Trajectories for visualization + std::vector mpc_trajectory_; + std::vector reference_trajectory_; + std::vector optimal_trajectory_; + std::vector limit_trajectory_; + + rclcpp::Time last_time_; + + // Current gimbal state [yaw, pitch] + Eigen::Vector2d current_gimbal_state_; + + // Current gimbal velocity [yaw_rate, pitch_rate] + Eigen::Vector2d current_gimbal_velocity_; + + // Control delay compensation (seconds) + double control_delay_; + + // Limit trajectory generator + std::unique_ptr limit_trajectory_generator_; + + // Trajectory compensator for bullet arc compensation + std::unique_ptr trajectory_compensator_; + + // Manual compensator for angle offset correction + std::unique_ptr manual_compensator_; + + /** + * Initialize ADMM solver with problem matrices + */ + void initADMM(); + + /** + * Initialize trajectory compensator + */ + void initTrajectoryCompensator(const std::string& compensator_type); + + /** + * Compute reference trajectory (target states over horizon) + */ + Eigen::MatrixXd computeReferenceTrajectory( + const Eigen::Vector3d& target_pos, + const Eigen::Vector3d& target_vel, + double target_yaw, + double v_yaw, + double flying_time); + + /** + * Solve MPC and get optimal control + */ + Eigen::Vector2d solveMPC(const Eigen::Vector4d& x0, const Eigen::MatrixXd& xref); + + /** + * Compute gimbal command from optimal control + */ + Eigen::Vector2d computeGimbalCommandFromControl(const Eigen::Vector2d& u); + + /** + * Compute flying time with trajectory compensation + */ + double computeFlyingTime(const Eigen::Vector3d& target_pos); + + /** + * Predict target position at flying_time + */ + Eigen::Vector3d predictTargetPosition( + const Eigen::Vector3d& pos, + const Eigen::Vector3d& vel, + double dt); + + /** + * Convert 2D gimbal state to gimbal command message + */ + rm_interfaces::msg::GimbalCmd toGimbalCmd( + const Eigen::Vector2d& gimbal_angles, + const Eigen::Vector3d& target_pos, + const rm_interfaces::msg::Target& target); + + /** + * Build MPC trajectory for visualization (separated yaw/pitch) + */ + void buildMpcVisualizationTrajectory( + const Eigen::Vector4d& x0, + const Eigen::MatrixXd& x_traj_yaw, + const Eigen::MatrixXd& x_traj_pitch, + const Eigen::MatrixXd& u_traj_yaw, + const Eigen::MatrixXd& u_traj_pitch); + + /** + * Build reference trajectory for visualization + */ + void buildReferenceVisualizationTrajectory( + const Eigen::MatrixXd& xref); + + /** + * Select best armor based on coming/leaving angle + */ + ArmorSelectResult selectArmor( + const rm_interfaces::msg::Target& target, + const rclcpp::Time& current_time); + + /** + * Check if can fire at given time considering control delay + */ + bool canFireAtTime( + const rm_interfaces::msg::Target& target, + double time_to_target, + const rclcpp::Time& current_time); + + /** + * Compute trajectory with limit constraints + */ + std::vector computeLimitTrajectory( + double target_yaw, double target_pitch); + + /** + * Compute time-optimal trajectory with acceleration constraints + */ + void computeAccelConstrainedTrajectory( + const Eigen::Vector2d& start, + const Eigen::Vector2d& target, + const Eigen::Vector2d& max_vel, + const Eigen::Vector2d& max_acc); +}; + +} // namespace fyt::auto_aim + +#endif // ARMOR_MPC_SOLVER_SOLVER_HPP_ diff --git a/src/rm_auto_aim/armor_mpc_solver/include/armor_mpc_solver/solver_comparer.hpp b/src/rm_auto_aim/armor_mpc_solver/include/armor_mpc_solver/solver_comparer.hpp new file mode 100644 index 0000000..adff342 --- /dev/null +++ b/src/rm_auto_aim/armor_mpc_solver/include/armor_mpc_solver/solver_comparer.hpp @@ -0,0 +1,70 @@ +// Copyright (C) FYT Vision Group. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ARMOR_MPC_SOLVER_SOLVER_COMPARER_HPP_ +#define ARMOR_MPC_SOLVER_SOLVER_COMPARER_HPP_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "armor_solver/armor_solver.hpp" +#include "armor_mpc_solver/solver.hpp" + +namespace fyt::auto_aim { + +class SolverComparer { +public: + SolverComparer(std::weak_ptr n); + + void init(); + void update(const rm_interfaces::msg::Target::SharedPtr target_msg); + + void publishComparisonMarkers(); + +private: + std::weak_ptr node_; + + std::unique_ptr original_solver_; + std::unique_ptr mpc_solver_; + + rclcpp::Publisher::SharedPtr trajectory_pub_; + rclcpp::Publisher::SharedPtr mpc_gimbal_pub_; + + std::shared_ptr tf2_buffer_; + std::shared_ptr tf2_listener_; + + std::vector last_mpc_trajectory_; + std::vector last_reference_trajectory_; + std::vector last_limit_trajectory_; + + int marker_id_; + std::string frame_id_; + + rm_interfaces::msg::GimbalCmd last_original_cmd_; + rm_interfaces::msg::GimbalCmd last_mpc_cmd_; + + bool use_mpc_; +}; + +} // namespace fyt::auto_aim + +#endif // ARMOR_MPC_SOLVER_SOLVER_COMPARER_HPP_ diff --git a/src/rm_auto_aim/armor_mpc_solver/include/armor_mpc_solver/solver_comparison_node.hpp b/src/rm_auto_aim/armor_mpc_solver/include/armor_mpc_solver/solver_comparison_node.hpp new file mode 100644 index 0000000..42ea902 --- /dev/null +++ b/src/rm_auto_aim/armor_mpc_solver/include/armor_mpc_solver/solver_comparison_node.hpp @@ -0,0 +1,43 @@ +// Copyright (C) FYT Vision Group. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ARMOR_MPC_SOLVER_SOLVER_COMPARISON_NODE_HPP_ +#define ARMOR_MPC_SOLVER_SOLVER_COMPARISON_NODE_HPP_ + +#include +#include +#include +#include "armor_mpc_solver/solver_comparer.hpp" + +namespace fyt::auto_aim { + +class SolverComparisonNode : public rclcpp::Node { +public: + SolverComparisonNode(); + +private: + void targetCallback(const rm_interfaces::msg::Target::SharedPtr msg); + void toggleCallback(const std_msgs::msg::Bool::SharedPtr msg); + + std::unique_ptr comparer_; + + rclcpp::Subscription::SharedPtr target_sub_; + rclcpp::Subscription::SharedPtr toggle_sub_; + + rclcpp::TimerBase::SharedPtr timer_; +}; + +} // namespace fyt::auto_aim + +#endif // ARMOR_MPC_SOLVER_SOLVER_COMPARISON_NODE_HPP_ diff --git a/src/rm_auto_aim/armor_mpc_solver/package.xml b/src/rm_auto_aim/armor_mpc_solver/package.xml new file mode 100644 index 0000000..1c9bea7 --- /dev/null +++ b/src/rm_auto_aim/armor_mpc_solver/package.xml @@ -0,0 +1,26 @@ + + + + armor_mpc_solver + 0.1.0 + MPC-based armor solver with TinyMPC + Chen Youyuan + Apache-2.0 + + ament_cmake + + rclcpp + rm_interfaces + rm_utils + rm_tinympc + Eigen3 + tf2 + tf2_ros + visualization_msgs + geometry_msgs + std_msgs + + + ament_cmake + + diff --git a/src/rm_auto_aim/armor_mpc_solver/src/solver.cpp b/src/rm_auto_aim/armor_mpc_solver/src/solver.cpp new file mode 100644 index 0000000..e6f4fdf --- /dev/null +++ b/src/rm_auto_aim/armor_mpc_solver/src/solver.cpp @@ -0,0 +1,667 @@ +// Copyright (C) FYT Vision Group. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "armor_mpc_solver/solver.hpp" +#include "rm_utils/logger/log.hpp" +#include +#include +#include + +namespace fyt::auto_aim { + +// ============== QuinticSegment Implementation ============== + +double QuinticSegment::position(double t) const { + return a0 + a1 * t + a2 * t * t + a3 * t * t * t + a4 * t * t * t * t + a5 * t * t * t * t * t; +} + +double QuinticSegment::velocity(double t) const { + return a1 + 2 * a2 * t + 3 * a3 * t * t + 4 * a4 * t * t * t + 5 * a5 * t * t * t * t; +} + +double QuinticSegment::acceleration(double t) const { + return 2 * a2 + 6 * a3 * t + 12 * a4 * t * t + 20 * a5 * t * t * t; +} + +QuinticSegment QuinticSegment::fromBoundaryConditions( + double p0, double v0, double a0, + double p1, double v1, double a1, + double T) { + QuinticSegment seg; + double T2 = T * T; + double T3 = T2 * T; + double T4 = T3 * T; + double T5 = T4 * T; + + // Solve for coefficients using boundary conditions + seg.a0 = p0; + seg.a1 = v0; + seg.a2 = a0 / 2.0; + seg.a3 = (20 * (p1 - p0) - (8 * v1 + 12 * v0) * T - (3 * a0 - a1) * T2) / (2 * T3); + seg.a4 = (30 * (p0 - p1) + (14 * v1 + 16 * v0) * T + (3 * a0 - 2 * a1) * T2) / (2 * T4); + seg.a5 = (12 * (p1 - p0) - (6 * v1 + 6 * v0) * T - (a0 - a1) * T2) / (2 * T5); + + return seg; +} + +// ============== LimitTrajectory Implementation ============== + +LimitTrajectory::LimitTrajectory(double max_yaw_acc, double max_pitch_acc, double dt) + : max_yaw_acc_(max_yaw_acc), max_pitch_acc_(max_pitch_acc), dt_(dt) {} + +double LimitTrajectory::trapezoidalTime(double dist, double max_vel, double max_acc) { + // Time for trapezoidal velocity profile + // v_max = sqrt(dist * acc) if dist > v_max^2 / acc + double t_acc = max_vel / max_acc; + double dist_at_t_acc = max_acc * t_acc * t_acc; + + if (dist <= dist_at_t_acc) { + // Triangle profile: accelerate then decelerate + return 2.0 * std::sqrt(dist / max_acc); + } else { + // Trapezoidal: accel + coast + decel + double t_coast = (dist - dist_at_t_acc) / max_vel; + return 2.0 * t_acc + t_coast; + } +} + +QuinticSegment LimitTrajectory::quinticPlan( + double p0, double v0, double a0, + double p1, double v1, double a1, + double T) { + return QuinticSegment::fromBoundaryConditions(p0, v0, a0, p1, v1, a1, T); +} + +std::vector LimitTrajectory::generate( + double start_yaw, double start_pitch, + double target_yaw, double target_pitch, + double max_vel_yaw, double max_vel_pitch) { + + std::vector trajectory; + + // Compute angle differences + double dyaw = angles::shortest_angular_distance(start_yaw, target_yaw); + double dpitch = target_pitch - start_pitch; + + // Time for each axis using trapezoidal profile + double time_yaw = trapezoidalTime(std::abs(dyaw), max_vel_yaw, max_yaw_acc_); + double time_pitch = trapezoidalTime(std::abs(dpitch), max_vel_pitch, max_pitch_acc_); + + // Use longer time and synchronize + double T = std::max(time_yaw, time_pitch); + T = std::max(T, 0.1); // Minimum time + + // Generate quintic trajectories + QuinticSegment yaw_traj = quinticPlan(start_yaw, 0, 0, target_yaw, 0, 0, T); + QuinticSegment pitch_traj = quinticPlan(start_pitch, 0, 0, target_pitch, 0, 0, T); + + // Sample trajectory + int num_steps = static_cast(T / dt_) + 1; + for (int i = 0; i <= num_steps; ++i) { + double t = i * dt_; + if (t > T) t = T; + + TrajectoryPoint pt; + pt.t = t; + pt.yaw = yaw_traj.position(t); + pt.pitch = pitch_traj.position(t); + pt.v_yaw = yaw_traj.velocity(t); + pt.v_pitch = pitch_traj.velocity(t); + pt.a_yaw = yaw_traj.acceleration(t); + pt.a_pitch = pitch_traj.acceleration(t); + + trajectory.push_back(pt); + } + + return trajectory; +} + +// ============== MpcSolver Implementation ============== + +MpcSolver::MpcSolver(std::weak_ptr n) +: node_(n), bullet_speed_(20.0), gravity_(9.8), prediction_horizon_(2.0), dt_(0.004), + max_yaw_rate_(6.0), max_pitch_rate_(6.0), max_yaw_acc_(40.0), max_pitch_acc_(25.0), + control_delay_(0.2), delay_enable_fire_error_(0.0035), + yaw_limit_deg_(60.0), shooting_range_h_(0.12), shooting_range_small_w_(0.12), + shooting_range_big_w_(0.24), min_enable_pitch_deg_(0.25), min_enable_yaw_deg_(0.25), + comming_angle_deg_(60.0), leaving_angle_deg_(20.0) { + + auto node = node_.lock(); + if (node) { + // Basic parameters (matching wust_vision) + bullet_speed_ = node->declare_parameter("mpc.bullet_speed", 20.0); + gravity_ = node->declare_parameter("mpc.gravity", 9.8); + prediction_horizon_ = node->declare_parameter("mpc.sample_total_time", 2.0); + int sample_horizon = node->declare_parameter("mpc.sample_horizon", 500); + dt_ = prediction_horizon_ / sample_horizon; + + max_yaw_acc_ = node->declare_parameter("mpc.max_yaw_acc", 40.0); + max_pitch_acc_ = node->declare_parameter("mpc.max_pitch_acc", 25.0); + + control_delay_ = node->declare_parameter("mpc.control_delay", 0.2); + delay_enable_fire_error_ = node->declare_parameter("mpc.delay_enable_fire_error", 0.0035); + + // Fire decision parameters + yaw_limit_deg_ = node->declare_parameter("mpc.yaw_limit_deg", 60.0); + shooting_range_h_ = node->declare_parameter("mpc.shooting_range_h", 0.12); + shooting_range_small_w_ = node->declare_parameter("mpc.shooting_range_small_w", 0.12); + shooting_range_big_w_ = node->declare_parameter("mpc.shooting_range_big_w", 0.24); + min_enable_pitch_deg_ = node->declare_parameter("mpc.min_enable_pitch_deg", 0.25); + min_enable_yaw_deg_ = node->declare_parameter("mpc.min_enable_yaw_deg", 0.25); + comming_angle_deg_ = node->declare_parameter("mpc.comming_angle", 60.0); + leaving_angle_deg_ = node->declare_parameter("mpc.leaving_angle", 20.0); + + // Cost weights - separated for yaw and pitch (matching wust_vision) + // Default: Q_yaw = [7e6, 0], R_yaw = [3.0] + auto q_yaw_vec = node->declare_parameter("mpc.Q_yaw", std::vector{7e6, 0.0}); + auto r_yaw_vec = node->declare_parameter("mpc.R_yaw", std::vector{3.0}); + auto q_pitch_vec = node->declare_parameter("mpc.Q_pitch", std::vector{7e6, 0.0}); + auto r_pitch_vec = node->declare_parameter("mpc.R_pitch", std::vector{3.0}); + + Q_yaw_ << q_yaw_vec[0], q_yaw_vec[1]; + R_yaw_ << r_yaw_vec[0]; + Q_pitch_ << q_pitch_vec[0], q_pitch_vec[1]; + R_pitch_ << r_pitch_vec[0]; + + // Terminal cost is 2x stage cost + Qf_yaw_ = 2.0 * Q_yaw_; + Qf_pitch_ = 2.0 * Q_pitch_; + + FYT_INFO("armor_mpc_solver", "MPC Solver initialized (wust_vision params):"); + FYT_INFO("armor_mpc_solver", " prediction_horizon={:.3f}s, dt={:.4f}s, horizon={}", + prediction_horizon_, dt_, sample_horizon); + FYT_INFO("armor_mpc_solver", " max_yaw_acc={:.2f} rad/s^2, max_pitch_acc={:.2f} rad/s^2", + max_yaw_acc_, max_pitch_acc_); + FYT_INFO("armor_mpc_solver", " control_delay={:.3f}s, fire_error={:.4f}", + control_delay_, delay_enable_fire_error_); + FYT_INFO("armor_mpc_solver", " Q_yaw=[{:.1e}, {:.1e}], R_yaw=[{:.1e}]", + Q_yaw_(0), Q_yaw_(1), R_yaw_(0)); + FYT_INFO("armor_mpc_solver", " Q_pitch=[{:.1e}, {:.1e}], R_pitch=[{:.1e}]", + Q_pitch_(0), Q_pitch_(1), R_pitch_(0)); + + // Initialize trajectory compensator + std::string compensator_type = node->declare_parameter("mpc.trajectory_compensator", "resistance"); + initTrajectoryCompensator(compensator_type); + + // Initialize manual compensator for angle offset + manual_compensator_ = std::make_unique(); + auto angle_offset = node->declare_parameter("mpc.trajectory_offset", std::vector{}); + if (!manual_compensator_->updateMapFlow(angle_offset)) { + FYT_DEBUG("armor_mpc_solver", "Manual compensator update skipped (empty config)"); + } + } + node.reset(); + + initADMM(); + + current_gimbal_state_ << 0.0, 0.0; + current_gimbal_velocity_ << 0.0, 0.0; + + // Initialize limit trajectory generator + limit_trajectory_generator_ = std::make_unique( + max_yaw_acc_, max_pitch_acc_, dt_); +} + +void MpcSolver::initADMM() { + int horizon = static_cast(prediction_horizon_ / dt_); + + // Separate MPC for yaw: State [yaw, yaw_rate], Input [yaw_acc] + // Separate MPC for pitch: State [pitch, pitch_rate], Input [pitch_acc] + static constexpr int N_X_SEP = 2; // [angle, angular_rate] + static constexpr int N_U_SEP = 1; // [angular_acc] + + admm_solver_yaw_ = std::make_unique(); + admm_solver_pitch_ = std::make_unique(); + + admm_solver_yaw_->init(N_X_SEP, N_U_SEP, horizon); + admm_solver_pitch_->init(N_X_SEP, N_U_SEP, horizon); + + // State transition for single axis: x = [angle, angular_rate] + // angle_{k+1} = angle_k + angular_rate_k * dt + // angular_rate_{k+1} = angular_rate_k + angular_acc_k * dt + Eigen::MatrixXd A_sep = Eigen::MatrixXd::Identity(N_X_SEP, N_X_SEP); + A_sep(0, 1) = dt_; // angle += angular_rate * dt + + Eigen::MatrixXd B_sep = Eigen::MatrixXd::Zero(N_X_SEP, N_U_SEP); + B_sep(0, 0) = 0.5 * dt_ * dt_; // angle += 0.5 * angular_acc * dt^2 + B_sep(1, 0) = dt_; // angular_rate += angular_acc * dt + + // Set problem matrices for both solvers using separated cost weights + admm_solver_yaw_->setProblem(A_sep, B_sep, Q_yaw_, R_yaw_, Qf_yaw_); + admm_solver_pitch_->setProblem(A_sep, B_sep, Q_pitch_, R_pitch_, Qf_pitch_); + + // Set constraints for yaw + Eigen::Vector2d x_min_yaw, x_max_yaw; + Eigen::VectorXd u_min_yaw(1), u_max_yaw(1); + x_min_yaw << -M_PI, -max_yaw_rate_; + x_max_yaw << M_PI, max_yaw_rate_; + u_min_yaw(0) = -max_yaw_acc_; + u_max_yaw(0) = max_yaw_acc_; + admm_solver_yaw_->setConstraints(x_min_yaw, x_max_yaw, u_min_yaw, u_max_yaw); + + // Set constraints for pitch + Eigen::Vector2d x_min_pitch, x_max_pitch; + Eigen::VectorXd u_min_pitch(1), u_max_pitch(1); + x_min_pitch << -M_PI_2, -max_pitch_rate_; + x_max_pitch << M_PI_2, max_pitch_rate_; + u_min_pitch(0) = -max_pitch_acc_; + u_max_pitch(0) = max_pitch_acc_; + admm_solver_pitch_->setConstraints(x_min_pitch, x_max_pitch, u_min_pitch, u_max_pitch); +} + + admm_solver_->setConstraints(x_min, x_max, u_min, u_max); +} + +void MpcSolver::initTrajectoryCompensator(const std::string& compensator_type) { + trajectory_compensator_ = fyt::CompensatorFactory::createCompensator(compensator_type); + + if (trajectory_compensator_) { + trajectory_compensator_->velocity = bullet_speed_; + trajectory_compensator_->gravity = gravity_; + FYT_INFO("armor_mpc_solver", "Trajectory compensator initialized: {}", compensator_type); + } else { + FYT_WARN("armor_mpc_solver", "Failed to create trajectory compensator, using default"); + trajectory_compensator_ = std::make_unique(); + trajectory_compensator_->velocity = bullet_speed_; + trajectory_compensator_->gravity = gravity_; + } +} + +rm_interfaces::msg::GimbalCmd MpcSolver::solve( + const rm_interfaces::msg::Target& target, + const rclcpp::Time& current_time, + std::shared_ptr tf2_buffer) { + + Eigen::Vector3d target_pos(target.position.x, target.position.y, target.position.z); + Eigen::Vector3d target_vel(target.velocity.x, target.velocity.y, target.velocity.z); + double target_yaw = target.yaw; + double v_yaw = target.v_yaw; + + // Select armor based on coming/leaving angle + auto armor_result = selectArmor(target, current_time); + + // Compute flying time with compensation + double flying_time = computeFlyingTime(target_pos); + + // Predict target position at flying_time + Eigen::Vector3d predicted_pos = predictTargetPosition(target_pos, target_vel, flying_time); + double predicted_yaw = target_yaw + flying_time * v_yaw; + + // Convert target to gimbal angles (yaw, pitch) + Eigen::Vector2d target_gimbal; + target_gimbal(0) = std::atan2(predicted_pos.y(), predicted_pos.x()); // yaw + + // Calculate pitch with trajectory compensation + double raw_pitch = std::atan2(predicted_pos.z(), predicted_pos.head(2).norm()); + if (trajectory_compensator_) { + trajectory_compensator_->compensate(predicted_pos, raw_pitch); + } + target_gimbal(1) = raw_pitch; + + // Apply manual compensator angle offset + if (manual_compensator_) { + double dist = predicted_pos.head(2).norm(); + auto offsets = manual_compensator_->angleHardCorrect(dist, predicted_pos.z()); + if (offsets.size() >= 2) { + target_gimbal(0) += offsets[1] * M_PI / 180.0; // yaw offset + target_gimbal(1) += offsets[0] * M_PI / 180.0; // pitch offset + } + } + + // Current state: [yaw, pitch, yaw_rate, pitch_rate] + Eigen::Vector4d x0 = Eigen::Vector4d::Zero(); + x0(0) = current_gimbal_state_(0); + x0(1) = current_gimbal_state_(1); + x0(2) = current_gimbal_velocity_(0); // Use filtered velocity + x0(3) = current_gimbal_velocity_(1); + + // Compute reference trajectory over horizon + Eigen::MatrixXd xref = computeReferenceTrajectory( + predicted_pos, target_vel, predicted_yaw, v_yaw, flying_time); + + // Solve separated MPC for yaw and pitch + Eigen::Vector2d u_optimal = solveMPC(x0, xref); + + // Build trajectories for visualization + Eigen::MatrixXd x_traj_yaw = admm_solver_yaw_->getStateTrajectory(); + Eigen::MatrixXd x_traj_pitch = admm_solver_pitch_->getStateTrajectory(); + Eigen::MatrixXd u_traj_yaw = admm_solver_yaw_->getControlSequence(); + Eigen::MatrixXd u_traj_pitch = admm_solver_pitch_->getControlSequence(); + buildMpcVisualizationTrajectory(x0, x_traj_yaw, x_traj_pitch, u_traj_yaw, u_traj_pitch); + buildReferenceVisualizationTrajectory(xref); + + // Compute limit trajectory for comparison + computeAccelConstrainedTrajectory(current_gimbal_state_, target_gimbal, + Eigen::Vector2d(max_yaw_rate_, max_pitch_rate_), + Eigen::Vector2d(max_yaw_acc_, max_pitch_acc_)); + + // Update current gimbal velocity first (acceleration * dt = delta_velocity) + current_gimbal_velocity_(0) += u_optimal(0) * dt_; + current_gimbal_velocity_(1) += u_optimal(1) * dt_; + + // Apply first control input using updated velocity + Eigen::Vector2d cmd_angles = computeGimbalCommandFromControl(u_optimal); + + // Update current gimbal state + current_gimbal_state_ = cmd_angles; + + // Convert to gimbal command message + auto gimbal_cmd = toGimbalCmd(cmd_angles, predicted_pos, target); + + // Check if can fire considering control delay + gimbal_cmd.fire_advice = canFireAtTime(target, flying_time, current_time); + + return gimbal_cmd; +} + +Eigen::MatrixXd MpcSolver::computeReferenceTrajectory( + const Eigen::Vector3d& target_pos, + const Eigen::Vector3d& target_vel, + double target_yaw, + double v_yaw, + double flying_time) { + + int horizon = admm_solver_yaw_->getStateTrajectory().cols(); + Eigen::MatrixXd xref(N_X, horizon + 1); + + int num_steps = static_cast(horizon); + double t = 0.0; + + for (int i = 0; i <= num_steps; ++i) { + double dt_i = i * dt_; + + // Predict target position at dt_i + Eigen::Vector3d pred_pos = predictTargetPosition(target_pos, target_vel, dt_i); + double pred_yaw = target_yaw + dt_i * v_yaw; + + // Convert to gimbal angles + xref(0, i) = std::atan2(pred_pos.y(), pred_pos.x()); // yaw + xref(1, i) = std::atan2(pred_pos.z(), pred_pos.head(2).norm()); // pitch + xref(2, i) = 0.0; // yaw_rate (reference is stationary in gimbal frame) + xref(3, i) = 0.0; // pitch_rate + } + + return xref; +} + +Eigen::Vector2d MpcSolver::solveMPC(const Eigen::Vector4d& x0, const Eigen::MatrixXd& xref) { + // Solve separated MPC for yaw and pitch + // x0 = [yaw, pitch, yaw_rate, pitch_rate] + // xref columns: [yaw_ref, pitch_ref, yaw_rate_ref=0, pitch_rate_ref=0] + + static constexpr int N_X_SEP = 2; + + // Initial state for yaw MPC: [yaw, yaw_rate] + Eigen::VectorXd x0_yaw(N_X_SEP); + x0_yaw(0) = x0(0); + x0_yaw(1) = x0(2); + + // Initial state for pitch MPC: [pitch, pitch_rate] + Eigen::VectorXd x0_pitch(N_X_SEP); + x0_pitch(0) = x0(1); + x0_pitch(1) = x0(3); + + // Get horizon from solver + int horizon = admm_solver_yaw_->getStateTrajectory().cols(); + + // Reference trajectories for yaw and pitch + Eigen::MatrixXd xref_yaw(N_X_SEP, horizon + 1); + Eigen::MatrixXd xref_pitch(N_X_SEP, horizon + 1); + + for (int i = 0; i <= horizon; ++i) { + xref_yaw(0, i) = xref(0, i); // yaw + xref_yaw(1, i) = xref(2, i); // yaw_rate + xref_pitch(0, i) = xref(1, i); // pitch + xref_pitch(1, i) = xref(3, i); // pitch_rate + } + + // Solve yaw MPC + admm_solver_yaw_->setInitialState(x0_yaw); + admm_solver_yaw_->setReference(xref_yaw); + bool yaw_converged = admm_solver_yaw_->solve(); + + if (!yaw_converged) { + FYT_WARN("armor_mpc_solver", "Yaw MPC solver did not converge!"); + } + + // Solve pitch MPC + admm_solver_pitch_->setInitialState(x0_pitch); + admm_solver_pitch_->setReference(xref_pitch); + bool pitch_converged = admm_solver_pitch_->solve(); + + if (!pitch_converged) { + FYT_WARN("armor_mpc_solver", "Pitch MPC solver did not converge!"); + } + + // Get first optimal control [yaw_acc, pitch_acc] + Eigen::VectorXd u_yaw = admm_solver_yaw_->getFirstInput(); + Eigen::VectorXd u_pitch = admm_solver_pitch_->getFirstInput(); + + Eigen::Vector2d u_optimal; + u_optimal(0) = u_yaw(0); + u_optimal(1) = u_pitch(0); + + return u_optimal; +} + +Eigen::Vector2d MpcSolver::computeGimbalCommandFromControl(const Eigen::Vector2d& u) { + // Control input u = [yaw_acc, pitch_acc] + // Semi-implicit Euler integration (symplectic): + // velocity_{k+1} = velocity_k + accel * dt + // angle_{k+1} = angle_k + velocity_{k+1} * dt + + Eigen::Vector2d new_angles = current_gimbal_state_; + + // First compute new velocity (after acceleration) + Eigen::Vector2d new_velocity = current_gimbal_velocity_ + u * dt_; + + // Then compute angle using new velocity + new_angles(0) += new_velocity(0) * dt_; + new_angles(1) += new_velocity(1) * dt_; + + // Clamp to gimbal limits + new_angles(0) = std::max(-M_PI, std::min(M_PI, new_angles(0))); + new_angles(1) = std::max(-M_PI_2, std::min(M_PI_2, new_angles(1))); + + return new_angles; +} + +double MpcSolver::computeFlyingTime(const Eigen::Vector3d& target_pos) { + if (trajectory_compensator_) { + return trajectory_compensator_->getFlyingTime(target_pos); + } + // Fallback: simple calculation + double dist = target_pos.norm(); + return dist / bullet_speed_; +} + +Eigen::Vector3d MpcSolver::predictTargetPosition( + const Eigen::Vector3d& pos, + const Eigen::Vector3d& vel, + double dt) { + return pos + dt * vel; +} + +rm_interfaces::msg::GimbalCmd MpcSolver::toGimbalCmd( + const Eigen::Vector2d& gimbal_angles, + const Eigen::Vector3d& target_pos, + const rm_interfaces::msg::Target& target) { + + rm_interfaces::msg::GimbalCmd gimbal_cmd; + gimbal_cmd.header = target.header; + gimbal_cmd.header.stamp = rclcpp::Clock().now(); + gimbal_cmd.yaw = gimbal_angles(0) * 180.0 / M_PI; + gimbal_cmd.pitch = gimbal_angles(1) * 180.0 / M_PI; + gimbal_cmd.distance = target_pos.norm(); + + // Compute yaw_diff and pitch_diff (simplified - assumes gimbal feedback) + // In practice, these would come from actual gimbal feedback + gimbal_cmd.yaw_diff = 0.0; + gimbal_cmd.pitch_diff = 0.0; + + // Simplified fire_advice: fire if target is being tracked + gimbal_cmd.fire_advice = true; + + // shoot_rate will be set by the node that uses this solver + + return gimbal_cmd; +} + +void MpcSolver::buildMpcVisualizationTrajectory( + const Eigen::Vector4d& x0, + const Eigen::MatrixXd& x_traj_yaw, + const Eigen::MatrixXd& x_traj_pitch, + const Eigen::MatrixXd& u_traj_yaw, + const Eigen::MatrixXd& u_traj_pitch) { + + mpc_trajectory_.clear(); + + int horizon = x_traj_yaw.cols() - 1; + for (int i = 0; i <= horizon; ++i) { + TrajPoint pt; + pt.t = i * dt_; + pt.yaw = x_traj_yaw(0, i) * 180.0 / M_PI; // yaw in degrees + pt.pitch = x_traj_pitch(0, i) * 180.0 / M_PI; // pitch in degrees + pt.yaw_vel = x_traj_yaw(1, i) * 180.0 / M_PI; // yaw_rate in deg/s + pt.pitch_vel = x_traj_pitch(1, i) * 180.0 / M_PI; // pitch_rate in deg/s + mpc_trajectory_.push_back(pt); + } +} + +void MpcSolver::buildReferenceVisualizationTrajectory( + const Eigen::MatrixXd& xref) { + + reference_trajectory_.clear(); + + int horizon = xref.cols() - 1; + for (int i = 0; i <= horizon; ++i) { + TrajPoint pt; + pt.t = i * dt_; + pt.yaw = xref(0, i) * 180.0 / M_PI; // yaw in degrees + pt.pitch = xref(1, i) * 180.0 / M_PI; // pitch in degrees + pt.yaw_vel = xref(2, i) * 180.0 / M_PI; // yaw_rate in deg/s + pt.pitch_vel = xref(3, i) * 180.0 / M_PI; // pitch_rate in deg/s + reference_trajectory_.push_back(pt); + } +} + +ArmorSelectResult MpcSolver::selectArmor( + const rm_interfaces::msg::Target& target, + const rclcpp::Time& current_time) { + + ArmorSelectResult result; + result.armor_id = 0; + result.coming_angle = 0.0; + result.leaving_angle = 0.0; + result.is_switching = false; + + // Target provides position and velocity in Cartesian coordinates + Eigen::Vector3d target_pos(target.position.x, target.position.y, target.position.z); + Eigen::Vector3d target_vel(target.velocity.x, target.velocity.y, target.velocity.z); + + // Compute distance and approach angle + double dist = target_pos.norm(); + double approach_angle = std::atan2(target_pos.y(), target_pos.x()); + + // Estimate angular velocity from target velocity + // Angular velocity = (r × v) / |r|^2 where r is position vector + double omega = (target_pos.x() * target_vel.y() - target_pos.y() * target_vel.x()) / (dist * dist + 1e-6); + + // Flying time + double flying_time = computeFlyingTime(target_pos); + + // Coming angle: angle at which target is approaching + // Leaving angle: angle at which target will be when bullet arrives + double angle_at_flying_time = approach_angle + omega * flying_time; + double coming_angle = angles::shortest_angular_distance(approach_angle, angle_at_flying_time); + double leaving_angle = angles::shortest_angular_distance(angle_at_flying_time, approach_angle); + + result.coming_angle = coming_angle; + result.leaving_angle = leaving_angle; + + // Check if target is moving toward or away from robot + // Using configurable thresholds from wust_vision + double coming_thresh = comming_angle_deg_ * M_PI / 180.0; + double leaving_thresh = leaving_angle_deg_ * M_PI / 180.0; + + // is_switching: target is moving away fast or angle exceeds threshold + result.is_switching = (std::abs(coming_angle) > coming_thresh) || + (target_vel.norm() > 0.5 && std::abs(leaving_angle) < leaving_thresh); + + FYT_DEBUG("armor_mpc_solver", + "selectArmor: dist={:.2f}, coming={:.2f}deg, leaving={:.2f}deg, switching={}", + dist, coming_angle * 180 / M_PI, leaving_angle * 180 / M_PI, result.is_switching); + + return result; +} + +bool MpcSolver::canFireAtTime( + const rm_interfaces::msg::Target& target, + double time_to_target, + const rclcpp::Time& current_time) { + + // Get armor selection result + auto armor_result = selectArmor(target, current_time); + + // Account for control delay - add extra time buffer + double total_delay = control_delay_ + 0.05; // 50ms system latency buffer + + // Time when bullet will arrive + double fire_time = time_to_target; + + // Check if gimbal can reach target within fire_time considering acceleration limits + double dyaw = angles::shortest_angular_distance(current_gimbal_state_(0), + std::atan2(target.position.y, target.position.x)); + double dpitch = std::atan2(target.position.z, std::sqrt(target.position.x * target.position.x + + target.position.y * target.position.y)) - current_gimbal_state_(1); + + // Minimum time needed to reach target given acceleration constraints + double min_time_yaw = 2.0 * std::sqrt(std::abs(dyaw) / max_yaw_acc_); + double min_time_pitch = 2.0 * std::sqrt(std::abs(dpitch) / max_pitch_acc_); + double min_time_needed = std::max(min_time_yaw, min_time_pitch); + + // Can fire if we have enough time and target is not switching rapidly + bool can_fire = (fire_time >= min_time_needed + total_delay) && !armor_result.is_switching; + + FYT_DEBUG("armor_mpc_solver", + "canFireAtTime: time_to_target={:.3f}, min_time={:.3f}, can_fire={}", + fire_time, min_time_needed, can_fire); + + return can_fire; +} + +std::vector MpcSolver::computeLimitTrajectory( + double target_yaw, double target_pitch) { + + if (!limit_trajectory_generator_) { + return {}; + } + + return limit_trajectory_generator_->generate( + current_gimbal_state_(0), current_gimbal_state_(1), + target_yaw, target_pitch, + max_yaw_rate_, max_pitch_rate_); +} + +void MpcSolver::computeAccelConstrainedTrajectory( + const Eigen::Vector2d& start, + const Eigen::Vector2d& target, + const Eigen::Vector2d& max_vel, + const Eigen::Vector2d& max_acc) { + + // Generate limit trajectory for visualization + limit_trajectory_ = computeLimitTrajectory(target(0), target(1)); +} + +} // namespace fyt::auto_aim diff --git a/src/rm_auto_aim/armor_mpc_solver/src/solver_comparer.cpp b/src/rm_auto_aim/armor_mpc_solver/src/solver_comparer.cpp new file mode 100644 index 0000000..18e1094 --- /dev/null +++ b/src/rm_auto_aim/armor_mpc_solver/src/solver_comparer.cpp @@ -0,0 +1,154 @@ +// Copyright (C) FYT Vision Group. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "armor_mpc_solver/solver_comparer.hpp" +#include "rm_utils/logger/log.hpp" +#include +#include + +namespace fyt::auto_aim { + +SolverComparer::SolverComparer(std::weak_ptr n) +: node_(n), marker_id_(0), frame_id_("odom"), use_mpc_(false) { + auto node = node_.lock(); + if (node) { + use_mpc_ = node->declare_parameter("comparer.use_mpc", false); + } + node.reset(); +} + +void SolverComparer::init() { + auto node = node_.lock(); + if (!node) return; + + original_solver_ = std::make_unique(node_); + mpc_solver_ = std::make_unique(node_); + + tf2_buffer_ = std::make_shared(node->get_clock()); + tf2_listener_ = std::make_shared(*tf2_buffer_); + + trajectory_pub_ = node->create_publisher( + "/armor_solver/comparison_trajectory", 10); + mpc_gimbal_pub_ = node->create_publisher( + "/armor_solver/mpc_gimbal_cmd", 10); + + node.reset(); +} + +void SolverComparer::update(const rm_interfaces::msg::Target::SharedPtr target_msg) { + if (!target_msg) return; + + try { + auto current_time = rclcpp::Time(target_msg->header.stamp); + + if (original_solver_) { + last_original_cmd_ = original_solver_->solve(*target_msg, current_time, tf2_buffer_); + } + + if (mpc_solver_) { + last_mpc_cmd_ = mpc_solver_->solve(*target_msg, current_time, tf2_buffer_); + mpc_gimbal_pub_->publish(last_mpc_cmd_); + + last_mpc_trajectory_ = mpc_solver_->getMpcTrajectory(); + last_reference_trajectory_ = mpc_solver_->getReferenceTrajectory(); + last_limit_trajectory_ = mpc_solver_->getLimitTrajectory(); + } + + publishComparisonMarkers(); + + } catch (const std::exception& e) { + FYT_ERROR("armor_mpc_solver", "Solver comparison error: {}", e.what()); + } +} + +void SolverComparer::publishComparisonMarkers() { + visualization_msgs::msg::MarkerArray marker_array; + + if (!last_mpc_trajectory_.empty()) { + visualization_msgs::msg::Marker mpc_line; + mpc_line.header.frame_id = frame_id_; + mpc_line.header.stamp = rclcpp::Clock().now(); + mpc_line.ns = "mpc_trajectory"; + mpc_line.type = visualization_msgs::msg::Marker::LINE_STRIP; + mpc_line.action = visualization_msgs::msg::Marker::ADD; + mpc_line.scale.x = 0.02; + mpc_line.color.r = 1.0; + mpc_line.color.g = 0.0; + mpc_line.color.b = 0.0; + mpc_line.color.a = 1.0; + + for (const auto& pt : last_mpc_trajectory_) { + geometry_msgs::msg::Point p; + p.x = pt.t; + p.y = pt.yaw * 180.0 / M_PI; + p.z = pt.pitch * 180.0 / M_PI; + mpc_line.points.push_back(p); + } + mpc_line.id = marker_id_++; + marker_array.markers.push_back(mpc_line); + } + + if (!last_reference_trajectory_.empty()) { + visualization_msgs::msg::Marker ref_line; + ref_line.header.frame_id = frame_id_; + ref_line.header.stamp = rclcpp::Clock().now(); + ref_line.ns = "reference_trajectory"; + ref_line.type = visualization_msgs::msg::Marker::LINE_STRIP; + ref_line.action = visualization_msgs::msg::Marker::ADD; + ref_line.scale.x = 0.02; + ref_line.color.r = 0.0; + ref_line.color.g = 1.0; + ref_line.color.b = 0.0; + ref_line.color.a = 1.0; + + for (const auto& pt : last_reference_trajectory_) { + geometry_msgs::msg::Point p; + p.x = pt.t; + p.y = pt.yaw * 180.0 / M_PI; + p.z = pt.pitch * 180.0 / M_PI; + ref_line.points.push_back(p); + } + ref_line.id = marker_id_++; + marker_array.markers.push_back(ref_line); + } + + // Publish limit trajectory (acceleration-constrained) in blue + if (!last_limit_trajectory_.empty()) { + visualization_msgs::msg::Marker limit_line; + limit_line.header.frame_id = frame_id_; + limit_line.header.stamp = rclcpp::Clock().now(); + limit_line.ns = "limit_trajectory"; + limit_line.type = visualization_msgs::msg::Marker::LINE_STRIP; + limit_line.action = visualization_msgs::msg::Marker::ADD; + limit_line.scale.x = 0.02; + limit_line.color.r = 0.0; + limit_line.color.g = 0.0; + limit_line.color.b = 1.0; + limit_line.color.a = 1.0; + + for (const auto& pt : last_limit_trajectory_) { + geometry_msgs::msg::Point p; + p.x = pt.t; + p.y = pt.yaw * 180.0 / M_PI; + p.z = pt.pitch * 180.0 / M_PI; + limit_line.points.push_back(p); + } + limit_line.id = marker_id_++; + marker_array.markers.push_back(limit_line); + } + + trajectory_pub_->publish(marker_array); +} + +} // namespace fyt::auto_aim diff --git a/src/rm_auto_aim/armor_mpc_solver/src/solver_comparison_node.cpp b/src/rm_auto_aim/armor_mpc_solver/src/solver_comparison_node.cpp new file mode 100644 index 0000000..869a300 --- /dev/null +++ b/src/rm_auto_aim/armor_mpc_solver/src/solver_comparison_node.cpp @@ -0,0 +1,50 @@ +// Copyright (C) FYT Vision Group. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "armor_mpc_solver/solver_comparison_node.hpp" + +namespace fyt::auto_aim { + +SolverComparisonNode::SolverComparisonNode() +: rclcpp::Node("solver_comparison_node") { + comparer_ = std::make_unique(shared_from_this()); + comparer_->init(); + + target_sub_ = this->create_subscription( + "/armor_solver/target", + 10, + std::bind(&SolverComparisonNode::targetCallback, this, std::placeholders::_1)); + + toggle_sub_ = this->create_subscription( + "/armor_solver/toggle_mpc", + 10, + std::bind(&SolverComparisonNode::toggleCallback, this, std::placeholders::_1)); +} + +void SolverComparisonNode::targetCallback(const rm_interfaces::msg::Target::SharedPtr msg) { + comparer_->update(msg); +} + +void SolverComparisonNode::toggleCallback(const std_msgs::msg::Bool::SharedPtr msg) { + // Toggle between original solver and MPC solver +} + +} // namespace fyt::auto_aim + +int main(int argc, char** argv) { + rclcpp::init(argc, argv); + rclcpp::spin(std::make_shared()); + rclcpp::shutdown(); + return 0; +} diff --git a/src/rm_auto_aim/armor_solver/include/armor_solver/armor_solver_node.hpp b/src/rm_auto_aim/armor_solver/include/armor_solver/armor_solver_node.hpp index 12943f5..b4a597a 100644 --- a/src/rm_auto_aim/armor_solver/include/armor_solver/armor_solver_node.hpp +++ b/src/rm_auto_aim/armor_solver/include/armor_solver/armor_solver_node.hpp @@ -84,9 +84,11 @@ private: // Adaptive Q matrix parameters double s2qxyz_max_, s2qxyz_min_; // Position noise range double s2qyaw_max_, s2qyaw_min_; // Yaw noise range - double s2qr_, s2qd_zc_; // Radius and height offset noise - // R matrix parameters - double r_x_, r_y_, r_z_, r_yaw_; + double s2qr_, s2qd_zc_, s2qd_za_; // Radius and height offset noise + // R matrix parameters (Spherical coordinates: yaw, pitch, dist, ori_yaw) + double r_yaw_, r_pitch_, r_dist_, r_ori_yaw_; + // Adaptive R scaling for visibility (front vs side armor) + double r_front_scale_, r_side_scale_; double lost_time_thres_; std::unique_ptr tracker_; diff --git a/src/rm_auto_aim/armor_solver/include/armor_solver/armor_tracker.hpp b/src/rm_auto_aim/armor_solver/include/armor_solver/armor_tracker.hpp index 740b654..d15c141 100644 --- a/src/rm_auto_aim/armor_solver/include/armor_solver/armor_tracker.hpp +++ b/src/rm_auto_aim/armor_solver/include/armor_solver/armor_tracker.hpp @@ -70,7 +70,8 @@ public: Armor tracked_armor; std::string tracked_id; ArmorsNum tracked_armors_num; - Eigen::VectorXd measurement; + Eigen::VectorXd measurement; // Cartesian [x, y, z, yaw] for debug publishing + Eigen::VectorXd spherical_measurement_; // Spherical [yaw, pitch, dist, ori_yaw] for EKF Eigen::VectorXd target_state; // To store another pair of armors message @@ -79,6 +80,9 @@ public: // To store offset relative to the reference plane double d_zc; + // Last armor type for adaptive noise + std::string last_armor_type_; + private: void initEKF(const Armor &a) noexcept; @@ -88,6 +92,12 @@ private: static Eigen::Vector3d getArmorPositionFromState(const Eigen::VectorXd &x) noexcept; + // Convert Cartesian measurement to spherical + static Eigen::Vector4d cartesianToSpherical(double x, double y, double z, double yaw); + + // Adaptive noise based on armor visibility + void updateAdaptiveNoise(const Armor &armor) noexcept; + double max_match_distance_; double max_match_yaw_diff_; diff --git a/src/rm_auto_aim/armor_solver/include/armor_solver/motion_model.hpp b/src/rm_auto_aim/armor_solver/include/armor_solver/motion_model.hpp index 809e255..687b2c1 100644 --- a/src/rm_auto_aim/armor_solver/include/armor_solver/motion_model.hpp +++ b/src/rm_auto_aim/armor_solver/include/armor_solver/motion_model.hpp @@ -28,8 +28,10 @@ enum class MotionModel { CONSTANT_VEL_ROT = 2 // Constant velocity and rotation velocity }; -// X_N: state dimension, Z_N: measurement dimension -constexpr int X_N = 10, Z_N = 4; +// X_N: state dimension (11), Z_N: measurement dimension (4) +// State: [xc, vxc, yc, vyc, zc, vzc, yaw, vyaw, r, d_zc, d_za] +// Measurement: [yaw, pitch, distance, ori_yaw] (spherical coordinates) +constexpr int X_N = 11, Z_N = 4; struct Predict { explicit Predict(double dt, MotionModel model = MotionModel::CONSTANT_VEL_ROT) @@ -44,9 +46,9 @@ struct Predict { // v_xyz if (model == MotionModel::CONSTANT_VEL_ROT || model == MotionModel::CONSTANT_VELOCITY) { // linear velocity - x1[0] += x0[1] * dt; - x1[2] += x0[3] * dt; - x1[4] += x0[5] * dt; + x1[0] += x0[1] * dt; // xc += vxc * dt + x1[2] += x0[3] * dt; // yc += vyc * dt + x1[4] += x0[5] * dt; // zc += vzc * dt } else { // no velocity x1[1] *= 0.; @@ -57,7 +59,7 @@ struct Predict { // v_yaw if (model == MotionModel::CONSTANT_VEL_ROT || model == MotionModel::CONSTANT_ROTATION) { // angular velocity - x1[6] += x0[7] * dt; + x1[6] += x0[7] * dt; // yaw += vyaw * dt } else { // no rotation x1[7] *= 0.; @@ -68,12 +70,31 @@ struct Predict { MotionModel model; }; +// Spherical measurement model +// z = [yaw, pitch, distance, ori_yaw] struct Measure { + template + void operator()(const T x[Z_N], T z[Z_N]) { + // x[0]: xc, x[2]: yc, x[4]: zc, x[6]: yaw, x[8]: r, x[9]: d_zc, x[10]: d_za + T xa = x[0] - ceres::cos(x[6]) * x[8]; // armor x + T ya = x[2] - ceres::sin(x[6]) * x[8]; // armor y + T za = x[4] + x[9] + x[10]; // armor z + + // Convert Cartesian to spherical + z[0] = ceres::atan2(ya, xa); // yaw (azimuth angle) + z[1] = ceres::atan2(za, ceres::sqrt(xa * xa + ya * ya)); // pitch (elevation angle) + z[2] = ceres::sqrt(xa * xa + ya * ya + za * za); // distance + z[3] = x[6]; // ori_yaw (same as yaw for armor) + } +}; + +// Cartesian measurement model for comparison +struct MeasureCartesian { template void operator()(const T x[Z_N], T z[Z_N]) { z[0] = x[0] - ceres::cos(x[6]) * x[8]; z[1] = x[2] - ceres::sin(x[6]) * x[8]; - z[2] = x[4] + x[9]; + z[2] = x[4] + x[9] + x[10]; z[3] = x[6]; } }; @@ -81,4 +102,4 @@ struct Measure { using RobotStateEKF = ExtendedKalmanFilter; } // namespace fyt::auto_aim -#endif +#endif // ARMOR_SOLVER_MOTION_MODEL_HPP_ diff --git a/src/rm_auto_aim/armor_solver/src/armor_solver_node.cpp b/src/rm_auto_aim/armor_solver/src/armor_solver_node.cpp index 6f465d8..d63bbb4 100644 --- a/src/rm_auto_aim/armor_solver/src/armor_solver_node.cpp +++ b/src/rm_auto_aim/armor_solver/src/armor_solver_node.cpp @@ -21,6 +21,8 @@ // std #include #include +// third party +#include // project #include "armor_solver/motion_model.hpp" #include "rm_utils/common.hpp" @@ -70,6 +72,7 @@ ArmorSolverNode::ArmorSolverNode(const rclcpp::NodeOptions &options) s2qyaw_min_ = declare_parameter("ekf.sigma2_q_yaw_min", 50.0); s2qr_ = declare_parameter("ekf.sigma2_q_r", 800.0); s2qd_zc_ = declare_parameter("ekf.sigma2_q_d_zc", 800.0); + s2qd_za_ = declare_parameter("ekf.sigma2_q_d_za", 800.0); nis_window_size_ = declare_parameter("ekf.nis_window_size", 20); nis_adapt_range_ = declare_parameter("ekf.nis_adapt_range", 2.0); @@ -99,7 +102,7 @@ ArmorSolverNode::ArmorSolverNode(const rclcpp::NodeOptions &options) s2q_xyz *= q_scale; s2q_yaw *= q_scale; - double r = s2qr_ * q_scale, d_zc = s2qd_zc_ * q_scale; + double r = s2qr_ * q_scale, d_zc = s2qd_zc_ * q_scale, d_za = s2qd_za_ * q_scale; // White noise integral model for position-velocity state double q_x_x = pow(t, 4) / 4 * s2q_xyz, q_x_vx = pow(t, 3) / 2 * s2q_xyz, q_vx_vx = pow(t, 2) * s2q_xyz; @@ -109,39 +112,65 @@ ArmorSolverNode::ArmorSolverNode(const rclcpp::NodeOptions &options) q_vyaw_vyaw = pow(t, 2) * s2q_yaw; double q_r = pow(t, 4) / 4 * r; double q_d_zc = pow(t, 4) / 4 * d_zc; + double q_d_za = pow(t, 4) / 4 * d_za; // clang-format off - // xc v_xc yc v_yc zc v_zc yaw v_yaw r d_zc - q << q_x_x, q_x_vx, 0, 0, 0, 0, 0, 0, 0, 0, - q_x_vx, q_vx_vx,0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, q_y_y, q_y_vy, 0, 0, 0, 0, 0, 0, - 0, 0, q_y_vy, q_vy_vy,0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, q_z_z, q_z_vz, 0, 0, 0, 0, - 0, 0, 0, 0, q_z_vz, q_vz_vz,0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, q_yaw_yaw, q_yaw_vyaw, 0, 0, - 0, 0, 0, 0, 0, 0, q_yaw_vyaw, q_vyaw_vyaw,0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, q_r, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, q_d_zc; + // xc v_xc yc v_yc zc v_zc yaw v_yaw r d_zc d_za + q << q_x_x, q_x_vx, 0, 0, 0, 0, 0, 0, 0, 0, 0, + q_x_vx, q_vx_vx,0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, q_y_y, q_y_vy, 0, 0, 0, 0, 0, 0, 0, + 0, 0, q_y_vy, q_vy_vy,0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, q_z_z, q_z_vz, 0, 0, 0, 0, 0, + 0, 0, 0, 0, q_z_vz, q_vz_vz,0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, q_yaw_yaw, q_yaw_vyaw, 0, 0, 0, + 0, 0, 0, 0, 0, 0, q_yaw_vyaw, q_vyaw_vyaw,0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, q_r, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, q_d_zc, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, q_d_za; // clang-format on return q; }; - // update_R - measurement noise covariance matrix - // R scales with distance: farther target -> larger measurement noise - r_x_ = declare_parameter("ekf.r_x", 0.05); - r_y_ = declare_parameter("ekf.r_y", 0.05); - r_z_ = declare_parameter("ekf.r_z", 0.05); + // update_R - measurement noise covariance matrix (Spherical coordinates) + // z = [yaw, pitch, distance, ori_yaw] + // R scales with distance: farther target -> larger angular noise r_yaw_ = declare_parameter("ekf.r_yaw", 0.02); + r_pitch_ = declare_parameter("ekf.r_pitch", 0.02); + r_dist_ = declare_parameter("ekf.r_dist", 0.05); + r_ori_yaw_ = declare_parameter("ekf.r_ori_yaw", 0.02); + + // Adaptive R scaling factor based on armor visibility + // Front armor: smaller noise (more trust) + // Side armor: larger noise (less trust) + r_front_scale_ = declare_parameter("ekf.r_front_scale", 0.5); + r_side_scale_ = declare_parameter("ekf.r_side_scale", 2.0); + auto u_r = [this](const Eigen::Matrix &z) { Eigen::Matrix r; - // Calculate distance for better noise scaling - double dist = std::sqrt(z[0] * z[0] + z[1] * z[1] + z[2] * z[2]); + // z[0] = yaw, z[1] = pitch, z[2] = distance, z[3] = ori_yaw + double dist = z[2]; // distance is the 3rd element in spherical measurement // Minimum distance to prevent numerical issues when target is very close dist = std::max(dist, 1.0); + + // Angular noise scales with distance (smaller at close range) + double r_yaw_scaled = r_yaw_ * dist * dist; // rad^2 * m^2 + double r_pitch_scaled = r_pitch_ * dist * dist; + double r_dist_scaled = r_dist_ * dist; // m^2 + double r_ori_yaw_scaled = r_ori_yaw_ * dist * dist; + + // Apply visibility-based scaling if armor visibility info is available + // This is updated in tracker_->updateAdaptiveNoise() + double visibility_scale = 1.0; + if (tracker_->last_armor_type_ == "front") { + visibility_scale = r_front_scale_; + } else if (tracker_->last_armor_type_ == "side") { + visibility_scale = r_side_scale_; + } + // clang-format off - r << r_x_ * dist, 0, 0, 0, - 0, r_y_ * dist, 0, 0, - 0, 0, r_z_ * dist, 0, - 0, 0, 0, r_yaw_; + r << r_yaw_scaled * visibility_scale, 0, 0, 0, + 0, r_pitch_scaled * visibility_scale, 0, 0, + 0, 0, r_dist_scaled * visibility_scale, 0, + 0, 0, 0, r_ori_yaw_scaled * visibility_scale; // clang-format on return r; }; @@ -150,6 +179,20 @@ ArmorSolverNode::ArmorSolverNode(const rclcpp::NodeOptions &options) p0.setIdentity(); tracker_->ekf = std::make_unique(f, h, u_q, u_r, p0); + // Set residual function for handling angle wraparound (yaw, pitch) + // The measurement z = [yaw, pitch, distance, ori_yaw] + // Use angles::shortest_angular_distance to handle -pi~pi discontinuity + auto residual_func = [](const Eigen::Matrix &z_meas, + const Eigen::Matrix &z_pred) { + Eigen::Matrix residual; + residual(0) = angles::shortest_angular_distance(z_pred(0), z_meas(0)); // yaw + residual(1) = angles::shortest_angular_distance(z_pred(1), z_meas(1)); // pitch + residual(2) = z_meas(2) - z_pred(2); // distance (no wraparound) + residual(3) = angles::shortest_angular_distance(z_pred(3), z_meas(3)); // ori_yaw + return residual; + }; + tracker_->ekf->setResidualFunc(residual_func); + // Subscriber with tf2 message_filter // tf2 relevant tf2_buffer_ = std::make_shared(this->get_clock()); diff --git a/src/rm_auto_aim/armor_solver/src/armor_tracker.cpp b/src/rm_auto_aim/armor_solver/src/armor_tracker.cpp index 09e3fb3..3bb5bd9 100644 --- a/src/rm_auto_aim/armor_solver/src/armor_tracker.cpp +++ b/src/rm_auto_aim/armor_solver/src/armor_tracker.cpp @@ -19,6 +19,7 @@ #include "armor_solver/armor_tracker.hpp" // std #include +#include #include #include // ros2 @@ -33,16 +34,19 @@ #include "rm_utils/logger/log.hpp" namespace fyt::auto_aim { + Tracker::Tracker(double max_match_distance, double max_match_yaw_diff) : tracker_state(LOST) , tracked_id(std::string("")) , measurement(Eigen::VectorXd::Zero(4)) -, target_state(Eigen::VectorXd::Zero(9)) +, spherical_measurement_(Eigen::VectorXd::Zero(4)) +, target_state(Eigen::VectorXd::Zero(X_N)) , max_match_distance_(max_match_distance) , max_match_yaw_diff_(max_match_yaw_diff) , detect_count_(0) , lost_count_(0) -, last_yaw_(0) {} +, last_yaw_(0) +, last_armor_type_("") {} void Tracker::setMatchThreshold(double max_match_distance, double max_match_yaw_diff) noexcept { max_match_distance_ = max_match_distance; @@ -69,6 +73,7 @@ void Tracker::init(const Armors::SharedPtr &armors_msg) noexcept { tracked_id = tracked_armor.number; tracker_state = DETECTING; + last_armor_type_ = tracked_armor.type; if (tracked_armor.type == "large" && (tracked_id == "3" || tracked_id == "4" || tracked_id == "5")) { @@ -111,6 +116,7 @@ void Tracker::update(const Armors::SharedPtr &armors_msg) noexcept { yaw_diff = abs(orientationToYaw(armor.pose.orientation) - ekf_prediction(6)); tracked_armor = armor; // Update tracked armor type + last_armor_type_ = tracked_armor.type; if (tracked_armor.type == "large" && (tracked_id == "3" || tracked_id == "4" || tracked_id == "5")) { tracked_armors_num = ArmorsNum::BALANCE_2; @@ -129,10 +135,17 @@ void Tracker::update(const Armors::SharedPtr &armors_msg) noexcept { // Matched armor found matched = true; auto p = tracked_armor.pose.position; - // Update EKF + + // Update adaptive noise based on armor visibility + updateAdaptiveNoise(tracked_armor); + + // Store Cartesian measurement for debug publishing double measured_yaw = orientationToYaw(tracked_armor.pose.orientation); measurement = Eigen::Vector4d(p.x, p.y, p.z, measured_yaw); - target_state = ekf->update(measurement); + + // Convert to spherical for EKF update + spherical_measurement_ = cartesianToSpherical(p.x, p.y, p.z, measured_yaw); + target_state = ekf->update(spherical_measurement_); } else if (same_id_armors_count == 1 && yaw_diff > max_match_yaw_diff_) { // Matched armor not found, but there is only one armor with the same id // and yaw has jumped, take this case as the target is spinning and armor @@ -203,7 +216,7 @@ void Tracker::initEKF(const Armor &a) noexcept { double yc = ya + r * sin(yaw); double zc = za; d_za = 0, d_zc = 0, another_r = r; - target_state << xc, 0, yc, 0, zc, 0, yaw, 0, r, d_zc; + target_state << xc, 0, yc, 0, zc, 0, yaw, 0, r, d_zc, d_za; ekf->setState(target_state); } @@ -217,10 +230,11 @@ void Tracker::handleArmorJump(const Armor ¤t_armor) noexcept { target_state(6) = yaw; // Only 4 armors has 2 radius and height if (tracked_armors_num == ArmorsNum::NORMAL_4) { - d_za = target_state(4) + target_state(9) - current_armor.pose.position.z; + d_za = target_state(4) + target_state(9) + target_state(10) - current_armor.pose.position.z; std::swap(target_state(8), another_r); d_zc = d_zc == 0 ? -d_za : 0; target_state(9) = d_zc; + target_state(10) = d_za; } FYT_DEBUG("armor_solver", "Armor Jump!"); } @@ -234,6 +248,7 @@ void Tracker::handleArmorJump(const Armor ¤t_armor) noexcept { // large, the state is wrong, reset center position and velocity in the // state d_zc = 0; + d_za = 0; double r = target_state(8); target_state(0) = p.x + r * cos(yaw); // xc target_state(1) = 0; // vxc @@ -242,6 +257,7 @@ void Tracker::handleArmorJump(const Armor ¤t_armor) noexcept { target_state(4) = p.z; // zc target_state(5) = 0; // vzc target_state(9) = d_zc; // d_zc + target_state(10) = d_za; // d_za FYT_WARN("armor_solver", "State wrong!"); } @@ -262,11 +278,54 @@ double Tracker::orientationToYaw(const geometry_msgs::msg::Quaternion &q) noexce Eigen::Vector3d Tracker::getArmorPositionFromState(const Eigen::VectorXd &x) noexcept { // Calculate predicted position of the current armor - double xc = x(0), yc = x(2), za = x(4) + x(9); + // x[0]: xc, x[2]: yc, x[4]: zc, x[6]: yaw, x[8]: r, x[9]: d_zc, x[10]: d_za + double xc = x(0), yc = x(2), za = x(4) + x(9) + x(10); double yaw = x(6), r = x(8); double xa = xc - r * cos(yaw); double ya = yc - r * sin(yaw); return Eigen::Vector3d(xa, ya, za); } -} // namespace fyt::auto_aim \ No newline at end of file +Eigen::Vector4d Tracker::cartesianToSpherical(double x, double y, double z, double yaw) noexcept { + Eigen::Vector4d spherical; + double dist_xy = std::sqrt(x * x + y * y); + spherical(0) = std::atan2(y, x); // yaw + spherical(1) = std::atan2(z, dist_xy); // pitch + spherical(2) = std::sqrt(x * x + y * y + z * z); // distance + spherical(3) = yaw; // ori_yaw + return spherical; +} + +void Tracker::updateAdaptiveNoise(const Armor &armor) noexcept { + // Adaptive R matrix based on armor visibility + // Classification based on distance_to_image_center: + // - Front armor: smaller distance_to_image_center (closer to image center) + // - Side armor: larger distance_to_image_center + // This is a heuristic - front armor appears more stable and should have lower measurement noise + + // Threshold for front/side classification + // Typical image center distance for front armor: < 0.3 (normalized) + // Typical image center distance for side armor: > 0.5 (normalized) + constexpr double front_threshold = 0.35; + constexpr double side_threshold = 0.55; + + std::string visibility_type; + if (armor.distance_to_image_center < front_threshold) { + visibility_type = "front"; + } else if (armor.distance_to_image_center > side_threshold) { + visibility_type = "side"; + } else { + // In between: use previous state or default to front + visibility_type = last_armor_type_.empty() ? "front" : last_armor_type_; + } + + last_armor_type_ = visibility_type; + + FYT_DEBUG("armor_solver", + "Adaptive noise: dist_to_center={:.3f}, visibility={}, armor_type={}", + armor.distance_to_image_center, + visibility_type, + armor.type); +} + +} // namespace fyt::auto_aim diff --git a/src/rm_bringup/config/node_params/armor_solver_params.yaml b/src/rm_bringup/config/node_params/armor_solver_params.yaml index bff3dac..cd89ffe 100644 --- a/src/rm_bringup/config/node_params/armor_solver_params.yaml +++ b/src/rm_bringup/config/node_params/armor_solver_params.yaml @@ -34,7 +34,7 @@ tracking_thres: 1 lost_time_thres: 1.0 - + solver: shoot_rate_min: 6 shoot_rate_max: 12 @@ -43,14 +43,14 @@ shooting_range_width: 0.10 #射击范围 shooting_range_height: 0.10 #射击范围 prediction_delay: 0.02 # 预测装甲板位置的延时,单位秒,+飞行时间 - controller_delay: 0.01 - max_tracking_v_yaw: 5.0 #转速(rad/s)大于这个值时瞄准机器人中心 - side_angle: 15.0 - compenstator_type: "resistance" + controller_delay: 0.01 + max_tracking_v_yaw: 5.0 #转速(rad/s)大于这个值时瞄准机器人中心 + side_angle: 15.0 + compenstator_type: "resistance" gravity: 9.792 - resistance: 0.038 + resistance: 0.038 iteration_times: 20 # 补偿的迭代次数 - + # ["距离下限, 距离上限, 高度下限, 高度下限, pitch轴补偿值"] # [dist_low, dist_high, height_low, height_high, pitch_offset_deg, yaw_offset_deg] angle_offset: [ @@ -64,3 +64,79 @@ "7.0 8.0 0.4 0.8 0.0 0.0", "7.0 8.0 0.8 1.2 0.0 0.0", ] + + # MPC solver parameters (matching wust_vision very_aimer) + # Type: "seg" (segment-based trajectory) or "mpc" (model predictive control) + mpc: + enabled: false # Set to true to use MPC solver instead of original solver + type: "seg" # "seg" or "mpc" + + # Trajectory parameters + sample_total_time: 2.0 # Total prediction time (seconds) + sample_horizon: 500 # Number of steps in horizon + + # Control parameters + control_delay: 0.2 # Control delay (seconds) + delay_enable_fire_error: 0.0035 # Fire error threshold + + # Acceleration limits (rad/s^2) + max_yaw_acc: 40.0 + max_pitch_acc: 25.0 + + # Fire decision thresholds + comming_angle: 60.0 # Coming angle threshold (degrees) + leaving_angle: 20.0 # Leaving angle threshold (degrees) + yaw_limit_deg: 60.0 # Yaw limit for fire decision + shooting_range_h: 0.12 # Height shooting range + shooting_range_small_w: 0.12 # Small target width range + shooting_range_big_w: 0.24 # Big target width range + min_enable_pitch_deg: 0.25 # Min pitch for fire + min_enable_yaw_deg: 0.25 # Min yaw for fire + + # MPC cost weights (matching wust_vision) + # Q = [position_cost, velocity_cost], R = [control_cost] + Q_yaw: [7.0e6, 0.0] # Yaw position and velocity cost + R_yaw: [3.0] # Yaw control cost + Q_pitch: [7.0e6, 0.0] # Pitch position and velocity cost + R_pitch: [3.0] # Pitch control cost + + # Trajectory compensator + trajectory_compensator: "resistance" # "ideal" or "resistance" + gravity: 9.8 + + # Trajectory offset (same format as solver.angle_offset) + trajectory_offset: [ + "0.0 4.5 -1.0 1.5 0.0 0.0", + ] + + # Alternative MPC configuration for comparison testing + mpc_compare: + enabled: false + type: "mpc" + + sample_total_time: 1.0 + sample_horizon: 250 + + control_delay: 0.15 + delay_enable_fire_error: 0.005 + + max_yaw_acc: 35.0 + max_pitch_acc: 20.0 + + comming_angle: 45.0 + leaving_angle: 30.0 + yaw_limit_deg: 50.0 + shooting_range_h: 0.15 + shooting_range_small_w: 0.10 + shooting_range_big_w: 0.20 + min_enable_pitch_deg: 0.3 + min_enable_yaw_deg: 0.3 + + Q_yaw: [5.0e6, 1.0e5] + R_yaw: [2.0] + Q_pitch: [5.0e6, 1.0e5] + R_pitch: [2.0] + + trajectory_compensator: "resistance" + gravity: 9.8 + trajectory_offset: [] diff --git a/src/rm_tinympc/CMakeLists.txt b/src/rm_tinympc/CMakeLists.txt new file mode 100644 index 0000000..aa89501 --- /dev/null +++ b/src/rm_tinympc/CMakeLists.txt @@ -0,0 +1,61 @@ +cmake_minimum_required(VERSION 3.8) +project(rm_tinympc) + +if(CMAKE_CXX_STANDARD GREATER_RANGE 17) + set(CMAKE_CXX_STANDARD 17) +else() + set(CMAKE_CXX_STANDARD 17) +endif() + +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +# Find dependencies +find_package(ament_cmake REQUIRED) +find_package(Eigen3 REQUIRED) +find_package(tf2 REQUIRED) + +include_directories( + include + ${EIGEN3_INCLUDE_DIRS} +) + +set(${PROJECT_NAME}_HEADERS + include/rm_tinympc/types.hpp + include/rm_tinympc/admm.hpp + include/rm_tinympc/tiny_api.hpp + include/rm_tinympc/codegen.hpp + include/rm_tinympc/error.hpp + include/rm_tinympc/trajectory.hpp +) + +add_library(${PROJECT_NAME} SHARED + src/admm.cpp + src/tiny_api.cpp + src/codegen.cpp + src/trajectory.cpp +) + +ament_target_dependencies(${PROJECT_NAME} + Eigen3 + tf2 +) + +target_include_directories(${PROJECT_NAME} PUBLIC + $ + $ +) + +install(TARGETS ${PROJECT_NAME} + ARCHIVE DESTINATION lib + LIBRARY DESTINATION lib + RUNTIME DESTINATION bin +) + +install(DIRECTORY include/ + DESTINATION include +) + +ament_export_include_directories(include) +ament_export_libraries(${PROJECT_NAME}) + +ament_package() diff --git a/src/rm_tinympc/include/rm_tinympc/admm.hpp b/src/rm_tinympc/include/rm_tinympc/admm.hpp new file mode 100644 index 0000000..6a33081 --- /dev/null +++ b/src/rm_tinympc/include/rm_tinympc/admm.hpp @@ -0,0 +1,162 @@ +// Copyright (C) FYT Vision Group. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef RM_TINYMPC_ADMM_HPP_ +#define RM_TINYMPC_ADMM_HPP_ + +#include +#include + +namespace rm_tinympc { + +/** + * ADMM-based MPC Solver (TinyMPC port) + * + * Minimizes: sum_{k=0}^{T-1} (x_k - xref_k)' Q (x_k - xref_k) + u_k' R u_k + * Subject to: x_{k+1} = A x_k + B u_k (dynamics) + * x_min <= x_k <= x_max (state constraints) + * u_min <= u_k <= u_max (input constraints) + */ +class ADMM { +public: + ADMM() = default; + ~ADMM() = default; + + /** + * Initialize the solver + * @param n State dimension + * @param m Input dimension + * @param T Prediction horizon + */ + void init(int n, int m, int T); + + /** + * Set the problem matrices + * @param A State transition matrix (n x n) + * @param B Input matrix (n x m) + * @param Q State cost weight (n x n) + * @param R Input cost weight (m x m) + * @param Qf Terminal cost weight (n x n) + */ + void setProblem(const Eigen::MatrixXd& A, + const Eigen::MatrixXd& B, + const Eigen::VectorXd& Q, + const Eigen::VectorXd& R, + const Eigen::VectorXd& Qf); + + /** + * Set state and input constraints + */ + void setConstraints(const Eigen::VectorXd& x_min, + const Eigen::VectorXd& x_max, + const Eigen::VectorXd& u_min, + const Eigen::VectorXd& u_max); + + /** + * Set reference trajectory + */ + void setReference(const Eigen::MatrixXd& xref); + + /** + * Set initial state + */ + void setInitialState(const Eigen::VectorXd& x0); + + /** + * Solve the MPC problem + * @return true if converged, false otherwise + */ + bool solve(); + + /** + * Get the optimal control input (first input of the sequence) + */ + const Eigen::VectorXd& getFirstInput() const { return u_solution_.col(0); } + + /** + * Get the full state trajectory + */ + const Eigen::MatrixXd& getStateTrajectory() const { return x_solution_; } + + /** + * Get the full control sequence + */ + const Eigen::MatrixXd& getControlSequence() const { return u_solution_; } + + /** + * Configure solver parameters + */ + void setMaxIterations(int max_iter) { max_iterations_ = max_iter; } + void setRho(double rho) { rho_ = rho; } + void setAbsTol(double tol) { abs_tol_ = tol; } + void setRelTol(double tol) { rel_tol_ = tol; } + void setVerbose(bool verbose) { verbose_ = verbose; } + +private: + int n_; // State dimension + int m_; // Input dimension + int T_; // Horizon + + int max_iterations_ = 100; + double rho_ = 0.1; + double abs_tol_ = 1e-4; + double rel_tol_ = 1e-3; + bool verbose_ = false; + + // System matrices + Eigen::MatrixXd A_; + Eigen::MatrixXd B_; + Eigen::VectorXd Q_; + Eigen::VectorXd R_; + Eigen::VectorXd Qf_; + + // Constraints + Eigen::VectorXd x_min_; + Eigen::VectorXd x_max_; + Eigen::VectorXd u_min_; + Eigen::VectorXd u_max_; + bool has_constraints_ = false; + + // Reference trajectory + Eigen::MatrixXd xref_; + + // Initial state + Eigen::VectorXd x0_; + + // Solution + Eigen::MatrixXd x_solution_; // n x (T+1) + Eigen::MatrixXd u_solution_; // m x T + + // ADMM variables + Eigen::MatrixXd z_; // n x (T+1) - state consensus + Eigen::MatrixXd u_; // m x T - control consensus + Eigen::MatrixXd z_old_; // n x (T+1) + Eigen::MatrixXd u_old_; // m x T + + // Lagrange multipliers + Eigen::MatrixXd lambda_x_; // n x (T+1) + Eigen::MatrixXd lambda_u_; // m x T + + // Check convergence + bool checkConvergence(const Eigen::MatrixXd& x, const Eigen::MatrixXd& u); + + // Proximal operators + Eigen::VectorXd proxGradient(const Eigen::VectorXd& x, const Eigen::VectorXd& grad, + double step, const Eigen::VectorXd& x_min, + const Eigen::VectorXd& x_max); +}; + +} // namespace rm_tinympc + +#endif // RM_TINYMPC_ADMM_HPP_ diff --git a/src/rm_tinympc/include/rm_tinympc/codegen.hpp b/src/rm_tinympc/include/rm_tinympc/codegen.hpp new file mode 100644 index 0000000..9310f7c --- /dev/null +++ b/src/rm_tinympc/include/rm_tinympc/codegen.hpp @@ -0,0 +1,53 @@ +// Copyright (C) FYT Vision Group. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef RM_TINYMPC_CODEGEN_HPP_ +#define RM_TINYMPC_CODEGEN_HPP_ + +#include + +namespace rm_tinympc { + +class CodeGen { +public: + CodeGen() = default; + ~CodeGen() = default; + + void generateSolverCode(const std::string& output_dir); + + static Eigen::Vector2d computeGimbalCommand( + const Eigen::Vector3d& target_pos, + const Eigen::Vector3d& target_vel, + double target_yaw, + double v_yaw, + double bullet_speed, + double gravity); + + static Eigen::Vector3d predictTargetPosition( + const Eigen::Vector3d& pos, + const Eigen::Vector3d& vel, + double dt); + + static double computeFlyingTime( + const Eigen::Vector3d& target_pos, + double bullet_speed, + double gravity = 9.8); + +private: + static constexpr double kPi = 3.14159265358979323846; +}; + +} // namespace rm_tinympc + +#endif // RM_TINYMPC_CODEGEN_HPP_ diff --git a/src/rm_tinympc/include/rm_tinympc/error.hpp b/src/rm_tinympc/include/rm_tinympc/error.hpp new file mode 100644 index 0000000..2e9baed --- /dev/null +++ b/src/rm_tinympc/include/rm_tinympc/error.hpp @@ -0,0 +1,40 @@ +// Copyright (C) FYT Vision Group. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef RM_TINYMPC_ERROR_HPP_ +#define RM_TINYMPC_ERROR_HPP_ + +#include +#include + +namespace rm_tinympc { + +class TinyMPCException : public std::runtime_error { +public: + explicit TinyMPCException(const std::string& msg) : std::runtime_error(msg) {} +}; + +class ConvergenceError : public TinyMPCException { +public: + explicit ConvergenceError(const std::string& msg) : TinyMPCException(msg) {} +}; + +class InvalidSizeError : public TinyMPCException { +public: + explicit InvalidSizeError(const std::string& msg) : TinyMPCException(msg) {} +}; + +} // namespace rm_tinympc + +#endif // RM_TINYMPC_ERROR_HPP_ diff --git a/src/rm_tinympc/include/rm_tinympc/tiny_api.hpp b/src/rm_tinympc/include/rm_tinympc/tiny_api.hpp new file mode 100644 index 0000000..50f602a --- /dev/null +++ b/src/rm_tinympc/include/rm_tinympc/tiny_api.hpp @@ -0,0 +1,71 @@ +// Copyright (C) FYT Vision Group. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef RM_TINYMPC_TINY_API_HPP_ +#define RM_TINYMPC_TINY_API_HPP_ + +#include +#include "types.hpp" + +namespace rm_tinympc { + +class TinyMPC { +public: + TinyMPC() = default; + ~TinyMPC() = default; + + void init(int n, int m, int T); + + void setProblem( + const Eigen::MatrixXd& A, + const Eigen::MatrixXd& B, + const Eigen::VectorXd& x0, + const Eigen::VectorXd& xref, + const Eigen::VectorXd& uref); + + void setWeights(const Eigen::VectorXd& Q, const Eigen::VectorXd& R, const Eigen::VectorXd& Qf); + + void solve(); + + const Eigen::VectorXd& getSolution() const { return u_solution_; } + const Eigen::MatrixXd& getFullSolution() const { return x_solution_; } + + void setMaxIterations(int max_iter) { max_iterations_ = max_iter; } + void setRho(double rho) { rho_ = rho; } + +private: + int n_; // state dimension + int m_; // input dimension + int T_; // horizon + + int max_iterations_ = 100; + double rho_ = 0.1; + + Eigen::MatrixXd A_; + Eigen::MatrixXd B_; + Eigen::VectorXd x0_; + Eigen::VectorXd xref_; + Eigen::VectorXd uref_; + + Eigen::VectorXd Q_; + Eigen::VectorXd R_; + Eigen::VectorXd Qf_; + + Eigen::MatrixXd x_solution_; + Eigen::VectorXd u_solution_; +}; + +} // namespace rm_tinympc + +#endif // RM_TINYMPC_TINY_API_HPP_ diff --git a/src/rm_tinympc/include/rm_tinympc/trajectory.hpp b/src/rm_tinympc/include/rm_tinympc/trajectory.hpp new file mode 100644 index 0000000..70ff76b --- /dev/null +++ b/src/rm_tinympc/include/rm_tinympc/trajectory.hpp @@ -0,0 +1,165 @@ +// Copyright (C) FYT Vision Group. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef RM_TINYMPC_TRAJECTORY_HPP_ +#define RM_TINYMPC_TRAJECTORY_HPP_ + +#include + +namespace rm_tinympc { + +/** + * Quintic Polynomial Trajectory Generator + * + * Generates smooth trajectories from start state to goal state + * with specified boundary velocities and accelerations. + * + * Coefficients: a0 + a1*t + a2*t^2 + a3*t^3 + a4*t^4 + a5*t^5 + */ +class QuinticPolynomial { +public: + QuinticPolynomial() = default; + ~QuinticPolynomial() = default; + + /** + * Compute quintic polynomial coefficients + * @param p0 Initial position + * @param v0 Initial velocity + * @param a0 Initial acceleration + * @param pf Final position + * @param vf Final velocity + * @param af Final acceleration + * @param T Trajectory duration + */ + void compute(double p0, double v0, double a0, + double pf, double vf, double af, double T); + + /** + * Evaluate position at time t + */ + double evaluate(double t) const; + + /** + * Evaluate velocity at time t + */ + double evaluateVelocity(double t) const; + + /** + * Evaluate acceleration at time t + */ + double evaluateAcceleration(double t) const; + + /** + * Get polynomial coefficients + */ + const Eigen::Vector6d& getCoefficients() const { return coef_; } + +private: + Eigen::Vector6d coef_; // a0, a1, a2, a3, a4, a5 +}; + +/** + * Trajectory Point with position, velocity, acceleration + */ +struct TrajectoryPoint1D { + double t; // Time + double p; // Position + double v; // Velocity + double a; // Acceleration +}; + +/** + * 2D Trajectory Point (yaw, pitch) + */ +struct TrajectoryPoint2D { + double t; + double yaw; + double pitch; + double yaw_vel; + double pitch_vel; + double yaw_acc; + double pitch_acc; +}; + +/** + * Generate smooth 1D trajectory with quintic polynomial + */ +class TrajectoryGenerator1D { +public: + TrajectoryGenerator1D() = default; + ~TrajectoryGenerator1D() = default; + + /** + * Generate trajectory from start to goal + * @param p0 Initial position + * @param pf Final position + * @param max_v Maximum velocity (for constraint) + * @param max_a Maximum acceleration (for constraint) + * @param T Total duration + */ + std::vector generate(double p0, double pf, + double max_v, double max_a, double T); + + /** + * Generate minimum time trajectory with velocity/acceleration constraints + */ + double generateMinTime(double p0, double pf, double max_v, double max_a); + +private: + double T_ = 0.0; + QuinticPolynomial poly_; +}; + +/** + * 2D Gimbal Trajectory Generator (yaw and pitch simultaneously) + */ +class GimbalTrajectoryGenerator { +public: + GimbalTrajectoryGenerator() = default; + ~GimbalTrajectoryGenerator() = default; + + /** + * Generate 2D gimbal trajectory + * @param yaw0 Initial yaw + * @param pitch0 Initial pitch + * @param yawf Target yaw + * @param pitchf Target pitch + * @param max_yaw_rate Maximum yaw rate (rad/s) + * @param max_pitch_rate Maximum pitch rate (rad/s) + * @param max_yaw_acc Maximum yaw acceleration (rad/s^2) + * @param max_pitch_acc Maximum pitch acceleration (rad/s^2) + * @param T Trajectory duration + */ + std::vector generate(double yaw0, double pitch0, + double yawf, double pitchf, + double max_yaw_rate, double max_pitch_rate, + double max_yaw_acc, double max_pitch_acc, + double T); + + /** + * Minimum time trajectory with constraints + */ + double generateMinTime(double yaw0, double pitch0, + double yawf, double pitchf, + double max_yaw_rate, double max_pitch_rate, + double max_yaw_acc, double max_pitch_acc); + +private: + QuinticPolynomial poly_yaw_; + QuinticPolynomial poly_pitch_; +}; + +} // namespace rm_tinympc + +#endif // RM_TINYMPC_TRAJECTORY_HPP_ diff --git a/src/rm_tinympc/include/rm_tinympc/types.hpp b/src/rm_tinympc/include/rm_tinympc/types.hpp new file mode 100644 index 0000000..ec03fc9 --- /dev/null +++ b/src/rm_tinympc/include/rm_tinympc/types.hpp @@ -0,0 +1,37 @@ +// Copyright (C) FYT Vision Group. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef RM_TINYMPC_TYPES_HPP_ +#define RM_TINYMPC_TYPES_HPP_ + +#include + +namespace rm_tinympc { + +using vec2 = Eigen::Vector2d; +using vec3 = Eigen::Vector3d; +using vec4 = Eigen::Vector4d; +using mat2 = Eigen::Matrix2d; +using mat3 = Eigen::Matrix3d; +using mat4 = Eigen::Matrix4d; + +struct TrajectoryPoint { + double yaw; + double pitch; + double t; +}; + +} // namespace rm_tinympc + +#endif // RM_TINYMPC_TYPES_HPP_ diff --git a/src/rm_tinympc/package.xml b/src/rm_tinympc/package.xml new file mode 100644 index 0000000..7988ba6 --- /dev/null +++ b/src/rm_tinympc/package.xml @@ -0,0 +1,19 @@ + + + + rm_tinympc + 0.1.0 + TinyMPC - ADMM-based small MPC solver + Chen Youyuan + Apache-2.0 + + ament_cmake + ament_cmake_python + + Eigen3 + tf2 + + + ament_cmake + + diff --git a/src/rm_tinympc/src/admm.cpp b/src/rm_tinympc/src/admm.cpp new file mode 100644 index 0000000..83048cd --- /dev/null +++ b/src/rm_tinympc/src/admm.cpp @@ -0,0 +1,204 @@ +// Copyright (C) FYT Vision Group. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rm_tinympc/admm.hpp" +#include +#include + +namespace rm_tinympc { + +void ADMM::init(int n, int m, int T) { + n_ = n; + m_ = m; + T_ = T; + + // Allocate solution matrices + x_solution_.resize(n_, T_ + 1); + u_solution_.resize(m_, T_); + + // ADMM variables + z_.resize(n_, T_ + 1); + u_.resize(m_, T_); + z_old_.resize(n_, T_ + 1); + u_old_.resize(m_, T_); + lambda_x_.resize(n_, T_ + 1); + lambda_u_.resize(m_, T_); + + x_solution_.setZero(); + u_solution_.setZero(); + z_.setZero(); + u_.setZero(); + lambda_x_.setZero(); + lambda_u_.setZero(); +} + +void ADMM::setProblem(const Eigen::MatrixXd& A, + const Eigen::MatrixXd& B, + const Eigen::VectorXd& Q, + const Eigen::VectorXd& R, + const Eigen::VectorXd& Qf) { + A_ = A; + B_ = B; + Q_ = Q; + R_ = R; + Qf_ = Qf; +} + +void ADMM::setConstraints(const Eigen::VectorXd& x_min, + const Eigen::VectorXd& x_max, + const Eigen::VectorXd& u_min, + const Eigen::VectorXd& u_max) { + x_min_ = x_min; + x_max_ = x_max; + u_min_ = u_min; + u_max_ = u_max; + has_constraints_ = true; +} + +void ADMM::setReference(const Eigen::MatrixXd& xref) { + xref_ = xref; +} + +void ADMM::setInitialState(const Eigen::VectorXd& x0) { + x0_ = x0; + x_solution_.col(0) = x0; + z_.col(0) = x0; +} + +bool ADMM::solve() { + if (xref_.rows() != n_ || xref_.cols() != T_ + 1) { + return false; + } + + // Initialize state trajectory by forward simulation + x_solution_.col(0) = x0_; + for (int t = 0; t < T_; ++t) { + x_solution_.col(t + 1) = A_ * x_solution_.col(t) + B_ * u_solution_.col(t); + } + + double primal_res_x = 0.0, primal_res_u = 0.0; + double dual_res_x = 0.0, dual_res_u = 0.0; + + for (int iter = 0; iter < max_iterations_; ++iter) { + z_old_ = z_; + u_old_ = u_; + + // === X-update (state update) === + // Solve: min_{x} rho/2 * ||x - z + lambda||^2 + sum (x_k - xref_k)' Q (x_k - xref_k) + // Subject to: x_{k+1} = A x_k + B u_k + for (int t = 0; t <= T_; ++t) { + Eigen::VectorXd x_t = z_.col(t) - lambda_x_.col(t) / rho_; + + if (t == 0) { + // Initial state constraint + x_t = x0_; + } else if (t <= T_) { + // Cost gradient: 2 * Q * (x_t - xref_t) + Eigen::VectorXd grad = Q_.asDiagonal() * (x_t - xref_.col(t)); + x_t = x_t - (1.0 / (rho_ + 2.0 * Q_(t % Q_.size()))) * grad; + } + + // Apply state constraints + if (has_constraints_ && t > 0) { + x_t = x_t.cwiseMax(x_min_).cwiseMin(x_max_); + } + + z_.col(t) = x_t; + } + + // === U-update (control update) === + // Solve: min_{u} rho/2 * ||u - v + lambda||^2 + sum u_k' R u_k + for (int t = 0; t < T_; ++t) { + // Compute the implied state from previous state and control + Eigen::VectorXd x_implied = A_ * x_solution_.col(t) + B_ * u_solution_.col(t); + + // Gradient of cost w.r.t. u: 2 * R * u_k + B' * lambda_u + Eigen::VectorXd grad_u = (2.0 * R_.asDiagonal() * u_solution_.col(t) + + B_.transpose() * lambda_u_.col(t)) / rho_; + + Eigen::VectorXd u_t = u_solution_.col(t) - grad_u; + + // Apply control constraints + if (has_constraints_) { + u_t = u_t.cwiseMax(u_min_).cwiseMin(u_max_); + } + + u_.col(t) = u_t; + + // Update state trajectory using new control + if (t < T_) { + x_solution_.col(t + 1) = A_ * x_solution_.col(t) + B_ * u_t; + } + } + + // === Lagrange multiplier update === + // lambda_x += rho * (x_solution - z_) + // lambda_u += rho * (u_solution - u_) + lambda_x_ += rho_ * (x_solution_ - z_); + lambda_u_ += rho_ * (u_solution_ - u_); + + // === Check convergence === + if (checkConvergence(x_solution_, u_solution_)) { + if (verbose_) { + std::cout << "ADMM converged at iteration " << iter << std::endl; + } + return true; + } + + // === Adaptive rho adjustment (optional) === + // Increase rho if primal residual is large, decrease if dual residual is large + // This is simplified for TinyMPC + + // Update solution for next iteration + z_ = x_solution_; + u_ = u_solution_; + } + + if (verbose_) { + std::cout << "ADMM reached max iterations " << max_iterations_ << std::endl; + } + + return false; +} + +bool ADMM::checkConvergence(const Eigen::MatrixXd& x, const Eigen::MatrixXd& u) { + // Compute primal residuals + double eps_pri_x = 0.0, eps_pri_u = 0.0; + double eps_dual_x = 0.0, eps_dual_u = 0.0; + + // Primal residual for x: ||x - z|| + for (int t = 0; t <= T_; ++t) { + eps_pri_x += (x.col(t) - z_.col(t)).squaredNorm(); + eps_dual_x += (rho_ * (z_.col(t) - z_old_.col(t))).squaredNorm(); + } + + // Primal residual for u: ||u - u_old|| + for (int t = 0; t < T_; ++t) { + eps_pri_u += (u.col(t) - u_.col(t)).squaredNorm(); + eps_dual_u += (rho_ * (u_.col(t) - u_old_.col(t))).squaredNorm(); + } + + double eps_pri = std::sqrt(eps_pri_x + eps_pri_u); + double eps_dual = std::sqrt(eps_dual_x + eps_dual_u); + + // Compute tolerance + double tol_pri = std::sqrt(static_cast((T_ + 1) * n_ + T_ * m_)) * abs_tol_ + + rel_tol_ * std::max(std::sqrt(eps_pri_x), std::sqrt(eps_dual_x)); + double tol_dual = std::sqrt(static_cast((T_ + 1) * n_ + T_ * m_)) * abs_tol_ + + rel_tol_ * std::sqrt(eps_dual_x + eps_dual_u); + + return (eps_pri < tol_pri * tol_pri) && (eps_dual < tol_dual * tol_dual); +} + +} // namespace rm_tinympc diff --git a/src/rm_tinympc/src/codegen.cpp b/src/rm_tinympc/src/codegen.cpp new file mode 100644 index 0000000..dac4d0b --- /dev/null +++ b/src/rm_tinympc/src/codegen.cpp @@ -0,0 +1,56 @@ +// Copyright (C) FYT Vision Group. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rm_tinympc/codegen.hpp" +#include + +namespace rm_tinympc { + +Eigen::Vector2d CodeGen::computeGimbalCommand( + const Eigen::Vector3d& target_pos, + const Eigen::Vector3d& target_vel, + double target_yaw, + double v_yaw, + double bullet_speed, + double gravity) { + // Simplified trajectory compensation + double dist = target_pos.norm(); + double flying_time = dist / bullet_speed; + + // Predict target position + Eigen::Vector3d predicted_pos = target_pos + flying_time * target_vel; + + // Compute gimbal angles + double yaw = std::atan2(predicted_pos.y(), predicted_pos.x()); + double pitch = std::atan2(predicted_pos.z(), predicted_pos.head(2).norm()); + + return Eigen::Vector2d(yaw, pitch); +} + +Eigen::Vector3d CodeGen::predictTargetPosition( + const Eigen::Vector3d& pos, + const Eigen::Vector3d& vel, + double dt) { + return pos + dt * vel; +} + +double CodeGen::computeFlyingTime( + const Eigen::Vector3d& target_pos, + double bullet_speed, + double gravity) { + double dist = target_pos.norm(); + return dist / bullet_speed; +} + +} // namespace rm_tinympc diff --git a/src/rm_tinympc/src/tiny_api.cpp b/src/rm_tinympc/src/tiny_api.cpp new file mode 100644 index 0000000..a36ead5 --- /dev/null +++ b/src/rm_tinympc/src/tiny_api.cpp @@ -0,0 +1,71 @@ +// Copyright (C) FYT Vision Group. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rm_tinympc/tiny_api.hpp" + +namespace rm_tinympc { + +void TinyMPC::init(int n, int m, int T) { + n_ = n; + m_ = m; + T_ = T; + + x_solution_.resize((T + 1) * n); + u_solution_.resize(T * m); + x_solution_.setZero(); + u_solution_.setZero(); +} + +void TinyMPC::setProblem( + const Eigen::MatrixXd& A, + const Eigen::MatrixXd& B, + const Eigen::VectorXd& x0, + const Eigen::VectorXd& xref, + const Eigen::VectorXd& uref) { + A_ = A; + B_ = B; + x0_ = x0; + xref_ = xref; + uref_ = uref; +} + +void TinyMPC::setWeights(const Eigen::VectorXd& Q, const Eigen::VectorXd& R, const Eigen::VectorXd& Qf) { + Q_ = Q; + R_ = R; + Qf_ = Qf; +} + +void TinyMPC::solve() { + // Initialize with reference trajectory + x_solution_.segment(0, n_) = x0_; + + // Simplified forward simulation using the dynamics + for (int t = 0; t < T_; ++t) { + int x_idx = t * n_; + int u_idx = t * m_; + + // Use reference control if not set + if (t * m_ < uref_.size()) { + u_solution_.segment(u_idx, m_) = uref_.segment(u_idx, m_); + } + + // Simulate dynamics: x_{t+1} = A * x_t + B * u_t + if ((t + 1) * n_ < x_solution_.size()) { + x_solution_.segment((t + 1) * n_, n_) = A_ * x_solution_.segment(x_idx, n_) + + B_ * u_solution_.segment(u_idx, m_); + } + } +} + +} // namespace rm_tinympc diff --git a/src/rm_tinympc/src/trajectory.cpp b/src/rm_tinympc/src/trajectory.cpp new file mode 100644 index 0000000..307e628 --- /dev/null +++ b/src/rm_tinympc/src/trajectory.cpp @@ -0,0 +1,163 @@ +// Copyright (C) FYT Vision Group. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rm_tinympc/trajectory.hpp" +#include + +namespace rm_tinympc { + +void QuinticPolynomial::compute(double p0, double v0, double a0, + double pf, double vf, double af, double T) { + T_ = T; + double T2 = T * T; + double T3 = T2 * T; + double T4 = T3 * T; + double T5 = T4 * T; + + // Solve for coefficients using boundary conditions + // p(0) = p0, p(T) = pf + // v(0) = v0, v(T) = vf + // a(0) = a0, a(T) = af + Eigen::Matrix3d A; + A << T3, T4, T5, + 3 * T2, 4 * T3, 5 * T4, + 6 * T, 12 * T2, 20 * T3; + + Eigen::Vector3d b; + b << pf - p0 - v0 * T - 0.5 * a0 * T2, + vf - v0 - a0 * T, + af - a0; + + Eigen::Vector3d x = A.colPivHouseholderQr().solve(b); + + coef_ << p0, v0, 0.5 * a0, x[0], x[1], x[2]; +} + +double QuinticPolynomial::evaluate(double t) const { + double t2 = t * t; + double t3 = t2 * t; + double t4 = t3 * t; + double t5 = t4 * t; + return coef_[0] + coef_[1] * t + coef_[2] * t2 + coef_[3] * t3 + coef_[4] * t4 + coef_[5] * t5; +} + +double QuinticPolynomial::evaluateVelocity(double t) const { + double t2 = t * t; + double t3 = t2 * t; + double t4 = t3 * t; + return coef_[1] + 2 * coef_[2] * t + 3 * coef_[3] * t2 + 4 * coef_[4] * t3 + 5 * coef_[5] * t4; +} + +double QuinticPolynomial::evaluateAcceleration(double t) const { + double t2 = t * t; + double t3 = t2 * t; + return 2 * coef_[2] + 6 * coef_[3] * t + 12 * coef_[4] * t2 + 20 * coef_[5] * t3; +} + +std::vector TrajectoryGenerator1D::generate( + double p0, double pf, double max_v, double max_a, double T) { + std::vector trajectory; + + // Compute quintic polynomial + poly_.compute(p0, 0, 0, pf, 0, 0, T); + + // Sample trajectory + const int num_samples = static_cast(T * 100); // 100 Hz + for (int i = 0; i <= num_samples; ++i) { + double t = static_cast(i) / num_samples * T; + TrajectoryPoint1D pt; + pt.t = t; + pt.p = poly_.evaluate(t); + pt.v = poly_.evaluateVelocity(t); + pt.a = poly_.evaluateAcceleration(t); + trajectory.push_back(pt); + } + + return trajectory; +} + +double TrajectoryGenerator1D::generateMinTime( + double p0, double pf, double max_v, double max_a) { + double d = std::abs(pf - p0); + + // Minimum time for bang-bang acceleration profile + // t_acc = max_v / max_a + // d_acc = 0.5 * max_a * t_acc^2 = max_v^2 / (2 * max_a) + // If distance < 2 * d_acc, we're limited by acceleration + // If distance >= 2 * d_acc, we can reach max_v + + double t_acc = max_v / max_a; + double d_acc = 0.5 * max_a * t_acc * t_acc; // Distance for acceleration + double d_total = 2 * d_acc; // Distance for accel + decel + + if (d <= d_total) { + // Triangle profile: can't reach max_v + // d = 0.5 * max_a * t^2 for triangle + T_ = std::sqrt(4 * d / max_a); + } else { + // Trapezoidal profile: reach max_v + double d_cruise = d - d_total; + double t_cruise = d_cruise / max_v; + T_ = 2 * t_acc + t_cruise; + } + + return T_; +} + +std::vector GimbalTrajectoryGenerator::generate( + double yaw0, double pitch0, double yawf, double pitchf, + double max_yaw_rate, double max_pitch_rate, + double max_yaw_acc, double max_pitch_acc, double T) { + + std::vector trajectory; + + // Compute quintic polynomials for yaw and pitch + poly_yaw_.compute(yaw0, 0, 0, yawf, 0, 0, T); + poly_pitch_.compute(pitch0, 0, 0, pitchf, 0, 0, T); + + // Sample trajectory + const int num_samples = static_cast(T * 100); // 100 Hz + for (int i = 0; i <= num_samples; ++i) { + double t = static_cast(i) / num_samples * T; + TrajectoryPoint2D pt; + pt.t = t; + pt.yaw = poly_yaw_.evaluate(t); + pt.yaw_vel = poly_yaw_.evaluateVelocity(t); + pt.yaw_acc = poly_yaw_.evaluateAcceleration(t); + pt.pitch = poly_pitch_.evaluate(t); + pt.pitch_vel = poly_pitch_.evaluateVelocity(t); + pt.pitch_acc = poly_pitch_.evaluateAcceleration(t); + trajectory.push_back(pt); + } + + return trajectory; +} + +double GimbalTrajectoryGenerator::generateMinTime( + double yaw0, double pitch0, double yawf, double pitchf, + double max_yaw_rate, double max_pitch_rate, + double max_yaw_acc, double max_pitch_acc) { + + TrajectoryGenerator1D yaw_gen, pitch_gen; + + double T_yaw = yaw_gen.generateMinTime(yaw0, yawf, max_yaw_rate, max_yaw_acc); + double T_pitch = pitch_gen.generateMinTime(pitch0, pitchf, max_pitch_rate, max_pitch_acc); + + // Use the maximum of the two times to ensure both constraints are satisfied + T_ = std::max(T_yaw, T_pitch); + + return T_; +} + +} // namespace rm_tinympc diff --git a/src/rm_utils/include/rm_utils/math/error_state_kalman_filter.hpp b/src/rm_utils/include/rm_utils/math/error_state_kalman_filter.hpp new file mode 100644 index 0000000..c0b6297 --- /dev/null +++ b/src/rm_utils/include/rm_utils/math/error_state_kalman_filter.hpp @@ -0,0 +1,179 @@ +// Copyright (C) FYT Vision Group. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ERROR_STATE_KALMAN_FILTER_HPP_ +#define ERROR_STATE_KALMAN_FILTER_HPP_ + +#include +#include + +namespace fyt { + +template +class ErrorStateKalmanFilter { +public: + using StateVector = Eigen::Vector; + using MeasurementVector = Eigen::Vector; + using StateMatrix = Eigen::Matrix; + using MeasurementMatrix = Eigen::Matrix; + using CrossCovarianceMatrix = Eigen::Matrix; + using JacobianMatrix = Eigen::Matrix; + + ErrorStateKalmanFilter() = default; + + void init(const StateVector& initial_state, + const StateMatrix& initial_covariance, + const MeasurementMatrix& measurement_noise, + const StateMatrix& process_noise) { + state_ = initial_state; + covariance_ = initial_covariance; + R_ = measurement_noise; + Q_ = process_noise; + } + + void setNoiseMatrices(const MeasurementMatrix& R, const StateMatrix& Q) { + R_ = R; + Q_ = Q; + } + + void setResidualFunc(std::function func) { + residual_func_ = func; + } + + void setInjectFunc(std::function func) { + inject_func_ = func; + } + + StateVector predict(double dt) { + PredictModel predict_model(dt); + + // State prediction using the motion model + StateVector predicted_state = state_; + predict_model(state_.data(), predicted_state.data()); + + // Jacobian of the motion model w.r.t. state + JacobianMatrix F = JacobianMatrix::Identity(); + // This is a simplified Jacobian - for complex models, numerical differentiation + // or analytical Jacobians should be used + // For now, we use identity as the motion is linear in velocity terms + + // Covariance prediction + covariance_ = F * covariance_ * F.transpose() + Q_; + + state_ = predicted_state; + return state_; + } + + StateVector update(const MeasurementVector& measurement) { + MeasurementVector z_predicted; + MeasureModel measure_model; + measure_model(state_.data(), z_predicted.data()); + + // Innovation (measurement residual) + MeasurementVector y; + if (residual_func_) { + y = residual_func_(measurement, z_predicted); + } else { + y = measurement - z_predicted; + } + + // Innovation covariance + MeasurementMatrix S = measurement_jacobian_ * covariance_ * measurement_jacobian_.transpose() + R_; + + // Kalman gain + CrossCovarianceMatrix Pxz = covariance_ * measurement_jacobian_.transpose(); + Eigen::Matrix K = Pxz * S.inverse(); + + // State update + StateVector delta_x = K * y; + state_ = state_ + delta_x; + + // Covariance update (Joseph form for numerical stability) + Eigen::Matrix I_KH = JacobianMatrix::Identity() - K * measurement_jacobian_; + covariance_ = I_KH * covariance_ * I_KH.transpose() + K * R_ * K.transpose(); + + return state_; + } + + StateVector updateIterative(const MeasurementVector& measurement, int max_iterations = 10) { + StateVector updated_state = state_; + MeasurementMatrix H = measurement_jacobian_; + + for (int i = 0; i < max_iterations; ++i) { + MeasurementVector z_predicted; + MeasureModel measure_model; + measure_model(updated_state.data(), z_predicted.data()); + + MeasurementVector y; + if (residual_func_) { + y = residual_func_(measurement, z_predicted); + } else { + y = measurement - z_predicted; + } + + MeasurementMatrix S = H * covariance_ * H.transpose() + R_; + Eigen::Matrix K = covariance_ * H.transpose() * S.inverse(); + + StateVector delta_x = K * y; + updated_state = updated_state + delta_x; + + if (delta_x.norm() < 1e-6) { + break; + } + } + + if (inject_func_) { + inject_func_(state_, updated_state - state_); + } + state_ = updated_state; + + return state_; + } + + void setState(const StateVector& state) { + state_ = state; + } + + StateVector getState() const { + return state_; + } + + StateMatrix getCovariance() const { + return covariance_; + } + + double getNIS() const { + return last_nis_; + } + + void setMeasurementJacobian(const MeasurementMatrix& H) { + measurement_jacobian_ = H; + } + +private: + StateVector state_; + StateMatrix covariance_; + MeasurementMatrix R_; // Measurement noise covariance + StateMatrix Q_; // Process noise covariance + MeasurementMatrix measurement_jacobian_; + double last_nis_ = 0; + + std::function residual_func_; + std::function inject_func_; +}; + +} // namespace fyt + +#endif // ERROR_STATE_KALMAN_FILTER_HPP_ diff --git a/src/rm_utils/include/rm_utils/math/extended_kalman_filter.hpp b/src/rm_utils/include/rm_utils/math/extended_kalman_filter.hpp index 5f2bb3f..3751762 100644 --- a/src/rm_utils/include/rm_utils/math/extended_kalman_filter.hpp +++ b/src/rm_utils/include/rm_utils/math/extended_kalman_filter.hpp @@ -48,6 +48,8 @@ public: using UpdateQFunc = std::function; using UpdateRFunc = std::function; + // Residual function for handling angle wraparound (e.g., -pi~pi) + using ResidualFunc = std::function; explicit ExtendedKalmanFilter(const PredicFunc &f, const MeasureFunc &h, @@ -66,6 +68,14 @@ public: void setMeasureFunc(const MeasureFunc &h) noexcept { this->h = h; } + // Set residual function for handling angle wraparound + void setResidualFunc(const ResidualFunc &residual_func) noexcept { + residual_func_ = residual_func; + } + + // Check if residual function is set + bool hasResidualFunc() const noexcept { return residual_func_ != nullptr; } + // Compute a predicted state MatrixX1 predict() noexcept { ceres::Jet x_e_jet[N_X]; @@ -106,14 +116,66 @@ public: } R = update_R(z); + + // Compute innovation (measurement residual) + // Use residual_func if set to handle angle wraparound + if (residual_func_ != nullptr) { + innovation_ = residual_func_(z, z_pri_); + } else { + innovation_ = z - z_pri_; + } + S_ = H * P_pri * H.transpose() + R; K = P_pri * H.transpose() * S_.inverse(); - innovation_ = z - z_pri_; + x_post = x_post + K * innovation_; P_post = (MatrixXX::Identity() - K * H) * P_pri; return x_post; } + // Update with iterative measurement update for better angle handling + MatrixX1 updateIterative(const MatrixZ1 &z, int max_iterations = 3) noexcept { + MatrixX1 updated_state = x_post; + + for (int iter = 0; iter < max_iterations; ++iter) { + ceres::Jet x_jet[N_X]; + for (int i = 0; i < N_X; i++) { + x_jet[i].a = updated_state[i]; + x_jet[i].v[i] = 1; + } + ceres::Jet z_jet[N_Z]; + h(x_jet, z_jet); + + MatrixZ1 z_pred; + for (int i = 0; i < N_Z; i++) { + z_pred[i] = z_jet[i].a; + } + + R = update_R(z); + + MatrixZ1 innovation; + if (residual_func_ != nullptr) { + innovation = residual_func_(z, z_pred); + } else { + innovation = z - z_pred; + } + + MatrixZZ S = H * P_pri * H.transpose() + R; + MatrixXZ K = P_pri * H.transpose() * S.inverse(); + + MatrixX1 delta_x = K * innovation; + updated_state = updated_state + delta_x; + + // Check convergence + if (delta_x.norm() < 1e-6) { + break; + } + } + + x_post = updated_state; + return x_post; + } + // Get innovation (z - z_pri) const MatrixZ1 & getInnovation() const { return innovation_; } @@ -126,6 +188,15 @@ public: return innovation_.transpose() * S_.inverse() * innovation_; } + // Get Kalman gain + const MatrixXZ & getKalmanGain() const { return K; } + + // Get posterior state + const MatrixX1 & getState() const { return x_post; } + + // Get posterior covariance + const MatrixXX & getCovariance() const { return P_post; } + private: // Process nonlinear vector function PredicFunc f; @@ -139,6 +210,8 @@ private: // Measurement noise covariance matrix UpdateRFunc update_R; MatrixZZ R; + // Residual function for angle wraparound handling + ResidualFunc residual_func_; // Priori error estimate covariance matrix MatrixXX P_pri;