This commit is contained in:
cyy_mac
2026-03-28 07:56:28 +08:00
parent b2507afea9
commit f2e26c3e1c
29 changed files with 3024 additions and 51 deletions

BIN
.DS_Store vendored

Binary file not shown.

View File

@@ -0,0 +1,87 @@
cmake_minimum_required(VERSION 3.8)
project(armor_mpc_solver)
if(CMAKE_CXX_STANDARD GREATER_RANGE 17)
set(CMAKE_CXX_STANDARD 17)
else()
set(CMAKE_CXX_STANDARD 17)
endif()
set(CMAKE_CXX_STANDARD_REQUIRED ON)
# Find dependencies
find_package(ament_cmake REQUIRED)
find_package(rclcpp REQUIRED)
find_package(rm_interfaces REQUIRED)
find_package(rm_utils REQUIRED)
find_package(rm_tinympc REQUIRED)
find_package(Eigen3 REQUIRED)
find_package(tf2 REQUIRED)
find_package(tf2_ros REQUIRED)
find_package(visualization_msgs REQUIRED)
find_package(geometry_msgs REQUIRED)
find_package(std_msgs REQUIRED)
include_directories(
include
${EIGEN3_INCLUDE_DIRS}
)
set(${PROJECT_NAME}_HEADERS
include/armor_mpc_solver/solver.hpp
include/armor_mpc_solver/solver_comparer.hpp
include/armor_mpc_solver/solver_comparison_node.hpp
)
add_library(${PROJECT_NAME} SHARED
src/solver.cpp
src/solver_comparer.cpp
)
add_executable(${PROJECT_NAME}_node src/solver_comparison_node.cpp)
ament_target_dependencies(${PROJECT_NAME}_node
rclcpp
rm_interfaces
rm_utils
rm_tinympc
Eigen3
tf2
tf2_ros
visualization_msgs
geometry_msgs
std_msgs
${PROJECT_NAME}
)
ament_target_dependencies(${PROJECT_NAME}
rclcpp
rm_interfaces
rm_utils
rm_tinympc
Eigen3
tf2
tf2_ros
visualization_msgs
geometry_msgs
std_msgs
)
target_include_directories(${PROJECT_NAME} PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<INSTALL_INTERFACE:include>
)
install(TARGETS ${PROJECT_NAME} ${PROJECT_NAME}_node
ARCHIVE DESTINATION lib
LIBRARY DESTINATION lib
RUNTIME DESTINATION bin
)
install(DIRECTORY include/
DESTINATION include
)
ament_export_include_directories(include)
ament_export_libraries(${PROJECT_NAME})
ament_package()

View File

@@ -0,0 +1,311 @@
// Copyright (C) FYT Vision Group. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef ARMOR_MPC_SOLVER_SOLVER_HPP_
#define ARMOR_MPC_SOLVER_SOLVER_HPP_
#include <memory>
#include <vector>
#include <Eigen/Eigen>
#include <rclcpp/rclcpp.hpp>
#include <tf2_ros/buffer.h>
#include <rm_interfaces/msg/target.hpp>
#include <rm_interfaces/msg/gimbal_cmd.hpp>
#include "rm_tinympc/types.hpp"
#include "rm_tinympc/admm.hpp"
#include "rm_tinympc/trajectory.hpp"
#include "rm_utils/math/trajectory_compensator.hpp"
#include "rm_utils/math/manual_compensator.hpp"
namespace fyt::auto_aim {
/**
* Trajectory point for visualization and debugging
*/
struct TrajPoint {
double t;
double yaw;
double pitch;
double yaw_vel;
double pitch_vel;
};
/**
* Quintic segment for polynomial trajectory
* Provides smooth position/velocity/acceleration trajectories
*/
struct QuinticSegment {
double a0, a1, a2, a3, a4, a5; // Coefficients: p(t) = a0 + a1*t + a2*t^2 + a3*t^3 + a4*t^4 + a5*t^5
double position(double t) const;
double velocity(double t) const;
double acceleration(double t) const;
static QuinticSegment fromBoundaryConditions(
double p0, double v0, double a0,
double p1, double v1, double a1,
double duration);
};
/**
* LimitTrajectory - Acceleration-constrained trajectory generation
* Ensures gimbal acceleration stays within limits
*/
class LimitTrajectory {
public:
struct TrajectoryPoint {
double t;
double yaw;
double pitch;
double v_yaw;
double v_pitch;
double a_yaw;
double a_pitch;
};
LimitTrajectory(double max_yaw_acc, double max_pitch_acc, double dt);
std::vector<TrajectoryPoint> generate(
double start_yaw, double start_pitch,
double target_yaw, double target_pitch,
double max_vel_yaw, double max_vel_pitch);
private:
double max_yaw_acc_;
double max_pitch_acc_;
double dt_;
double trapezoidalTime(double dist, double max_vel, double max_acc);
QuinticSegment quinticPlan(double p0, double v0, double a0, double p1, double v1, double a1, double T);
};
/**
* Armor selection result
*/
struct ArmorSelectResult {
int armor_id; // Selected armor index (0-3)
double coming_angle; // Angle when approaching
double leaving_angle; // Angle when leaving
bool is_switching; // Whether switching armor
};
/**
* Gimbal MPC Solver using TinyMPC ADMM
*
* State: [yaw, pitch, yaw_rate, pitch_rate]
* Input: [yaw_acc, pitch_acc]
*
* Cost: sum ||state - ref||_Q + ||input||_R
* Constraints: yaw_rate, pitch_rate limits
*/
class MpcSolver {
public:
MpcSolver(std::weak_ptr<rclcpp::Node> n);
rm_interfaces::msg::GimbalCmd solve(
const rm_interfaces::msg::Target& target,
const rclcpp::Time& current_time,
std::shared_ptr<tf2_ros::Buffer> tf2_buffer);
// Trajectory getters for visualization
std::vector<TrajPoint> getMpcTrajectory() const { return mpc_trajectory_; }
std::vector<TrajPoint> getReferenceTrajectory() const { return reference_trajectory_; }
std::vector<TrajPoint> getOptimalTrajectory() const { return optimal_trajectory_; }
void setBulletSpeed(double bullet_speed) { bullet_speed_ = bullet_speed; }
// Limit trajectory for acceleration-constrained paths
std::vector<LimitTrajectory::TrajectoryPoint> getLimitTrajectory() const { return limit_trajectory_; }
private:
// State and input dimensions
static constexpr int N_X = 4; // [yaw, pitch, yaw_rate, pitch_rate]
static constexpr int N_U = 2; // [yaw_acc, pitch_acc]
// MPC parameters
double bullet_speed_;
double gravity_;
double prediction_horizon_;
double dt_;
// Gimbal rate limits (rad/s)
double max_yaw_rate_;
double max_pitch_rate_;
// Gimbal acceleration limits (rad/s^2)
double max_yaw_acc_;
double max_pitch_acc_;
// Cost weights - separated for yaw and pitch
// Q = [position_cost, velocity_cost]
Eigen::Vector2d Q_yaw_;
Eigen::Vector2d Q_pitch_;
Eigen::Vector2d R_yaw_;
Eigen::Vector2d R_pitch_;
Eigen::Vector2d Qf_yaw_;
Eigen::Vector2d Qf_pitch_;
// Legacy combined cost weights (for compatibility)
Eigen::Vector4d Q_; // State cost [yaw, pitch, yaw_rate, pitch_rate]
Eigen::Vector2d R_; // Input cost [yaw_acc, pitch_acc]
Eigen::Vector4d Qf_; // Terminal cost
// Fire decision parameters
double delay_enable_fire_error_;
double yaw_limit_deg_;
double shooting_range_h_;
double shooting_range_small_w_;
double shooting_range_big_w_;
double min_enable_pitch_deg_;
double min_enable_yaw_deg_;
double comming_angle_deg_;
double leaving_angle_deg_;
// Armor selection
double coming_angle_thresh_;
double leaving_angle_thresh_;
// ADMM MPC solver - separated for yaw and pitch
std::unique_ptr<rm_tinympc::ADMM> admm_solver_yaw_;
std::unique_ptr<rm_tinympc::ADMM> admm_solver_pitch_;
// Trajectory generators
rm_tinympc::GimbalTrajectoryGenerator trajectory_generator_;
// Trajectories for visualization
std::vector<TrajPoint> mpc_trajectory_;
std::vector<TrajPoint> reference_trajectory_;
std::vector<TrajPoint> optimal_trajectory_;
std::vector<LimitTrajectory::TrajectoryPoint> limit_trajectory_;
rclcpp::Time last_time_;
// Current gimbal state [yaw, pitch]
Eigen::Vector2d current_gimbal_state_;
// Current gimbal velocity [yaw_rate, pitch_rate]
Eigen::Vector2d current_gimbal_velocity_;
// Control delay compensation (seconds)
double control_delay_;
// Limit trajectory generator
std::unique_ptr<LimitTrajectory> limit_trajectory_generator_;
// Trajectory compensator for bullet arc compensation
std::unique_ptr<fyt::TrajectoryCompensator> trajectory_compensator_;
// Manual compensator for angle offset correction
std::unique_ptr<fyt::ManualCompensator> manual_compensator_;
/**
* Initialize ADMM solver with problem matrices
*/
void initADMM();
/**
* Initialize trajectory compensator
*/
void initTrajectoryCompensator(const std::string& compensator_type);
/**
* Compute reference trajectory (target states over horizon)
*/
Eigen::MatrixXd computeReferenceTrajectory(
const Eigen::Vector3d& target_pos,
const Eigen::Vector3d& target_vel,
double target_yaw,
double v_yaw,
double flying_time);
/**
* Solve MPC and get optimal control
*/
Eigen::Vector2d solveMPC(const Eigen::Vector4d& x0, const Eigen::MatrixXd& xref);
/**
* Compute gimbal command from optimal control
*/
Eigen::Vector2d computeGimbalCommandFromControl(const Eigen::Vector2d& u);
/**
* Compute flying time with trajectory compensation
*/
double computeFlyingTime(const Eigen::Vector3d& target_pos);
/**
* Predict target position at flying_time
*/
Eigen::Vector3d predictTargetPosition(
const Eigen::Vector3d& pos,
const Eigen::Vector3d& vel,
double dt);
/**
* Convert 2D gimbal state to gimbal command message
*/
rm_interfaces::msg::GimbalCmd toGimbalCmd(
const Eigen::Vector2d& gimbal_angles,
const Eigen::Vector3d& target_pos,
const rm_interfaces::msg::Target& target);
/**
* Build MPC trajectory for visualization (separated yaw/pitch)
*/
void buildMpcVisualizationTrajectory(
const Eigen::Vector4d& x0,
const Eigen::MatrixXd& x_traj_yaw,
const Eigen::MatrixXd& x_traj_pitch,
const Eigen::MatrixXd& u_traj_yaw,
const Eigen::MatrixXd& u_traj_pitch);
/**
* Build reference trajectory for visualization
*/
void buildReferenceVisualizationTrajectory(
const Eigen::MatrixXd& xref);
/**
* Select best armor based on coming/leaving angle
*/
ArmorSelectResult selectArmor(
const rm_interfaces::msg::Target& target,
const rclcpp::Time& current_time);
/**
* Check if can fire at given time considering control delay
*/
bool canFireAtTime(
const rm_interfaces::msg::Target& target,
double time_to_target,
const rclcpp::Time& current_time);
/**
* Compute trajectory with limit constraints
*/
std::vector<LimitTrajectory::TrajectoryPoint> computeLimitTrajectory(
double target_yaw, double target_pitch);
/**
* Compute time-optimal trajectory with acceleration constraints
*/
void computeAccelConstrainedTrajectory(
const Eigen::Vector2d& start,
const Eigen::Vector2d& target,
const Eigen::Vector2d& max_vel,
const Eigen::Vector2d& max_acc);
};
} // namespace fyt::auto_aim
#endif // ARMOR_MPC_SOLVER_SOLVER_HPP_

View File

@@ -0,0 +1,70 @@
// Copyright (C) FYT Vision Group. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef ARMOR_MPC_SOLVER_SOLVER_COMPARER_HPP_
#define ARMOR_MPC_SOLVER_SOLVER_COMPARER_HPP_
#include <memory>
#include <vector>
#include <Eigen/Eigen>
#include <rclcpp/rclcpp.hpp>
#include <tf2_ros/buffer.h>
#include <tf2_ros/transform_listener.h>
#include <visualization_msgs/msg/marker.hpp>
#include <visualization_msgs/msg/marker_array.hpp>
#include <rm_interfaces/msg/target.hpp>
#include <rm_interfaces/msg/gimbal_cmd.hpp>
#include <geometry_msgs/msg/point.hpp>
#include "armor_solver/armor_solver.hpp"
#include "armor_mpc_solver/solver.hpp"
namespace fyt::auto_aim {
class SolverComparer {
public:
SolverComparer(std::weak_ptr<rclcpp::Node> n);
void init();
void update(const rm_interfaces::msg::Target::SharedPtr target_msg);
void publishComparisonMarkers();
private:
std::weak_ptr<rclcpp::Node> node_;
std::unique_ptr<armor_solver::Solver> original_solver_;
std::unique_ptr<MpcSolver> mpc_solver_;
rclcpp::Publisher<visualization_msgs::msg::MarkerArray>::SharedPtr trajectory_pub_;
rclcpp::Publisher<rm_interfaces::msg::GimbalCmd>::SharedPtr mpc_gimbal_pub_;
std::shared_ptr<tf2_ros::Buffer> tf2_buffer_;
std::shared_ptr<tf2_ros::TransformListener> tf2_listener_;
std::vector<TrajPoint> last_mpc_trajectory_;
std::vector<TrajPoint> last_reference_trajectory_;
std::vector<LimitTrajectory::TrajectoryPoint> last_limit_trajectory_;
int marker_id_;
std::string frame_id_;
rm_interfaces::msg::GimbalCmd last_original_cmd_;
rm_interfaces::msg::GimbalCmd last_mpc_cmd_;
bool use_mpc_;
};
} // namespace fyt::auto_aim
#endif // ARMOR_MPC_SOLVER_SOLVER_COMPARER_HPP_

View File

@@ -0,0 +1,43 @@
// Copyright (C) FYT Vision Group. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef ARMOR_MPC_SOLVER_SOLVER_COMPARISON_NODE_HPP_
#define ARMOR_MPC_SOLVER_SOLVER_COMPARISON_NODE_HPP_
#include <rclcpp/rclcpp.hpp>
#include <rm_interfaces/msg/target.hpp>
#include <std_msgs/msg/bool.hpp>
#include "armor_mpc_solver/solver_comparer.hpp"
namespace fyt::auto_aim {
class SolverComparisonNode : public rclcpp::Node {
public:
SolverComparisonNode();
private:
void targetCallback(const rm_interfaces::msg::Target::SharedPtr msg);
void toggleCallback(const std_msgs::msg::Bool::SharedPtr msg);
std::unique_ptr<SolverComparer> comparer_;
rclcpp::Subscription<rm_interfaces::msg::Target>::SharedPtr target_sub_;
rclcpp::Subscription<std_msgs::msg::Bool>::SharedPtr toggle_sub_;
rclcpp::TimerBase::SharedPtr timer_;
};
} // namespace fyt::auto_aim
#endif // ARMOR_MPC_SOLVER_SOLVER_COMPARISON_NODE_HPP_

View File

@@ -0,0 +1,26 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>armor_mpc_solver</name>
<version>0.1.0</version>
<description>MPC-based armor solver with TinyMPC</description>
<maintainer email="chenyouyuan@foxmail.com">Chen Youyuan</maintainer>
<license>Apache-2.0</license>
<buildtool_depend>ament_cmake</buildtool_depend>
<depend>rclcpp</depend>
<depend>rm_interfaces</depend>
<depend>rm_utils</depend>
<depend>rm_tinympc</depend>
<depend>Eigen3</depend>
<depend>tf2</depend>
<depend>tf2_ros</depend>
<depend>visualization_msgs</depend>
<depend>geometry_msgs</depend>
<depend>std_msgs</depend>
<export>
<build_type>ament_cmake</build_type>
</export>
</package>

View File

@@ -0,0 +1,667 @@
// Copyright (C) FYT Vision Group. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "armor_mpc_solver/solver.hpp"
#include "rm_utils/logger/log.hpp"
#include <angles/angles.h>
#include <cmath>
#include <string>
namespace fyt::auto_aim {
// ============== QuinticSegment Implementation ==============
double QuinticSegment::position(double t) const {
return a0 + a1 * t + a2 * t * t + a3 * t * t * t + a4 * t * t * t * t + a5 * t * t * t * t * t;
}
double QuinticSegment::velocity(double t) const {
return a1 + 2 * a2 * t + 3 * a3 * t * t + 4 * a4 * t * t * t + 5 * a5 * t * t * t * t;
}
double QuinticSegment::acceleration(double t) const {
return 2 * a2 + 6 * a3 * t + 12 * a4 * t * t + 20 * a5 * t * t * t;
}
QuinticSegment QuinticSegment::fromBoundaryConditions(
double p0, double v0, double a0,
double p1, double v1, double a1,
double T) {
QuinticSegment seg;
double T2 = T * T;
double T3 = T2 * T;
double T4 = T3 * T;
double T5 = T4 * T;
// Solve for coefficients using boundary conditions
seg.a0 = p0;
seg.a1 = v0;
seg.a2 = a0 / 2.0;
seg.a3 = (20 * (p1 - p0) - (8 * v1 + 12 * v0) * T - (3 * a0 - a1) * T2) / (2 * T3);
seg.a4 = (30 * (p0 - p1) + (14 * v1 + 16 * v0) * T + (3 * a0 - 2 * a1) * T2) / (2 * T4);
seg.a5 = (12 * (p1 - p0) - (6 * v1 + 6 * v0) * T - (a0 - a1) * T2) / (2 * T5);
return seg;
}
// ============== LimitTrajectory Implementation ==============
LimitTrajectory::LimitTrajectory(double max_yaw_acc, double max_pitch_acc, double dt)
: max_yaw_acc_(max_yaw_acc), max_pitch_acc_(max_pitch_acc), dt_(dt) {}
double LimitTrajectory::trapezoidalTime(double dist, double max_vel, double max_acc) {
// Time for trapezoidal velocity profile
// v_max = sqrt(dist * acc) if dist > v_max^2 / acc
double t_acc = max_vel / max_acc;
double dist_at_t_acc = max_acc * t_acc * t_acc;
if (dist <= dist_at_t_acc) {
// Triangle profile: accelerate then decelerate
return 2.0 * std::sqrt(dist / max_acc);
} else {
// Trapezoidal: accel + coast + decel
double t_coast = (dist - dist_at_t_acc) / max_vel;
return 2.0 * t_acc + t_coast;
}
}
QuinticSegment LimitTrajectory::quinticPlan(
double p0, double v0, double a0,
double p1, double v1, double a1,
double T) {
return QuinticSegment::fromBoundaryConditions(p0, v0, a0, p1, v1, a1, T);
}
std::vector<LimitTrajectory::TrajectoryPoint> LimitTrajectory::generate(
double start_yaw, double start_pitch,
double target_yaw, double target_pitch,
double max_vel_yaw, double max_vel_pitch) {
std::vector<TrajectoryPoint> trajectory;
// Compute angle differences
double dyaw = angles::shortest_angular_distance(start_yaw, target_yaw);
double dpitch = target_pitch - start_pitch;
// Time for each axis using trapezoidal profile
double time_yaw = trapezoidalTime(std::abs(dyaw), max_vel_yaw, max_yaw_acc_);
double time_pitch = trapezoidalTime(std::abs(dpitch), max_vel_pitch, max_pitch_acc_);
// Use longer time and synchronize
double T = std::max(time_yaw, time_pitch);
T = std::max(T, 0.1); // Minimum time
// Generate quintic trajectories
QuinticSegment yaw_traj = quinticPlan(start_yaw, 0, 0, target_yaw, 0, 0, T);
QuinticSegment pitch_traj = quinticPlan(start_pitch, 0, 0, target_pitch, 0, 0, T);
// Sample trajectory
int num_steps = static_cast<int>(T / dt_) + 1;
for (int i = 0; i <= num_steps; ++i) {
double t = i * dt_;
if (t > T) t = T;
TrajectoryPoint pt;
pt.t = t;
pt.yaw = yaw_traj.position(t);
pt.pitch = pitch_traj.position(t);
pt.v_yaw = yaw_traj.velocity(t);
pt.v_pitch = pitch_traj.velocity(t);
pt.a_yaw = yaw_traj.acceleration(t);
pt.a_pitch = pitch_traj.acceleration(t);
trajectory.push_back(pt);
}
return trajectory;
}
// ============== MpcSolver Implementation ==============
MpcSolver::MpcSolver(std::weak_ptr<rclcpp::Node> n)
: node_(n), bullet_speed_(20.0), gravity_(9.8), prediction_horizon_(2.0), dt_(0.004),
max_yaw_rate_(6.0), max_pitch_rate_(6.0), max_yaw_acc_(40.0), max_pitch_acc_(25.0),
control_delay_(0.2), delay_enable_fire_error_(0.0035),
yaw_limit_deg_(60.0), shooting_range_h_(0.12), shooting_range_small_w_(0.12),
shooting_range_big_w_(0.24), min_enable_pitch_deg_(0.25), min_enable_yaw_deg_(0.25),
comming_angle_deg_(60.0), leaving_angle_deg_(20.0) {
auto node = node_.lock();
if (node) {
// Basic parameters (matching wust_vision)
bullet_speed_ = node->declare_parameter("mpc.bullet_speed", 20.0);
gravity_ = node->declare_parameter("mpc.gravity", 9.8);
prediction_horizon_ = node->declare_parameter("mpc.sample_total_time", 2.0);
int sample_horizon = node->declare_parameter("mpc.sample_horizon", 500);
dt_ = prediction_horizon_ / sample_horizon;
max_yaw_acc_ = node->declare_parameter("mpc.max_yaw_acc", 40.0);
max_pitch_acc_ = node->declare_parameter("mpc.max_pitch_acc", 25.0);
control_delay_ = node->declare_parameter("mpc.control_delay", 0.2);
delay_enable_fire_error_ = node->declare_parameter("mpc.delay_enable_fire_error", 0.0035);
// Fire decision parameters
yaw_limit_deg_ = node->declare_parameter("mpc.yaw_limit_deg", 60.0);
shooting_range_h_ = node->declare_parameter("mpc.shooting_range_h", 0.12);
shooting_range_small_w_ = node->declare_parameter("mpc.shooting_range_small_w", 0.12);
shooting_range_big_w_ = node->declare_parameter("mpc.shooting_range_big_w", 0.24);
min_enable_pitch_deg_ = node->declare_parameter("mpc.min_enable_pitch_deg", 0.25);
min_enable_yaw_deg_ = node->declare_parameter("mpc.min_enable_yaw_deg", 0.25);
comming_angle_deg_ = node->declare_parameter("mpc.comming_angle", 60.0);
leaving_angle_deg_ = node->declare_parameter("mpc.leaving_angle", 20.0);
// Cost weights - separated for yaw and pitch (matching wust_vision)
// Default: Q_yaw = [7e6, 0], R_yaw = [3.0]
auto q_yaw_vec = node->declare_parameter("mpc.Q_yaw", std::vector<double>{7e6, 0.0});
auto r_yaw_vec = node->declare_parameter("mpc.R_yaw", std::vector<double>{3.0});
auto q_pitch_vec = node->declare_parameter("mpc.Q_pitch", std::vector<double>{7e6, 0.0});
auto r_pitch_vec = node->declare_parameter("mpc.R_pitch", std::vector<double>{3.0});
Q_yaw_ << q_yaw_vec[0], q_yaw_vec[1];
R_yaw_ << r_yaw_vec[0];
Q_pitch_ << q_pitch_vec[0], q_pitch_vec[1];
R_pitch_ << r_pitch_vec[0];
// Terminal cost is 2x stage cost
Qf_yaw_ = 2.0 * Q_yaw_;
Qf_pitch_ = 2.0 * Q_pitch_;
FYT_INFO("armor_mpc_solver", "MPC Solver initialized (wust_vision params):");
FYT_INFO("armor_mpc_solver", " prediction_horizon={:.3f}s, dt={:.4f}s, horizon={}",
prediction_horizon_, dt_, sample_horizon);
FYT_INFO("armor_mpc_solver", " max_yaw_acc={:.2f} rad/s^2, max_pitch_acc={:.2f} rad/s^2",
max_yaw_acc_, max_pitch_acc_);
FYT_INFO("armor_mpc_solver", " control_delay={:.3f}s, fire_error={:.4f}",
control_delay_, delay_enable_fire_error_);
FYT_INFO("armor_mpc_solver", " Q_yaw=[{:.1e}, {:.1e}], R_yaw=[{:.1e}]",
Q_yaw_(0), Q_yaw_(1), R_yaw_(0));
FYT_INFO("armor_mpc_solver", " Q_pitch=[{:.1e}, {:.1e}], R_pitch=[{:.1e}]",
Q_pitch_(0), Q_pitch_(1), R_pitch_(0));
// Initialize trajectory compensator
std::string compensator_type = node->declare_parameter("mpc.trajectory_compensator", "resistance");
initTrajectoryCompensator(compensator_type);
// Initialize manual compensator for angle offset
manual_compensator_ = std::make_unique<fyt::ManualCompensator>();
auto angle_offset = node->declare_parameter("mpc.trajectory_offset", std::vector<std::string>{});
if (!manual_compensator_->updateMapFlow(angle_offset)) {
FYT_DEBUG("armor_mpc_solver", "Manual compensator update skipped (empty config)");
}
}
node.reset();
initADMM();
current_gimbal_state_ << 0.0, 0.0;
current_gimbal_velocity_ << 0.0, 0.0;
// Initialize limit trajectory generator
limit_trajectory_generator_ = std::make_unique<LimitTrajectory>(
max_yaw_acc_, max_pitch_acc_, dt_);
}
void MpcSolver::initADMM() {
int horizon = static_cast<int>(prediction_horizon_ / dt_);
// Separate MPC for yaw: State [yaw, yaw_rate], Input [yaw_acc]
// Separate MPC for pitch: State [pitch, pitch_rate], Input [pitch_acc]
static constexpr int N_X_SEP = 2; // [angle, angular_rate]
static constexpr int N_U_SEP = 1; // [angular_acc]
admm_solver_yaw_ = std::make_unique<rm_tinympc::ADMM>();
admm_solver_pitch_ = std::make_unique<rm_tinympc::ADMM>();
admm_solver_yaw_->init(N_X_SEP, N_U_SEP, horizon);
admm_solver_pitch_->init(N_X_SEP, N_U_SEP, horizon);
// State transition for single axis: x = [angle, angular_rate]
// angle_{k+1} = angle_k + angular_rate_k * dt
// angular_rate_{k+1} = angular_rate_k + angular_acc_k * dt
Eigen::MatrixXd A_sep = Eigen::MatrixXd::Identity(N_X_SEP, N_X_SEP);
A_sep(0, 1) = dt_; // angle += angular_rate * dt
Eigen::MatrixXd B_sep = Eigen::MatrixXd::Zero(N_X_SEP, N_U_SEP);
B_sep(0, 0) = 0.5 * dt_ * dt_; // angle += 0.5 * angular_acc * dt^2
B_sep(1, 0) = dt_; // angular_rate += angular_acc * dt
// Set problem matrices for both solvers using separated cost weights
admm_solver_yaw_->setProblem(A_sep, B_sep, Q_yaw_, R_yaw_, Qf_yaw_);
admm_solver_pitch_->setProblem(A_sep, B_sep, Q_pitch_, R_pitch_, Qf_pitch_);
// Set constraints for yaw
Eigen::Vector2d x_min_yaw, x_max_yaw;
Eigen::VectorXd u_min_yaw(1), u_max_yaw(1);
x_min_yaw << -M_PI, -max_yaw_rate_;
x_max_yaw << M_PI, max_yaw_rate_;
u_min_yaw(0) = -max_yaw_acc_;
u_max_yaw(0) = max_yaw_acc_;
admm_solver_yaw_->setConstraints(x_min_yaw, x_max_yaw, u_min_yaw, u_max_yaw);
// Set constraints for pitch
Eigen::Vector2d x_min_pitch, x_max_pitch;
Eigen::VectorXd u_min_pitch(1), u_max_pitch(1);
x_min_pitch << -M_PI_2, -max_pitch_rate_;
x_max_pitch << M_PI_2, max_pitch_rate_;
u_min_pitch(0) = -max_pitch_acc_;
u_max_pitch(0) = max_pitch_acc_;
admm_solver_pitch_->setConstraints(x_min_pitch, x_max_pitch, u_min_pitch, u_max_pitch);
}
admm_solver_->setConstraints(x_min, x_max, u_min, u_max);
}
void MpcSolver::initTrajectoryCompensator(const std::string& compensator_type) {
trajectory_compensator_ = fyt::CompensatorFactory::createCompensator(compensator_type);
if (trajectory_compensator_) {
trajectory_compensator_->velocity = bullet_speed_;
trajectory_compensator_->gravity = gravity_;
FYT_INFO("armor_mpc_solver", "Trajectory compensator initialized: {}", compensator_type);
} else {
FYT_WARN("armor_mpc_solver", "Failed to create trajectory compensator, using default");
trajectory_compensator_ = std::make_unique<fyt::IdealCompensator>();
trajectory_compensator_->velocity = bullet_speed_;
trajectory_compensator_->gravity = gravity_;
}
}
rm_interfaces::msg::GimbalCmd MpcSolver::solve(
const rm_interfaces::msg::Target& target,
const rclcpp::Time& current_time,
std::shared_ptr<tf2_ros::Buffer> tf2_buffer) {
Eigen::Vector3d target_pos(target.position.x, target.position.y, target.position.z);
Eigen::Vector3d target_vel(target.velocity.x, target.velocity.y, target.velocity.z);
double target_yaw = target.yaw;
double v_yaw = target.v_yaw;
// Select armor based on coming/leaving angle
auto armor_result = selectArmor(target, current_time);
// Compute flying time with compensation
double flying_time = computeFlyingTime(target_pos);
// Predict target position at flying_time
Eigen::Vector3d predicted_pos = predictTargetPosition(target_pos, target_vel, flying_time);
double predicted_yaw = target_yaw + flying_time * v_yaw;
// Convert target to gimbal angles (yaw, pitch)
Eigen::Vector2d target_gimbal;
target_gimbal(0) = std::atan2(predicted_pos.y(), predicted_pos.x()); // yaw
// Calculate pitch with trajectory compensation
double raw_pitch = std::atan2(predicted_pos.z(), predicted_pos.head(2).norm());
if (trajectory_compensator_) {
trajectory_compensator_->compensate(predicted_pos, raw_pitch);
}
target_gimbal(1) = raw_pitch;
// Apply manual compensator angle offset
if (manual_compensator_) {
double dist = predicted_pos.head(2).norm();
auto offsets = manual_compensator_->angleHardCorrect(dist, predicted_pos.z());
if (offsets.size() >= 2) {
target_gimbal(0) += offsets[1] * M_PI / 180.0; // yaw offset
target_gimbal(1) += offsets[0] * M_PI / 180.0; // pitch offset
}
}
// Current state: [yaw, pitch, yaw_rate, pitch_rate]
Eigen::Vector4d x0 = Eigen::Vector4d::Zero();
x0(0) = current_gimbal_state_(0);
x0(1) = current_gimbal_state_(1);
x0(2) = current_gimbal_velocity_(0); // Use filtered velocity
x0(3) = current_gimbal_velocity_(1);
// Compute reference trajectory over horizon
Eigen::MatrixXd xref = computeReferenceTrajectory(
predicted_pos, target_vel, predicted_yaw, v_yaw, flying_time);
// Solve separated MPC for yaw and pitch
Eigen::Vector2d u_optimal = solveMPC(x0, xref);
// Build trajectories for visualization
Eigen::MatrixXd x_traj_yaw = admm_solver_yaw_->getStateTrajectory();
Eigen::MatrixXd x_traj_pitch = admm_solver_pitch_->getStateTrajectory();
Eigen::MatrixXd u_traj_yaw = admm_solver_yaw_->getControlSequence();
Eigen::MatrixXd u_traj_pitch = admm_solver_pitch_->getControlSequence();
buildMpcVisualizationTrajectory(x0, x_traj_yaw, x_traj_pitch, u_traj_yaw, u_traj_pitch);
buildReferenceVisualizationTrajectory(xref);
// Compute limit trajectory for comparison
computeAccelConstrainedTrajectory(current_gimbal_state_, target_gimbal,
Eigen::Vector2d(max_yaw_rate_, max_pitch_rate_),
Eigen::Vector2d(max_yaw_acc_, max_pitch_acc_));
// Update current gimbal velocity first (acceleration * dt = delta_velocity)
current_gimbal_velocity_(0) += u_optimal(0) * dt_;
current_gimbal_velocity_(1) += u_optimal(1) * dt_;
// Apply first control input using updated velocity
Eigen::Vector2d cmd_angles = computeGimbalCommandFromControl(u_optimal);
// Update current gimbal state
current_gimbal_state_ = cmd_angles;
// Convert to gimbal command message
auto gimbal_cmd = toGimbalCmd(cmd_angles, predicted_pos, target);
// Check if can fire considering control delay
gimbal_cmd.fire_advice = canFireAtTime(target, flying_time, current_time);
return gimbal_cmd;
}
Eigen::MatrixXd MpcSolver::computeReferenceTrajectory(
const Eigen::Vector3d& target_pos,
const Eigen::Vector3d& target_vel,
double target_yaw,
double v_yaw,
double flying_time) {
int horizon = admm_solver_yaw_->getStateTrajectory().cols();
Eigen::MatrixXd xref(N_X, horizon + 1);
int num_steps = static_cast<int>(horizon);
double t = 0.0;
for (int i = 0; i <= num_steps; ++i) {
double dt_i = i * dt_;
// Predict target position at dt_i
Eigen::Vector3d pred_pos = predictTargetPosition(target_pos, target_vel, dt_i);
double pred_yaw = target_yaw + dt_i * v_yaw;
// Convert to gimbal angles
xref(0, i) = std::atan2(pred_pos.y(), pred_pos.x()); // yaw
xref(1, i) = std::atan2(pred_pos.z(), pred_pos.head(2).norm()); // pitch
xref(2, i) = 0.0; // yaw_rate (reference is stationary in gimbal frame)
xref(3, i) = 0.0; // pitch_rate
}
return xref;
}
Eigen::Vector2d MpcSolver::solveMPC(const Eigen::Vector4d& x0, const Eigen::MatrixXd& xref) {
// Solve separated MPC for yaw and pitch
// x0 = [yaw, pitch, yaw_rate, pitch_rate]
// xref columns: [yaw_ref, pitch_ref, yaw_rate_ref=0, pitch_rate_ref=0]
static constexpr int N_X_SEP = 2;
// Initial state for yaw MPC: [yaw, yaw_rate]
Eigen::VectorXd x0_yaw(N_X_SEP);
x0_yaw(0) = x0(0);
x0_yaw(1) = x0(2);
// Initial state for pitch MPC: [pitch, pitch_rate]
Eigen::VectorXd x0_pitch(N_X_SEP);
x0_pitch(0) = x0(1);
x0_pitch(1) = x0(3);
// Get horizon from solver
int horizon = admm_solver_yaw_->getStateTrajectory().cols();
// Reference trajectories for yaw and pitch
Eigen::MatrixXd xref_yaw(N_X_SEP, horizon + 1);
Eigen::MatrixXd xref_pitch(N_X_SEP, horizon + 1);
for (int i = 0; i <= horizon; ++i) {
xref_yaw(0, i) = xref(0, i); // yaw
xref_yaw(1, i) = xref(2, i); // yaw_rate
xref_pitch(0, i) = xref(1, i); // pitch
xref_pitch(1, i) = xref(3, i); // pitch_rate
}
// Solve yaw MPC
admm_solver_yaw_->setInitialState(x0_yaw);
admm_solver_yaw_->setReference(xref_yaw);
bool yaw_converged = admm_solver_yaw_->solve();
if (!yaw_converged) {
FYT_WARN("armor_mpc_solver", "Yaw MPC solver did not converge!");
}
// Solve pitch MPC
admm_solver_pitch_->setInitialState(x0_pitch);
admm_solver_pitch_->setReference(xref_pitch);
bool pitch_converged = admm_solver_pitch_->solve();
if (!pitch_converged) {
FYT_WARN("armor_mpc_solver", "Pitch MPC solver did not converge!");
}
// Get first optimal control [yaw_acc, pitch_acc]
Eigen::VectorXd u_yaw = admm_solver_yaw_->getFirstInput();
Eigen::VectorXd u_pitch = admm_solver_pitch_->getFirstInput();
Eigen::Vector2d u_optimal;
u_optimal(0) = u_yaw(0);
u_optimal(1) = u_pitch(0);
return u_optimal;
}
Eigen::Vector2d MpcSolver::computeGimbalCommandFromControl(const Eigen::Vector2d& u) {
// Control input u = [yaw_acc, pitch_acc]
// Semi-implicit Euler integration (symplectic):
// velocity_{k+1} = velocity_k + accel * dt
// angle_{k+1} = angle_k + velocity_{k+1} * dt
Eigen::Vector2d new_angles = current_gimbal_state_;
// First compute new velocity (after acceleration)
Eigen::Vector2d new_velocity = current_gimbal_velocity_ + u * dt_;
// Then compute angle using new velocity
new_angles(0) += new_velocity(0) * dt_;
new_angles(1) += new_velocity(1) * dt_;
// Clamp to gimbal limits
new_angles(0) = std::max(-M_PI, std::min(M_PI, new_angles(0)));
new_angles(1) = std::max(-M_PI_2, std::min(M_PI_2, new_angles(1)));
return new_angles;
}
double MpcSolver::computeFlyingTime(const Eigen::Vector3d& target_pos) {
if (trajectory_compensator_) {
return trajectory_compensator_->getFlyingTime(target_pos);
}
// Fallback: simple calculation
double dist = target_pos.norm();
return dist / bullet_speed_;
}
Eigen::Vector3d MpcSolver::predictTargetPosition(
const Eigen::Vector3d& pos,
const Eigen::Vector3d& vel,
double dt) {
return pos + dt * vel;
}
rm_interfaces::msg::GimbalCmd MpcSolver::toGimbalCmd(
const Eigen::Vector2d& gimbal_angles,
const Eigen::Vector3d& target_pos,
const rm_interfaces::msg::Target& target) {
rm_interfaces::msg::GimbalCmd gimbal_cmd;
gimbal_cmd.header = target.header;
gimbal_cmd.header.stamp = rclcpp::Clock().now();
gimbal_cmd.yaw = gimbal_angles(0) * 180.0 / M_PI;
gimbal_cmd.pitch = gimbal_angles(1) * 180.0 / M_PI;
gimbal_cmd.distance = target_pos.norm();
// Compute yaw_diff and pitch_diff (simplified - assumes gimbal feedback)
// In practice, these would come from actual gimbal feedback
gimbal_cmd.yaw_diff = 0.0;
gimbal_cmd.pitch_diff = 0.0;
// Simplified fire_advice: fire if target is being tracked
gimbal_cmd.fire_advice = true;
// shoot_rate will be set by the node that uses this solver
return gimbal_cmd;
}
void MpcSolver::buildMpcVisualizationTrajectory(
const Eigen::Vector4d& x0,
const Eigen::MatrixXd& x_traj_yaw,
const Eigen::MatrixXd& x_traj_pitch,
const Eigen::MatrixXd& u_traj_yaw,
const Eigen::MatrixXd& u_traj_pitch) {
mpc_trajectory_.clear();
int horizon = x_traj_yaw.cols() - 1;
for (int i = 0; i <= horizon; ++i) {
TrajPoint pt;
pt.t = i * dt_;
pt.yaw = x_traj_yaw(0, i) * 180.0 / M_PI; // yaw in degrees
pt.pitch = x_traj_pitch(0, i) * 180.0 / M_PI; // pitch in degrees
pt.yaw_vel = x_traj_yaw(1, i) * 180.0 / M_PI; // yaw_rate in deg/s
pt.pitch_vel = x_traj_pitch(1, i) * 180.0 / M_PI; // pitch_rate in deg/s
mpc_trajectory_.push_back(pt);
}
}
void MpcSolver::buildReferenceVisualizationTrajectory(
const Eigen::MatrixXd& xref) {
reference_trajectory_.clear();
int horizon = xref.cols() - 1;
for (int i = 0; i <= horizon; ++i) {
TrajPoint pt;
pt.t = i * dt_;
pt.yaw = xref(0, i) * 180.0 / M_PI; // yaw in degrees
pt.pitch = xref(1, i) * 180.0 / M_PI; // pitch in degrees
pt.yaw_vel = xref(2, i) * 180.0 / M_PI; // yaw_rate in deg/s
pt.pitch_vel = xref(3, i) * 180.0 / M_PI; // pitch_rate in deg/s
reference_trajectory_.push_back(pt);
}
}
ArmorSelectResult MpcSolver::selectArmor(
const rm_interfaces::msg::Target& target,
const rclcpp::Time& current_time) {
ArmorSelectResult result;
result.armor_id = 0;
result.coming_angle = 0.0;
result.leaving_angle = 0.0;
result.is_switching = false;
// Target provides position and velocity in Cartesian coordinates
Eigen::Vector3d target_pos(target.position.x, target.position.y, target.position.z);
Eigen::Vector3d target_vel(target.velocity.x, target.velocity.y, target.velocity.z);
// Compute distance and approach angle
double dist = target_pos.norm();
double approach_angle = std::atan2(target_pos.y(), target_pos.x());
// Estimate angular velocity from target velocity
// Angular velocity = (r × v) / |r|^2 where r is position vector
double omega = (target_pos.x() * target_vel.y() - target_pos.y() * target_vel.x()) / (dist * dist + 1e-6);
// Flying time
double flying_time = computeFlyingTime(target_pos);
// Coming angle: angle at which target is approaching
// Leaving angle: angle at which target will be when bullet arrives
double angle_at_flying_time = approach_angle + omega * flying_time;
double coming_angle = angles::shortest_angular_distance(approach_angle, angle_at_flying_time);
double leaving_angle = angles::shortest_angular_distance(angle_at_flying_time, approach_angle);
result.coming_angle = coming_angle;
result.leaving_angle = leaving_angle;
// Check if target is moving toward or away from robot
// Using configurable thresholds from wust_vision
double coming_thresh = comming_angle_deg_ * M_PI / 180.0;
double leaving_thresh = leaving_angle_deg_ * M_PI / 180.0;
// is_switching: target is moving away fast or angle exceeds threshold
result.is_switching = (std::abs(coming_angle) > coming_thresh) ||
(target_vel.norm() > 0.5 && std::abs(leaving_angle) < leaving_thresh);
FYT_DEBUG("armor_mpc_solver",
"selectArmor: dist={:.2f}, coming={:.2f}deg, leaving={:.2f}deg, switching={}",
dist, coming_angle * 180 / M_PI, leaving_angle * 180 / M_PI, result.is_switching);
return result;
}
bool MpcSolver::canFireAtTime(
const rm_interfaces::msg::Target& target,
double time_to_target,
const rclcpp::Time& current_time) {
// Get armor selection result
auto armor_result = selectArmor(target, current_time);
// Account for control delay - add extra time buffer
double total_delay = control_delay_ + 0.05; // 50ms system latency buffer
// Time when bullet will arrive
double fire_time = time_to_target;
// Check if gimbal can reach target within fire_time considering acceleration limits
double dyaw = angles::shortest_angular_distance(current_gimbal_state_(0),
std::atan2(target.position.y, target.position.x));
double dpitch = std::atan2(target.position.z, std::sqrt(target.position.x * target.position.x +
target.position.y * target.position.y)) - current_gimbal_state_(1);
// Minimum time needed to reach target given acceleration constraints
double min_time_yaw = 2.0 * std::sqrt(std::abs(dyaw) / max_yaw_acc_);
double min_time_pitch = 2.0 * std::sqrt(std::abs(dpitch) / max_pitch_acc_);
double min_time_needed = std::max(min_time_yaw, min_time_pitch);
// Can fire if we have enough time and target is not switching rapidly
bool can_fire = (fire_time >= min_time_needed + total_delay) && !armor_result.is_switching;
FYT_DEBUG("armor_mpc_solver",
"canFireAtTime: time_to_target={:.3f}, min_time={:.3f}, can_fire={}",
fire_time, min_time_needed, can_fire);
return can_fire;
}
std::vector<LimitTrajectory::TrajectoryPoint> MpcSolver::computeLimitTrajectory(
double target_yaw, double target_pitch) {
if (!limit_trajectory_generator_) {
return {};
}
return limit_trajectory_generator_->generate(
current_gimbal_state_(0), current_gimbal_state_(1),
target_yaw, target_pitch,
max_yaw_rate_, max_pitch_rate_);
}
void MpcSolver::computeAccelConstrainedTrajectory(
const Eigen::Vector2d& start,
const Eigen::Vector2d& target,
const Eigen::Vector2d& max_vel,
const Eigen::Vector2d& max_acc) {
// Generate limit trajectory for visualization
limit_trajectory_ = computeLimitTrajectory(target(0), target(1));
}
} // namespace fyt::auto_aim

View File

@@ -0,0 +1,154 @@
// Copyright (C) FYT Vision Group. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "armor_mpc_solver/solver_comparer.hpp"
#include "rm_utils/logger/log.hpp"
#include <visualization_msgs/msg/color.hpp>
#include <cmath>
namespace fyt::auto_aim {
SolverComparer::SolverComparer(std::weak_ptr<rclcpp::Node> n)
: node_(n), marker_id_(0), frame_id_("odom"), use_mpc_(false) {
auto node = node_.lock();
if (node) {
use_mpc_ = node->declare_parameter("comparer.use_mpc", false);
}
node.reset();
}
void SolverComparer::init() {
auto node = node_.lock();
if (!node) return;
original_solver_ = std::make_unique<armor_solver::Solver>(node_);
mpc_solver_ = std::make_unique<MpcSolver>(node_);
tf2_buffer_ = std::make_shared<tf2_ros::Buffer>(node->get_clock());
tf2_listener_ = std::make_shared<tf2_ros::TransformListener>(*tf2_buffer_);
trajectory_pub_ = node->create_publisher<visualization_msgs::msg::MarkerArray>(
"/armor_solver/comparison_trajectory", 10);
mpc_gimbal_pub_ = node->create_publisher<rm_interfaces::msg::GimbalCmd>(
"/armor_solver/mpc_gimbal_cmd", 10);
node.reset();
}
void SolverComparer::update(const rm_interfaces::msg::Target::SharedPtr target_msg) {
if (!target_msg) return;
try {
auto current_time = rclcpp::Time(target_msg->header.stamp);
if (original_solver_) {
last_original_cmd_ = original_solver_->solve(*target_msg, current_time, tf2_buffer_);
}
if (mpc_solver_) {
last_mpc_cmd_ = mpc_solver_->solve(*target_msg, current_time, tf2_buffer_);
mpc_gimbal_pub_->publish(last_mpc_cmd_);
last_mpc_trajectory_ = mpc_solver_->getMpcTrajectory();
last_reference_trajectory_ = mpc_solver_->getReferenceTrajectory();
last_limit_trajectory_ = mpc_solver_->getLimitTrajectory();
}
publishComparisonMarkers();
} catch (const std::exception& e) {
FYT_ERROR("armor_mpc_solver", "Solver comparison error: {}", e.what());
}
}
void SolverComparer::publishComparisonMarkers() {
visualization_msgs::msg::MarkerArray marker_array;
if (!last_mpc_trajectory_.empty()) {
visualization_msgs::msg::Marker mpc_line;
mpc_line.header.frame_id = frame_id_;
mpc_line.header.stamp = rclcpp::Clock().now();
mpc_line.ns = "mpc_trajectory";
mpc_line.type = visualization_msgs::msg::Marker::LINE_STRIP;
mpc_line.action = visualization_msgs::msg::Marker::ADD;
mpc_line.scale.x = 0.02;
mpc_line.color.r = 1.0;
mpc_line.color.g = 0.0;
mpc_line.color.b = 0.0;
mpc_line.color.a = 1.0;
for (const auto& pt : last_mpc_trajectory_) {
geometry_msgs::msg::Point p;
p.x = pt.t;
p.y = pt.yaw * 180.0 / M_PI;
p.z = pt.pitch * 180.0 / M_PI;
mpc_line.points.push_back(p);
}
mpc_line.id = marker_id_++;
marker_array.markers.push_back(mpc_line);
}
if (!last_reference_trajectory_.empty()) {
visualization_msgs::msg::Marker ref_line;
ref_line.header.frame_id = frame_id_;
ref_line.header.stamp = rclcpp::Clock().now();
ref_line.ns = "reference_trajectory";
ref_line.type = visualization_msgs::msg::Marker::LINE_STRIP;
ref_line.action = visualization_msgs::msg::Marker::ADD;
ref_line.scale.x = 0.02;
ref_line.color.r = 0.0;
ref_line.color.g = 1.0;
ref_line.color.b = 0.0;
ref_line.color.a = 1.0;
for (const auto& pt : last_reference_trajectory_) {
geometry_msgs::msg::Point p;
p.x = pt.t;
p.y = pt.yaw * 180.0 / M_PI;
p.z = pt.pitch * 180.0 / M_PI;
ref_line.points.push_back(p);
}
ref_line.id = marker_id_++;
marker_array.markers.push_back(ref_line);
}
// Publish limit trajectory (acceleration-constrained) in blue
if (!last_limit_trajectory_.empty()) {
visualization_msgs::msg::Marker limit_line;
limit_line.header.frame_id = frame_id_;
limit_line.header.stamp = rclcpp::Clock().now();
limit_line.ns = "limit_trajectory";
limit_line.type = visualization_msgs::msg::Marker::LINE_STRIP;
limit_line.action = visualization_msgs::msg::Marker::ADD;
limit_line.scale.x = 0.02;
limit_line.color.r = 0.0;
limit_line.color.g = 0.0;
limit_line.color.b = 1.0;
limit_line.color.a = 1.0;
for (const auto& pt : last_limit_trajectory_) {
geometry_msgs::msg::Point p;
p.x = pt.t;
p.y = pt.yaw * 180.0 / M_PI;
p.z = pt.pitch * 180.0 / M_PI;
limit_line.points.push_back(p);
}
limit_line.id = marker_id_++;
marker_array.markers.push_back(limit_line);
}
trajectory_pub_->publish(marker_array);
}
} // namespace fyt::auto_aim

View File

@@ -0,0 +1,50 @@
// Copyright (C) FYT Vision Group. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "armor_mpc_solver/solver_comparison_node.hpp"
namespace fyt::auto_aim {
SolverComparisonNode::SolverComparisonNode()
: rclcpp::Node("solver_comparison_node") {
comparer_ = std::make_unique<SolverComparer>(shared_from_this());
comparer_->init();
target_sub_ = this->create_subscription<rm_interfaces::msg::Target>(
"/armor_solver/target",
10,
std::bind(&SolverComparisonNode::targetCallback, this, std::placeholders::_1));
toggle_sub_ = this->create_subscription<std_msgs::msg::Bool>(
"/armor_solver/toggle_mpc",
10,
std::bind(&SolverComparisonNode::toggleCallback, this, std::placeholders::_1));
}
void SolverComparisonNode::targetCallback(const rm_interfaces::msg::Target::SharedPtr msg) {
comparer_->update(msg);
}
void SolverComparisonNode::toggleCallback(const std_msgs::msg::Bool::SharedPtr msg) {
// Toggle between original solver and MPC solver
}
} // namespace fyt::auto_aim
int main(int argc, char** argv) {
rclcpp::init(argc, argv);
rclcpp::spin(std::make_shared<fyt::auto_aim::SolverComparisonNode>());
rclcpp::shutdown();
return 0;
}

View File

@@ -84,9 +84,11 @@ private:
// Adaptive Q matrix parameters
double s2qxyz_max_, s2qxyz_min_; // Position noise range
double s2qyaw_max_, s2qyaw_min_; // Yaw noise range
double s2qr_, s2qd_zc_; // Radius and height offset noise
// R matrix parameters
double r_x_, r_y_, r_z_, r_yaw_;
double s2qr_, s2qd_zc_, s2qd_za_; // Radius and height offset noise
// R matrix parameters (Spherical coordinates: yaw, pitch, dist, ori_yaw)
double r_yaw_, r_pitch_, r_dist_, r_ori_yaw_;
// Adaptive R scaling for visibility (front vs side armor)
double r_front_scale_, r_side_scale_;
double lost_time_thres_;
std::unique_ptr<Tracker> tracker_;

View File

@@ -70,7 +70,8 @@ public:
Armor tracked_armor;
std::string tracked_id;
ArmorsNum tracked_armors_num;
Eigen::VectorXd measurement;
Eigen::VectorXd measurement; // Cartesian [x, y, z, yaw] for debug publishing
Eigen::VectorXd spherical_measurement_; // Spherical [yaw, pitch, dist, ori_yaw] for EKF
Eigen::VectorXd target_state;
// To store another pair of armors message
@@ -79,6 +80,9 @@ public:
// To store offset relative to the reference plane
double d_zc;
// Last armor type for adaptive noise
std::string last_armor_type_;
private:
void initEKF(const Armor &a) noexcept;
@@ -88,6 +92,12 @@ private:
static Eigen::Vector3d getArmorPositionFromState(const Eigen::VectorXd &x) noexcept;
// Convert Cartesian measurement to spherical
static Eigen::Vector4d cartesianToSpherical(double x, double y, double z, double yaw);
// Adaptive noise based on armor visibility
void updateAdaptiveNoise(const Armor &armor) noexcept;
double max_match_distance_;
double max_match_yaw_diff_;

View File

@@ -28,8 +28,10 @@ enum class MotionModel {
CONSTANT_VEL_ROT = 2 // Constant velocity and rotation velocity
};
// X_N: state dimension, Z_N: measurement dimension
constexpr int X_N = 10, Z_N = 4;
// X_N: state dimension (11), Z_N: measurement dimension (4)
// State: [xc, vxc, yc, vyc, zc, vzc, yaw, vyaw, r, d_zc, d_za]
// Measurement: [yaw, pitch, distance, ori_yaw] (spherical coordinates)
constexpr int X_N = 11, Z_N = 4;
struct Predict {
explicit Predict(double dt, MotionModel model = MotionModel::CONSTANT_VEL_ROT)
@@ -44,9 +46,9 @@ struct Predict {
// v_xyz
if (model == MotionModel::CONSTANT_VEL_ROT || model == MotionModel::CONSTANT_VELOCITY) {
// linear velocity
x1[0] += x0[1] * dt;
x1[2] += x0[3] * dt;
x1[4] += x0[5] * dt;
x1[0] += x0[1] * dt; // xc += vxc * dt
x1[2] += x0[3] * dt; // yc += vyc * dt
x1[4] += x0[5] * dt; // zc += vzc * dt
} else {
// no velocity
x1[1] *= 0.;
@@ -57,7 +59,7 @@ struct Predict {
// v_yaw
if (model == MotionModel::CONSTANT_VEL_ROT || model == MotionModel::CONSTANT_ROTATION) {
// angular velocity
x1[6] += x0[7] * dt;
x1[6] += x0[7] * dt; // yaw += vyaw * dt
} else {
// no rotation
x1[7] *= 0.;
@@ -68,12 +70,31 @@ struct Predict {
MotionModel model;
};
// Spherical measurement model
// z = [yaw, pitch, distance, ori_yaw]
struct Measure {
template <typename T>
void operator()(const T x[Z_N], T z[Z_N]) {
// x[0]: xc, x[2]: yc, x[4]: zc, x[6]: yaw, x[8]: r, x[9]: d_zc, x[10]: d_za
T xa = x[0] - ceres::cos(x[6]) * x[8]; // armor x
T ya = x[2] - ceres::sin(x[6]) * x[8]; // armor y
T za = x[4] + x[9] + x[10]; // armor z
// Convert Cartesian to spherical
z[0] = ceres::atan2(ya, xa); // yaw (azimuth angle)
z[1] = ceres::atan2(za, ceres::sqrt(xa * xa + ya * ya)); // pitch (elevation angle)
z[2] = ceres::sqrt(xa * xa + ya * ya + za * za); // distance
z[3] = x[6]; // ori_yaw (same as yaw for armor)
}
};
// Cartesian measurement model for comparison
struct MeasureCartesian {
template <typename T>
void operator()(const T x[Z_N], T z[Z_N]) {
z[0] = x[0] - ceres::cos(x[6]) * x[8];
z[1] = x[2] - ceres::sin(x[6]) * x[8];
z[2] = x[4] + x[9];
z[2] = x[4] + x[9] + x[10];
z[3] = x[6];
}
};
@@ -81,4 +102,4 @@ struct Measure {
using RobotStateEKF = ExtendedKalmanFilter<X_N, Z_N, Predict, Measure>;
} // namespace fyt::auto_aim
#endif
#endif // ARMOR_SOLVER_MOTION_MODEL_HPP_

View File

@@ -21,6 +21,8 @@
// std
#include <memory>
#include <vector>
// third party
#include <angles/angles.h>
// project
#include "armor_solver/motion_model.hpp"
#include "rm_utils/common.hpp"
@@ -70,6 +72,7 @@ ArmorSolverNode::ArmorSolverNode(const rclcpp::NodeOptions &options)
s2qyaw_min_ = declare_parameter("ekf.sigma2_q_yaw_min", 50.0);
s2qr_ = declare_parameter("ekf.sigma2_q_r", 800.0);
s2qd_zc_ = declare_parameter("ekf.sigma2_q_d_zc", 800.0);
s2qd_za_ = declare_parameter("ekf.sigma2_q_d_za", 800.0);
nis_window_size_ = declare_parameter("ekf.nis_window_size", 20);
nis_adapt_range_ = declare_parameter("ekf.nis_adapt_range", 2.0);
@@ -99,7 +102,7 @@ ArmorSolverNode::ArmorSolverNode(const rclcpp::NodeOptions &options)
s2q_xyz *= q_scale;
s2q_yaw *= q_scale;
double r = s2qr_ * q_scale, d_zc = s2qd_zc_ * q_scale;
double r = s2qr_ * q_scale, d_zc = s2qd_zc_ * q_scale, d_za = s2qd_za_ * q_scale;
// White noise integral model for position-velocity state
double q_x_x = pow(t, 4) / 4 * s2q_xyz, q_x_vx = pow(t, 3) / 2 * s2q_xyz, q_vx_vx = pow(t, 2) * s2q_xyz;
@@ -109,39 +112,65 @@ ArmorSolverNode::ArmorSolverNode(const rclcpp::NodeOptions &options)
q_vyaw_vyaw = pow(t, 2) * s2q_yaw;
double q_r = pow(t, 4) / 4 * r;
double q_d_zc = pow(t, 4) / 4 * d_zc;
double q_d_za = pow(t, 4) / 4 * d_za;
// clang-format off
// xc v_xc yc v_yc zc v_zc yaw v_yaw r d_zc
q << q_x_x, q_x_vx, 0, 0, 0, 0, 0, 0, 0, 0,
q_x_vx, q_vx_vx,0, 0, 0, 0, 0, 0, 0, 0,
0, 0, q_y_y, q_y_vy, 0, 0, 0, 0, 0, 0,
0, 0, q_y_vy, q_vy_vy,0, 0, 0, 0, 0, 0,
0, 0, 0, 0, q_z_z, q_z_vz, 0, 0, 0, 0,
0, 0, 0, 0, q_z_vz, q_vz_vz,0, 0, 0, 0,
0, 0, 0, 0, 0, 0, q_yaw_yaw, q_yaw_vyaw, 0, 0,
0, 0, 0, 0, 0, 0, q_yaw_vyaw, q_vyaw_vyaw,0, 0,
0, 0, 0, 0, 0, 0, 0, 0, q_r, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, q_d_zc;
// xc v_xc yc v_yc zc v_zc yaw v_yaw r d_zc d_za
q << q_x_x, q_x_vx, 0, 0, 0, 0, 0, 0, 0, 0, 0,
q_x_vx, q_vx_vx,0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, q_y_y, q_y_vy, 0, 0, 0, 0, 0, 0, 0,
0, 0, q_y_vy, q_vy_vy,0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, q_z_z, q_z_vz, 0, 0, 0, 0, 0,
0, 0, 0, 0, q_z_vz, q_vz_vz,0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, q_yaw_yaw, q_yaw_vyaw, 0, 0, 0,
0, 0, 0, 0, 0, 0, q_yaw_vyaw, q_vyaw_vyaw,0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, q_r, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, q_d_zc, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, q_d_za;
// clang-format on
return q;
};
// update_R - measurement noise covariance matrix
// R scales with distance: farther target -> larger measurement noise
r_x_ = declare_parameter("ekf.r_x", 0.05);
r_y_ = declare_parameter("ekf.r_y", 0.05);
r_z_ = declare_parameter("ekf.r_z", 0.05);
// update_R - measurement noise covariance matrix (Spherical coordinates)
// z = [yaw, pitch, distance, ori_yaw]
// R scales with distance: farther target -> larger angular noise
r_yaw_ = declare_parameter("ekf.r_yaw", 0.02);
r_pitch_ = declare_parameter("ekf.r_pitch", 0.02);
r_dist_ = declare_parameter("ekf.r_dist", 0.05);
r_ori_yaw_ = declare_parameter("ekf.r_ori_yaw", 0.02);
// Adaptive R scaling factor based on armor visibility
// Front armor: smaller noise (more trust)
// Side armor: larger noise (less trust)
r_front_scale_ = declare_parameter("ekf.r_front_scale", 0.5);
r_side_scale_ = declare_parameter("ekf.r_side_scale", 2.0);
auto u_r = [this](const Eigen::Matrix<double, Z_N, 1> &z) {
Eigen::Matrix<double, Z_N, Z_N> r;
// Calculate distance for better noise scaling
double dist = std::sqrt(z[0] * z[0] + z[1] * z[1] + z[2] * z[2]);
// z[0] = yaw, z[1] = pitch, z[2] = distance, z[3] = ori_yaw
double dist = z[2]; // distance is the 3rd element in spherical measurement
// Minimum distance to prevent numerical issues when target is very close
dist = std::max(dist, 1.0);
// Angular noise scales with distance (smaller at close range)
double r_yaw_scaled = r_yaw_ * dist * dist; // rad^2 * m^2
double r_pitch_scaled = r_pitch_ * dist * dist;
double r_dist_scaled = r_dist_ * dist; // m^2
double r_ori_yaw_scaled = r_ori_yaw_ * dist * dist;
// Apply visibility-based scaling if armor visibility info is available
// This is updated in tracker_->updateAdaptiveNoise()
double visibility_scale = 1.0;
if (tracker_->last_armor_type_ == "front") {
visibility_scale = r_front_scale_;
} else if (tracker_->last_armor_type_ == "side") {
visibility_scale = r_side_scale_;
}
// clang-format off
r << r_x_ * dist, 0, 0, 0,
0, r_y_ * dist, 0, 0,
0, 0, r_z_ * dist, 0,
0, 0, 0, r_yaw_;
r << r_yaw_scaled * visibility_scale, 0, 0, 0,
0, r_pitch_scaled * visibility_scale, 0, 0,
0, 0, r_dist_scaled * visibility_scale, 0,
0, 0, 0, r_ori_yaw_scaled * visibility_scale;
// clang-format on
return r;
};
@@ -150,6 +179,20 @@ ArmorSolverNode::ArmorSolverNode(const rclcpp::NodeOptions &options)
p0.setIdentity();
tracker_->ekf = std::make_unique<RobotStateEKF>(f, h, u_q, u_r, p0);
// Set residual function for handling angle wraparound (yaw, pitch)
// The measurement z = [yaw, pitch, distance, ori_yaw]
// Use angles::shortest_angular_distance to handle -pi~pi discontinuity
auto residual_func = [](const Eigen::Matrix<double, Z_N, 1> &z_meas,
const Eigen::Matrix<double, Z_N, 1> &z_pred) {
Eigen::Matrix<double, Z_N, 1> residual;
residual(0) = angles::shortest_angular_distance(z_pred(0), z_meas(0)); // yaw
residual(1) = angles::shortest_angular_distance(z_pred(1), z_meas(1)); // pitch
residual(2) = z_meas(2) - z_pred(2); // distance (no wraparound)
residual(3) = angles::shortest_angular_distance(z_pred(3), z_meas(3)); // ori_yaw
return residual;
};
tracker_->ekf->setResidualFunc(residual_func);
// Subscriber with tf2 message_filter
// tf2 relevant
tf2_buffer_ = std::make_shared<tf2_ros::Buffer>(this->get_clock());

View File

@@ -19,6 +19,7 @@
#include "armor_solver/armor_tracker.hpp"
// std
#include <cfloat>
#include <cmath>
#include <memory>
#include <string>
// ros2
@@ -33,16 +34,19 @@
#include "rm_utils/logger/log.hpp"
namespace fyt::auto_aim {
Tracker::Tracker(double max_match_distance, double max_match_yaw_diff)
: tracker_state(LOST)
, tracked_id(std::string(""))
, measurement(Eigen::VectorXd::Zero(4))
, target_state(Eigen::VectorXd::Zero(9))
, spherical_measurement_(Eigen::VectorXd::Zero(4))
, target_state(Eigen::VectorXd::Zero(X_N))
, max_match_distance_(max_match_distance)
, max_match_yaw_diff_(max_match_yaw_diff)
, detect_count_(0)
, lost_count_(0)
, last_yaw_(0) {}
, last_yaw_(0)
, last_armor_type_("") {}
void Tracker::setMatchThreshold(double max_match_distance, double max_match_yaw_diff) noexcept {
max_match_distance_ = max_match_distance;
@@ -69,6 +73,7 @@ void Tracker::init(const Armors::SharedPtr &armors_msg) noexcept {
tracked_id = tracked_armor.number;
tracker_state = DETECTING;
last_armor_type_ = tracked_armor.type;
if (tracked_armor.type == "large" &&
(tracked_id == "3" || tracked_id == "4" || tracked_id == "5")) {
@@ -111,6 +116,7 @@ void Tracker::update(const Armors::SharedPtr &armors_msg) noexcept {
yaw_diff = abs(orientationToYaw(armor.pose.orientation) - ekf_prediction(6));
tracked_armor = armor;
// Update tracked armor type
last_armor_type_ = tracked_armor.type;
if (tracked_armor.type == "large" &&
(tracked_id == "3" || tracked_id == "4" || tracked_id == "5")) {
tracked_armors_num = ArmorsNum::BALANCE_2;
@@ -129,10 +135,17 @@ void Tracker::update(const Armors::SharedPtr &armors_msg) noexcept {
// Matched armor found
matched = true;
auto p = tracked_armor.pose.position;
// Update EKF
// Update adaptive noise based on armor visibility
updateAdaptiveNoise(tracked_armor);
// Store Cartesian measurement for debug publishing
double measured_yaw = orientationToYaw(tracked_armor.pose.orientation);
measurement = Eigen::Vector4d(p.x, p.y, p.z, measured_yaw);
target_state = ekf->update(measurement);
// Convert to spherical for EKF update
spherical_measurement_ = cartesianToSpherical(p.x, p.y, p.z, measured_yaw);
target_state = ekf->update(spherical_measurement_);
} else if (same_id_armors_count == 1 && yaw_diff > max_match_yaw_diff_) {
// Matched armor not found, but there is only one armor with the same id
// and yaw has jumped, take this case as the target is spinning and armor
@@ -203,7 +216,7 @@ void Tracker::initEKF(const Armor &a) noexcept {
double yc = ya + r * sin(yaw);
double zc = za;
d_za = 0, d_zc = 0, another_r = r;
target_state << xc, 0, yc, 0, zc, 0, yaw, 0, r, d_zc;
target_state << xc, 0, yc, 0, zc, 0, yaw, 0, r, d_zc, d_za;
ekf->setState(target_state);
}
@@ -217,10 +230,11 @@ void Tracker::handleArmorJump(const Armor &current_armor) noexcept {
target_state(6) = yaw;
// Only 4 armors has 2 radius and height
if (tracked_armors_num == ArmorsNum::NORMAL_4) {
d_za = target_state(4) + target_state(9) - current_armor.pose.position.z;
d_za = target_state(4) + target_state(9) + target_state(10) - current_armor.pose.position.z;
std::swap(target_state(8), another_r);
d_zc = d_zc == 0 ? -d_za : 0;
target_state(9) = d_zc;
target_state(10) = d_za;
}
FYT_DEBUG("armor_solver", "Armor Jump!");
}
@@ -234,6 +248,7 @@ void Tracker::handleArmorJump(const Armor &current_armor) noexcept {
// large, the state is wrong, reset center position and velocity in the
// state
d_zc = 0;
d_za = 0;
double r = target_state(8);
target_state(0) = p.x + r * cos(yaw); // xc
target_state(1) = 0; // vxc
@@ -242,6 +257,7 @@ void Tracker::handleArmorJump(const Armor &current_armor) noexcept {
target_state(4) = p.z; // zc
target_state(5) = 0; // vzc
target_state(9) = d_zc; // d_zc
target_state(10) = d_za; // d_za
FYT_WARN("armor_solver", "State wrong!");
}
@@ -262,11 +278,54 @@ double Tracker::orientationToYaw(const geometry_msgs::msg::Quaternion &q) noexce
Eigen::Vector3d Tracker::getArmorPositionFromState(const Eigen::VectorXd &x) noexcept {
// Calculate predicted position of the current armor
double xc = x(0), yc = x(2), za = x(4) + x(9);
// x[0]: xc, x[2]: yc, x[4]: zc, x[6]: yaw, x[8]: r, x[9]: d_zc, x[10]: d_za
double xc = x(0), yc = x(2), za = x(4) + x(9) + x(10);
double yaw = x(6), r = x(8);
double xa = xc - r * cos(yaw);
double ya = yc - r * sin(yaw);
return Eigen::Vector3d(xa, ya, za);
}
} // namespace fyt::auto_aim
Eigen::Vector4d Tracker::cartesianToSpherical(double x, double y, double z, double yaw) noexcept {
Eigen::Vector4d spherical;
double dist_xy = std::sqrt(x * x + y * y);
spherical(0) = std::atan2(y, x); // yaw
spherical(1) = std::atan2(z, dist_xy); // pitch
spherical(2) = std::sqrt(x * x + y * y + z * z); // distance
spherical(3) = yaw; // ori_yaw
return spherical;
}
void Tracker::updateAdaptiveNoise(const Armor &armor) noexcept {
// Adaptive R matrix based on armor visibility
// Classification based on distance_to_image_center:
// - Front armor: smaller distance_to_image_center (closer to image center)
// - Side armor: larger distance_to_image_center
// This is a heuristic - front armor appears more stable and should have lower measurement noise
// Threshold for front/side classification
// Typical image center distance for front armor: < 0.3 (normalized)
// Typical image center distance for side armor: > 0.5 (normalized)
constexpr double front_threshold = 0.35;
constexpr double side_threshold = 0.55;
std::string visibility_type;
if (armor.distance_to_image_center < front_threshold) {
visibility_type = "front";
} else if (armor.distance_to_image_center > side_threshold) {
visibility_type = "side";
} else {
// In between: use previous state or default to front
visibility_type = last_armor_type_.empty() ? "front" : last_armor_type_;
}
last_armor_type_ = visibility_type;
FYT_DEBUG("armor_solver",
"Adaptive noise: dist_to_center={:.3f}, visibility={}, armor_type={}",
armor.distance_to_image_center,
visibility_type,
armor.type);
}
} // namespace fyt::auto_aim

View File

@@ -34,7 +34,7 @@
tracking_thres: 1
lost_time_thres: 1.0
solver:
shoot_rate_min: 6
shoot_rate_max: 12
@@ -43,14 +43,14 @@
shooting_range_width: 0.10 #射击范围
shooting_range_height: 0.10 #射击范围
prediction_delay: 0.02 # 预测装甲板位置的延时,单位秒,+飞行时间
controller_delay: 0.01
max_tracking_v_yaw: 5.0 #转速(rad/s)大于这个值时瞄准机器人中心
side_angle: 15.0
compenstator_type: "resistance"
controller_delay: 0.01
max_tracking_v_yaw: 5.0 #转速(rad/s)大于这个值时瞄准机器人中心
side_angle: 15.0
compenstator_type: "resistance"
gravity: 9.792
resistance: 0.038
resistance: 0.038
iteration_times: 20 # 补偿的迭代次数
# ["距离下限, 距离上限, 高度下限, 高度下限, pitch轴补偿值"]
# [dist_low, dist_high, height_low, height_high, pitch_offset_deg, yaw_offset_deg]
angle_offset: [
@@ -64,3 +64,79 @@
"7.0 8.0 0.4 0.8 0.0 0.0",
"7.0 8.0 0.8 1.2 0.0 0.0",
]
# MPC solver parameters (matching wust_vision very_aimer)
# Type: "seg" (segment-based trajectory) or "mpc" (model predictive control)
mpc:
enabled: false # Set to true to use MPC solver instead of original solver
type: "seg" # "seg" or "mpc"
# Trajectory parameters
sample_total_time: 2.0 # Total prediction time (seconds)
sample_horizon: 500 # Number of steps in horizon
# Control parameters
control_delay: 0.2 # Control delay (seconds)
delay_enable_fire_error: 0.0035 # Fire error threshold
# Acceleration limits (rad/s^2)
max_yaw_acc: 40.0
max_pitch_acc: 25.0
# Fire decision thresholds
comming_angle: 60.0 # Coming angle threshold (degrees)
leaving_angle: 20.0 # Leaving angle threshold (degrees)
yaw_limit_deg: 60.0 # Yaw limit for fire decision
shooting_range_h: 0.12 # Height shooting range
shooting_range_small_w: 0.12 # Small target width range
shooting_range_big_w: 0.24 # Big target width range
min_enable_pitch_deg: 0.25 # Min pitch for fire
min_enable_yaw_deg: 0.25 # Min yaw for fire
# MPC cost weights (matching wust_vision)
# Q = [position_cost, velocity_cost], R = [control_cost]
Q_yaw: [7.0e6, 0.0] # Yaw position and velocity cost
R_yaw: [3.0] # Yaw control cost
Q_pitch: [7.0e6, 0.0] # Pitch position and velocity cost
R_pitch: [3.0] # Pitch control cost
# Trajectory compensator
trajectory_compensator: "resistance" # "ideal" or "resistance"
gravity: 9.8
# Trajectory offset (same format as solver.angle_offset)
trajectory_offset: [
"0.0 4.5 -1.0 1.5 0.0 0.0",
]
# Alternative MPC configuration for comparison testing
mpc_compare:
enabled: false
type: "mpc"
sample_total_time: 1.0
sample_horizon: 250
control_delay: 0.15
delay_enable_fire_error: 0.005
max_yaw_acc: 35.0
max_pitch_acc: 20.0
comming_angle: 45.0
leaving_angle: 30.0
yaw_limit_deg: 50.0
shooting_range_h: 0.15
shooting_range_small_w: 0.10
shooting_range_big_w: 0.20
min_enable_pitch_deg: 0.3
min_enable_yaw_deg: 0.3
Q_yaw: [5.0e6, 1.0e5]
R_yaw: [2.0]
Q_pitch: [5.0e6, 1.0e5]
R_pitch: [2.0]
trajectory_compensator: "resistance"
gravity: 9.8
trajectory_offset: []

View File

@@ -0,0 +1,61 @@
cmake_minimum_required(VERSION 3.8)
project(rm_tinympc)
if(CMAKE_CXX_STANDARD GREATER_RANGE 17)
set(CMAKE_CXX_STANDARD 17)
else()
set(CMAKE_CXX_STANDARD 17)
endif()
set(CMAKE_CXX_STANDARD_REQUIRED ON)
# Find dependencies
find_package(ament_cmake REQUIRED)
find_package(Eigen3 REQUIRED)
find_package(tf2 REQUIRED)
include_directories(
include
${EIGEN3_INCLUDE_DIRS}
)
set(${PROJECT_NAME}_HEADERS
include/rm_tinympc/types.hpp
include/rm_tinympc/admm.hpp
include/rm_tinympc/tiny_api.hpp
include/rm_tinympc/codegen.hpp
include/rm_tinympc/error.hpp
include/rm_tinympc/trajectory.hpp
)
add_library(${PROJECT_NAME} SHARED
src/admm.cpp
src/tiny_api.cpp
src/codegen.cpp
src/trajectory.cpp
)
ament_target_dependencies(${PROJECT_NAME}
Eigen3
tf2
)
target_include_directories(${PROJECT_NAME} PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<INSTALL_INTERFACE:include>
)
install(TARGETS ${PROJECT_NAME}
ARCHIVE DESTINATION lib
LIBRARY DESTINATION lib
RUNTIME DESTINATION bin
)
install(DIRECTORY include/
DESTINATION include
)
ament_export_include_directories(include)
ament_export_libraries(${PROJECT_NAME})
ament_package()

View File

@@ -0,0 +1,162 @@
// Copyright (C) FYT Vision Group. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef RM_TINYMPC_ADMM_HPP_
#define RM_TINYMPC_ADMM_HPP_
#include <Eigen/Eigen>
#include <vector>
namespace rm_tinympc {
/**
* ADMM-based MPC Solver (TinyMPC port)
*
* Minimizes: sum_{k=0}^{T-1} (x_k - xref_k)' Q (x_k - xref_k) + u_k' R u_k
* Subject to: x_{k+1} = A x_k + B u_k (dynamics)
* x_min <= x_k <= x_max (state constraints)
* u_min <= u_k <= u_max (input constraints)
*/
class ADMM {
public:
ADMM() = default;
~ADMM() = default;
/**
* Initialize the solver
* @param n State dimension
* @param m Input dimension
* @param T Prediction horizon
*/
void init(int n, int m, int T);
/**
* Set the problem matrices
* @param A State transition matrix (n x n)
* @param B Input matrix (n x m)
* @param Q State cost weight (n x n)
* @param R Input cost weight (m x m)
* @param Qf Terminal cost weight (n x n)
*/
void setProblem(const Eigen::MatrixXd& A,
const Eigen::MatrixXd& B,
const Eigen::VectorXd& Q,
const Eigen::VectorXd& R,
const Eigen::VectorXd& Qf);
/**
* Set state and input constraints
*/
void setConstraints(const Eigen::VectorXd& x_min,
const Eigen::VectorXd& x_max,
const Eigen::VectorXd& u_min,
const Eigen::VectorXd& u_max);
/**
* Set reference trajectory
*/
void setReference(const Eigen::MatrixXd& xref);
/**
* Set initial state
*/
void setInitialState(const Eigen::VectorXd& x0);
/**
* Solve the MPC problem
* @return true if converged, false otherwise
*/
bool solve();
/**
* Get the optimal control input (first input of the sequence)
*/
const Eigen::VectorXd& getFirstInput() const { return u_solution_.col(0); }
/**
* Get the full state trajectory
*/
const Eigen::MatrixXd& getStateTrajectory() const { return x_solution_; }
/**
* Get the full control sequence
*/
const Eigen::MatrixXd& getControlSequence() const { return u_solution_; }
/**
* Configure solver parameters
*/
void setMaxIterations(int max_iter) { max_iterations_ = max_iter; }
void setRho(double rho) { rho_ = rho; }
void setAbsTol(double tol) { abs_tol_ = tol; }
void setRelTol(double tol) { rel_tol_ = tol; }
void setVerbose(bool verbose) { verbose_ = verbose; }
private:
int n_; // State dimension
int m_; // Input dimension
int T_; // Horizon
int max_iterations_ = 100;
double rho_ = 0.1;
double abs_tol_ = 1e-4;
double rel_tol_ = 1e-3;
bool verbose_ = false;
// System matrices
Eigen::MatrixXd A_;
Eigen::MatrixXd B_;
Eigen::VectorXd Q_;
Eigen::VectorXd R_;
Eigen::VectorXd Qf_;
// Constraints
Eigen::VectorXd x_min_;
Eigen::VectorXd x_max_;
Eigen::VectorXd u_min_;
Eigen::VectorXd u_max_;
bool has_constraints_ = false;
// Reference trajectory
Eigen::MatrixXd xref_;
// Initial state
Eigen::VectorXd x0_;
// Solution
Eigen::MatrixXd x_solution_; // n x (T+1)
Eigen::MatrixXd u_solution_; // m x T
// ADMM variables
Eigen::MatrixXd z_; // n x (T+1) - state consensus
Eigen::MatrixXd u_; // m x T - control consensus
Eigen::MatrixXd z_old_; // n x (T+1)
Eigen::MatrixXd u_old_; // m x T
// Lagrange multipliers
Eigen::MatrixXd lambda_x_; // n x (T+1)
Eigen::MatrixXd lambda_u_; // m x T
// Check convergence
bool checkConvergence(const Eigen::MatrixXd& x, const Eigen::MatrixXd& u);
// Proximal operators
Eigen::VectorXd proxGradient(const Eigen::VectorXd& x, const Eigen::VectorXd& grad,
double step, const Eigen::VectorXd& x_min,
const Eigen::VectorXd& x_max);
};
} // namespace rm_tinympc
#endif // RM_TINYMPC_ADMM_HPP_

View File

@@ -0,0 +1,53 @@
// Copyright (C) FYT Vision Group. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef RM_TINYMPC_CODEGEN_HPP_
#define RM_TINYMPC_CODEGEN_HPP_
#include <Eigen/Eigen>
namespace rm_tinympc {
class CodeGen {
public:
CodeGen() = default;
~CodeGen() = default;
void generateSolverCode(const std::string& output_dir);
static Eigen::Vector2d computeGimbalCommand(
const Eigen::Vector3d& target_pos,
const Eigen::Vector3d& target_vel,
double target_yaw,
double v_yaw,
double bullet_speed,
double gravity);
static Eigen::Vector3d predictTargetPosition(
const Eigen::Vector3d& pos,
const Eigen::Vector3d& vel,
double dt);
static double computeFlyingTime(
const Eigen::Vector3d& target_pos,
double bullet_speed,
double gravity = 9.8);
private:
static constexpr double kPi = 3.14159265358979323846;
};
} // namespace rm_tinympc
#endif // RM_TINYMPC_CODEGEN_HPP_

View File

@@ -0,0 +1,40 @@
// Copyright (C) FYT Vision Group. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef RM_TINYMPC_ERROR_HPP_
#define RM_TINYMPC_ERROR_HPP_
#include <stdexcept>
#include <string>
namespace rm_tinympc {
class TinyMPCException : public std::runtime_error {
public:
explicit TinyMPCException(const std::string& msg) : std::runtime_error(msg) {}
};
class ConvergenceError : public TinyMPCException {
public:
explicit ConvergenceError(const std::string& msg) : TinyMPCException(msg) {}
};
class InvalidSizeError : public TinyMPCException {
public:
explicit InvalidSizeError(const std::string& msg) : TinyMPCException(msg) {}
};
} // namespace rm_tinympc
#endif // RM_TINYMPC_ERROR_HPP_

View File

@@ -0,0 +1,71 @@
// Copyright (C) FYT Vision Group. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef RM_TINYMPC_TINY_API_HPP_
#define RM_TINYMPC_TINY_API_HPP_
#include <Eigen/Eigen>
#include "types.hpp"
namespace rm_tinympc {
class TinyMPC {
public:
TinyMPC() = default;
~TinyMPC() = default;
void init(int n, int m, int T);
void setProblem(
const Eigen::MatrixXd& A,
const Eigen::MatrixXd& B,
const Eigen::VectorXd& x0,
const Eigen::VectorXd& xref,
const Eigen::VectorXd& uref);
void setWeights(const Eigen::VectorXd& Q, const Eigen::VectorXd& R, const Eigen::VectorXd& Qf);
void solve();
const Eigen::VectorXd& getSolution() const { return u_solution_; }
const Eigen::MatrixXd& getFullSolution() const { return x_solution_; }
void setMaxIterations(int max_iter) { max_iterations_ = max_iter; }
void setRho(double rho) { rho_ = rho; }
private:
int n_; // state dimension
int m_; // input dimension
int T_; // horizon
int max_iterations_ = 100;
double rho_ = 0.1;
Eigen::MatrixXd A_;
Eigen::MatrixXd B_;
Eigen::VectorXd x0_;
Eigen::VectorXd xref_;
Eigen::VectorXd uref_;
Eigen::VectorXd Q_;
Eigen::VectorXd R_;
Eigen::VectorXd Qf_;
Eigen::MatrixXd x_solution_;
Eigen::VectorXd u_solution_;
};
} // namespace rm_tinympc
#endif // RM_TINYMPC_TINY_API_HPP_

View File

@@ -0,0 +1,165 @@
// Copyright (C) FYT Vision Group. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef RM_TINYMPC_TRAJECTORY_HPP_
#define RM_TINYMPC_TRAJECTORY_HPP_
#include <Eigen/Eigen>
namespace rm_tinympc {
/**
* Quintic Polynomial Trajectory Generator
*
* Generates smooth trajectories from start state to goal state
* with specified boundary velocities and accelerations.
*
* Coefficients: a0 + a1*t + a2*t^2 + a3*t^3 + a4*t^4 + a5*t^5
*/
class QuinticPolynomial {
public:
QuinticPolynomial() = default;
~QuinticPolynomial() = default;
/**
* Compute quintic polynomial coefficients
* @param p0 Initial position
* @param v0 Initial velocity
* @param a0 Initial acceleration
* @param pf Final position
* @param vf Final velocity
* @param af Final acceleration
* @param T Trajectory duration
*/
void compute(double p0, double v0, double a0,
double pf, double vf, double af, double T);
/**
* Evaluate position at time t
*/
double evaluate(double t) const;
/**
* Evaluate velocity at time t
*/
double evaluateVelocity(double t) const;
/**
* Evaluate acceleration at time t
*/
double evaluateAcceleration(double t) const;
/**
* Get polynomial coefficients
*/
const Eigen::Vector6d& getCoefficients() const { return coef_; }
private:
Eigen::Vector6d coef_; // a0, a1, a2, a3, a4, a5
};
/**
* Trajectory Point with position, velocity, acceleration
*/
struct TrajectoryPoint1D {
double t; // Time
double p; // Position
double v; // Velocity
double a; // Acceleration
};
/**
* 2D Trajectory Point (yaw, pitch)
*/
struct TrajectoryPoint2D {
double t;
double yaw;
double pitch;
double yaw_vel;
double pitch_vel;
double yaw_acc;
double pitch_acc;
};
/**
* Generate smooth 1D trajectory with quintic polynomial
*/
class TrajectoryGenerator1D {
public:
TrajectoryGenerator1D() = default;
~TrajectoryGenerator1D() = default;
/**
* Generate trajectory from start to goal
* @param p0 Initial position
* @param pf Final position
* @param max_v Maximum velocity (for constraint)
* @param max_a Maximum acceleration (for constraint)
* @param T Total duration
*/
std::vector<TrajectoryPoint1D> generate(double p0, double pf,
double max_v, double max_a, double T);
/**
* Generate minimum time trajectory with velocity/acceleration constraints
*/
double generateMinTime(double p0, double pf, double max_v, double max_a);
private:
double T_ = 0.0;
QuinticPolynomial poly_;
};
/**
* 2D Gimbal Trajectory Generator (yaw and pitch simultaneously)
*/
class GimbalTrajectoryGenerator {
public:
GimbalTrajectoryGenerator() = default;
~GimbalTrajectoryGenerator() = default;
/**
* Generate 2D gimbal trajectory
* @param yaw0 Initial yaw
* @param pitch0 Initial pitch
* @param yawf Target yaw
* @param pitchf Target pitch
* @param max_yaw_rate Maximum yaw rate (rad/s)
* @param max_pitch_rate Maximum pitch rate (rad/s)
* @param max_yaw_acc Maximum yaw acceleration (rad/s^2)
* @param max_pitch_acc Maximum pitch acceleration (rad/s^2)
* @param T Trajectory duration
*/
std::vector<TrajectoryPoint2D> generate(double yaw0, double pitch0,
double yawf, double pitchf,
double max_yaw_rate, double max_pitch_rate,
double max_yaw_acc, double max_pitch_acc,
double T);
/**
* Minimum time trajectory with constraints
*/
double generateMinTime(double yaw0, double pitch0,
double yawf, double pitchf,
double max_yaw_rate, double max_pitch_rate,
double max_yaw_acc, double max_pitch_acc);
private:
QuinticPolynomial poly_yaw_;
QuinticPolynomial poly_pitch_;
};
} // namespace rm_tinympc
#endif // RM_TINYMPC_TRAJECTORY_HPP_

View File

@@ -0,0 +1,37 @@
// Copyright (C) FYT Vision Group. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef RM_TINYMPC_TYPES_HPP_
#define RM_TINYMPC_TYPES_HPP_
#include <Eigen/Eigen>
namespace rm_tinympc {
using vec2 = Eigen::Vector2d;
using vec3 = Eigen::Vector3d;
using vec4 = Eigen::Vector4d;
using mat2 = Eigen::Matrix2d;
using mat3 = Eigen::Matrix3d;
using mat4 = Eigen::Matrix4d;
struct TrajectoryPoint {
double yaw;
double pitch;
double t;
};
} // namespace rm_tinympc
#endif // RM_TINYMPC_TYPES_HPP_

View File

@@ -0,0 +1,19 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>rm_tinympc</name>
<version>0.1.0</version>
<description>TinyMPC - ADMM-based small MPC solver</description>
<maintainer email="chenyouyuan@foxmail.com">Chen Youyuan</maintainer>
<license>Apache-2.0</license>
<buildtool_depend>ament_cmake</buildtool_depend>
<buildtool_depend>ament_cmake_python</buildtool_depend>
<depend>Eigen3</depend>
<depend>tf2</depend>
<export>
<build_type>ament_cmake</build_type>
</export>
</package>

204
src/rm_tinympc/src/admm.cpp Normal file
View File

@@ -0,0 +1,204 @@
// Copyright (C) FYT Vision Group. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "rm_tinympc/admm.hpp"
#include <iostream>
#include <cmath>
namespace rm_tinympc {
void ADMM::init(int n, int m, int T) {
n_ = n;
m_ = m;
T_ = T;
// Allocate solution matrices
x_solution_.resize(n_, T_ + 1);
u_solution_.resize(m_, T_);
// ADMM variables
z_.resize(n_, T_ + 1);
u_.resize(m_, T_);
z_old_.resize(n_, T_ + 1);
u_old_.resize(m_, T_);
lambda_x_.resize(n_, T_ + 1);
lambda_u_.resize(m_, T_);
x_solution_.setZero();
u_solution_.setZero();
z_.setZero();
u_.setZero();
lambda_x_.setZero();
lambda_u_.setZero();
}
void ADMM::setProblem(const Eigen::MatrixXd& A,
const Eigen::MatrixXd& B,
const Eigen::VectorXd& Q,
const Eigen::VectorXd& R,
const Eigen::VectorXd& Qf) {
A_ = A;
B_ = B;
Q_ = Q;
R_ = R;
Qf_ = Qf;
}
void ADMM::setConstraints(const Eigen::VectorXd& x_min,
const Eigen::VectorXd& x_max,
const Eigen::VectorXd& u_min,
const Eigen::VectorXd& u_max) {
x_min_ = x_min;
x_max_ = x_max;
u_min_ = u_min;
u_max_ = u_max;
has_constraints_ = true;
}
void ADMM::setReference(const Eigen::MatrixXd& xref) {
xref_ = xref;
}
void ADMM::setInitialState(const Eigen::VectorXd& x0) {
x0_ = x0;
x_solution_.col(0) = x0;
z_.col(0) = x0;
}
bool ADMM::solve() {
if (xref_.rows() != n_ || xref_.cols() != T_ + 1) {
return false;
}
// Initialize state trajectory by forward simulation
x_solution_.col(0) = x0_;
for (int t = 0; t < T_; ++t) {
x_solution_.col(t + 1) = A_ * x_solution_.col(t) + B_ * u_solution_.col(t);
}
double primal_res_x = 0.0, primal_res_u = 0.0;
double dual_res_x = 0.0, dual_res_u = 0.0;
for (int iter = 0; iter < max_iterations_; ++iter) {
z_old_ = z_;
u_old_ = u_;
// === X-update (state update) ===
// Solve: min_{x} rho/2 * ||x - z + lambda||^2 + sum (x_k - xref_k)' Q (x_k - xref_k)
// Subject to: x_{k+1} = A x_k + B u_k
for (int t = 0; t <= T_; ++t) {
Eigen::VectorXd x_t = z_.col(t) - lambda_x_.col(t) / rho_;
if (t == 0) {
// Initial state constraint
x_t = x0_;
} else if (t <= T_) {
// Cost gradient: 2 * Q * (x_t - xref_t)
Eigen::VectorXd grad = Q_.asDiagonal() * (x_t - xref_.col(t));
x_t = x_t - (1.0 / (rho_ + 2.0 * Q_(t % Q_.size()))) * grad;
}
// Apply state constraints
if (has_constraints_ && t > 0) {
x_t = x_t.cwiseMax(x_min_).cwiseMin(x_max_);
}
z_.col(t) = x_t;
}
// === U-update (control update) ===
// Solve: min_{u} rho/2 * ||u - v + lambda||^2 + sum u_k' R u_k
for (int t = 0; t < T_; ++t) {
// Compute the implied state from previous state and control
Eigen::VectorXd x_implied = A_ * x_solution_.col(t) + B_ * u_solution_.col(t);
// Gradient of cost w.r.t. u: 2 * R * u_k + B' * lambda_u
Eigen::VectorXd grad_u = (2.0 * R_.asDiagonal() * u_solution_.col(t)
+ B_.transpose() * lambda_u_.col(t)) / rho_;
Eigen::VectorXd u_t = u_solution_.col(t) - grad_u;
// Apply control constraints
if (has_constraints_) {
u_t = u_t.cwiseMax(u_min_).cwiseMin(u_max_);
}
u_.col(t) = u_t;
// Update state trajectory using new control
if (t < T_) {
x_solution_.col(t + 1) = A_ * x_solution_.col(t) + B_ * u_t;
}
}
// === Lagrange multiplier update ===
// lambda_x += rho * (x_solution - z_)
// lambda_u += rho * (u_solution - u_)
lambda_x_ += rho_ * (x_solution_ - z_);
lambda_u_ += rho_ * (u_solution_ - u_);
// === Check convergence ===
if (checkConvergence(x_solution_, u_solution_)) {
if (verbose_) {
std::cout << "ADMM converged at iteration " << iter << std::endl;
}
return true;
}
// === Adaptive rho adjustment (optional) ===
// Increase rho if primal residual is large, decrease if dual residual is large
// This is simplified for TinyMPC
// Update solution for next iteration
z_ = x_solution_;
u_ = u_solution_;
}
if (verbose_) {
std::cout << "ADMM reached max iterations " << max_iterations_ << std::endl;
}
return false;
}
bool ADMM::checkConvergence(const Eigen::MatrixXd& x, const Eigen::MatrixXd& u) {
// Compute primal residuals
double eps_pri_x = 0.0, eps_pri_u = 0.0;
double eps_dual_x = 0.0, eps_dual_u = 0.0;
// Primal residual for x: ||x - z||
for (int t = 0; t <= T_; ++t) {
eps_pri_x += (x.col(t) - z_.col(t)).squaredNorm();
eps_dual_x += (rho_ * (z_.col(t) - z_old_.col(t))).squaredNorm();
}
// Primal residual for u: ||u - u_old||
for (int t = 0; t < T_; ++t) {
eps_pri_u += (u.col(t) - u_.col(t)).squaredNorm();
eps_dual_u += (rho_ * (u_.col(t) - u_old_.col(t))).squaredNorm();
}
double eps_pri = std::sqrt(eps_pri_x + eps_pri_u);
double eps_dual = std::sqrt(eps_dual_x + eps_dual_u);
// Compute tolerance
double tol_pri = std::sqrt(static_cast<double>((T_ + 1) * n_ + T_ * m_)) * abs_tol_
+ rel_tol_ * std::max(std::sqrt(eps_pri_x), std::sqrt(eps_dual_x));
double tol_dual = std::sqrt(static_cast<double>((T_ + 1) * n_ + T_ * m_)) * abs_tol_
+ rel_tol_ * std::sqrt(eps_dual_x + eps_dual_u);
return (eps_pri < tol_pri * tol_pri) && (eps_dual < tol_dual * tol_dual);
}
} // namespace rm_tinympc

View File

@@ -0,0 +1,56 @@
// Copyright (C) FYT Vision Group. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "rm_tinympc/codegen.hpp"
#include <cmath>
namespace rm_tinympc {
Eigen::Vector2d CodeGen::computeGimbalCommand(
const Eigen::Vector3d& target_pos,
const Eigen::Vector3d& target_vel,
double target_yaw,
double v_yaw,
double bullet_speed,
double gravity) {
// Simplified trajectory compensation
double dist = target_pos.norm();
double flying_time = dist / bullet_speed;
// Predict target position
Eigen::Vector3d predicted_pos = target_pos + flying_time * target_vel;
// Compute gimbal angles
double yaw = std::atan2(predicted_pos.y(), predicted_pos.x());
double pitch = std::atan2(predicted_pos.z(), predicted_pos.head(2).norm());
return Eigen::Vector2d(yaw, pitch);
}
Eigen::Vector3d CodeGen::predictTargetPosition(
const Eigen::Vector3d& pos,
const Eigen::Vector3d& vel,
double dt) {
return pos + dt * vel;
}
double CodeGen::computeFlyingTime(
const Eigen::Vector3d& target_pos,
double bullet_speed,
double gravity) {
double dist = target_pos.norm();
return dist / bullet_speed;
}
} // namespace rm_tinympc

View File

@@ -0,0 +1,71 @@
// Copyright (C) FYT Vision Group. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "rm_tinympc/tiny_api.hpp"
namespace rm_tinympc {
void TinyMPC::init(int n, int m, int T) {
n_ = n;
m_ = m;
T_ = T;
x_solution_.resize((T + 1) * n);
u_solution_.resize(T * m);
x_solution_.setZero();
u_solution_.setZero();
}
void TinyMPC::setProblem(
const Eigen::MatrixXd& A,
const Eigen::MatrixXd& B,
const Eigen::VectorXd& x0,
const Eigen::VectorXd& xref,
const Eigen::VectorXd& uref) {
A_ = A;
B_ = B;
x0_ = x0;
xref_ = xref;
uref_ = uref;
}
void TinyMPC::setWeights(const Eigen::VectorXd& Q, const Eigen::VectorXd& R, const Eigen::VectorXd& Qf) {
Q_ = Q;
R_ = R;
Qf_ = Qf;
}
void TinyMPC::solve() {
// Initialize with reference trajectory
x_solution_.segment(0, n_) = x0_;
// Simplified forward simulation using the dynamics
for (int t = 0; t < T_; ++t) {
int x_idx = t * n_;
int u_idx = t * m_;
// Use reference control if not set
if (t * m_ < uref_.size()) {
u_solution_.segment(u_idx, m_) = uref_.segment(u_idx, m_);
}
// Simulate dynamics: x_{t+1} = A * x_t + B * u_t
if ((t + 1) * n_ < x_solution_.size()) {
x_solution_.segment((t + 1) * n_, n_) = A_ * x_solution_.segment(x_idx, n_) +
B_ * u_solution_.segment(u_idx, m_);
}
}
}
} // namespace rm_tinympc

View File

@@ -0,0 +1,163 @@
// Copyright (C) FYT Vision Group. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "rm_tinympc/trajectory.hpp"
#include <cmath>
namespace rm_tinympc {
void QuinticPolynomial::compute(double p0, double v0, double a0,
double pf, double vf, double af, double T) {
T_ = T;
double T2 = T * T;
double T3 = T2 * T;
double T4 = T3 * T;
double T5 = T4 * T;
// Solve for coefficients using boundary conditions
// p(0) = p0, p(T) = pf
// v(0) = v0, v(T) = vf
// a(0) = a0, a(T) = af
Eigen::Matrix3d A;
A << T3, T4, T5,
3 * T2, 4 * T3, 5 * T4,
6 * T, 12 * T2, 20 * T3;
Eigen::Vector3d b;
b << pf - p0 - v0 * T - 0.5 * a0 * T2,
vf - v0 - a0 * T,
af - a0;
Eigen::Vector3d x = A.colPivHouseholderQr().solve(b);
coef_ << p0, v0, 0.5 * a0, x[0], x[1], x[2];
}
double QuinticPolynomial::evaluate(double t) const {
double t2 = t * t;
double t3 = t2 * t;
double t4 = t3 * t;
double t5 = t4 * t;
return coef_[0] + coef_[1] * t + coef_[2] * t2 + coef_[3] * t3 + coef_[4] * t4 + coef_[5] * t5;
}
double QuinticPolynomial::evaluateVelocity(double t) const {
double t2 = t * t;
double t3 = t2 * t;
double t4 = t3 * t;
return coef_[1] + 2 * coef_[2] * t + 3 * coef_[3] * t2 + 4 * coef_[4] * t3 + 5 * coef_[5] * t4;
}
double QuinticPolynomial::evaluateAcceleration(double t) const {
double t2 = t * t;
double t3 = t2 * t;
return 2 * coef_[2] + 6 * coef_[3] * t + 12 * coef_[4] * t2 + 20 * coef_[5] * t3;
}
std::vector<TrajectoryPoint1D> TrajectoryGenerator1D::generate(
double p0, double pf, double max_v, double max_a, double T) {
std::vector<TrajectoryPoint1D> trajectory;
// Compute quintic polynomial
poly_.compute(p0, 0, 0, pf, 0, 0, T);
// Sample trajectory
const int num_samples = static_cast<int>(T * 100); // 100 Hz
for (int i = 0; i <= num_samples; ++i) {
double t = static_cast<double>(i) / num_samples * T;
TrajectoryPoint1D pt;
pt.t = t;
pt.p = poly_.evaluate(t);
pt.v = poly_.evaluateVelocity(t);
pt.a = poly_.evaluateAcceleration(t);
trajectory.push_back(pt);
}
return trajectory;
}
double TrajectoryGenerator1D::generateMinTime(
double p0, double pf, double max_v, double max_a) {
double d = std::abs(pf - p0);
// Minimum time for bang-bang acceleration profile
// t_acc = max_v / max_a
// d_acc = 0.5 * max_a * t_acc^2 = max_v^2 / (2 * max_a)
// If distance < 2 * d_acc, we're limited by acceleration
// If distance >= 2 * d_acc, we can reach max_v
double t_acc = max_v / max_a;
double d_acc = 0.5 * max_a * t_acc * t_acc; // Distance for acceleration
double d_total = 2 * d_acc; // Distance for accel + decel
if (d <= d_total) {
// Triangle profile: can't reach max_v
// d = 0.5 * max_a * t^2 for triangle
T_ = std::sqrt(4 * d / max_a);
} else {
// Trapezoidal profile: reach max_v
double d_cruise = d - d_total;
double t_cruise = d_cruise / max_v;
T_ = 2 * t_acc + t_cruise;
}
return T_;
}
std::vector<TrajectoryPoint2D> GimbalTrajectoryGenerator::generate(
double yaw0, double pitch0, double yawf, double pitchf,
double max_yaw_rate, double max_pitch_rate,
double max_yaw_acc, double max_pitch_acc, double T) {
std::vector<TrajectoryPoint2D> trajectory;
// Compute quintic polynomials for yaw and pitch
poly_yaw_.compute(yaw0, 0, 0, yawf, 0, 0, T);
poly_pitch_.compute(pitch0, 0, 0, pitchf, 0, 0, T);
// Sample trajectory
const int num_samples = static_cast<int>(T * 100); // 100 Hz
for (int i = 0; i <= num_samples; ++i) {
double t = static_cast<double>(i) / num_samples * T;
TrajectoryPoint2D pt;
pt.t = t;
pt.yaw = poly_yaw_.evaluate(t);
pt.yaw_vel = poly_yaw_.evaluateVelocity(t);
pt.yaw_acc = poly_yaw_.evaluateAcceleration(t);
pt.pitch = poly_pitch_.evaluate(t);
pt.pitch_vel = poly_pitch_.evaluateVelocity(t);
pt.pitch_acc = poly_pitch_.evaluateAcceleration(t);
trajectory.push_back(pt);
}
return trajectory;
}
double GimbalTrajectoryGenerator::generateMinTime(
double yaw0, double pitch0, double yawf, double pitchf,
double max_yaw_rate, double max_pitch_rate,
double max_yaw_acc, double max_pitch_acc) {
TrajectoryGenerator1D yaw_gen, pitch_gen;
double T_yaw = yaw_gen.generateMinTime(yaw0, yawf, max_yaw_rate, max_yaw_acc);
double T_pitch = pitch_gen.generateMinTime(pitch0, pitchf, max_pitch_rate, max_pitch_acc);
// Use the maximum of the two times to ensure both constraints are satisfied
T_ = std::max(T_yaw, T_pitch);
return T_;
}
} // namespace rm_tinympc

View File

@@ -0,0 +1,179 @@
// Copyright (C) FYT Vision Group. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef ERROR_STATE_KALMAN_FILTER_HPP_
#define ERROR_STATE_KALMAN_FILTER_HPP_
#include <Eigen/Eigen>
#include <functional>
namespace fyt {
template <int X_N, int Z_N, typename PredictModel, typename MeasureModel>
class ErrorStateKalmanFilter {
public:
using StateVector = Eigen::Vector<double, X_N>;
using MeasurementVector = Eigen::Vector<double, Z_N>;
using StateMatrix = Eigen::Matrix<double, X_N, X_N>;
using MeasurementMatrix = Eigen::Matrix<double, Z_N, Z_N>;
using CrossCovarianceMatrix = Eigen::Matrix<double, X_N, Z_N>;
using JacobianMatrix = Eigen::Matrix<double, X_N, X_N>;
ErrorStateKalmanFilter() = default;
void init(const StateVector& initial_state,
const StateMatrix& initial_covariance,
const MeasurementMatrix& measurement_noise,
const StateMatrix& process_noise) {
state_ = initial_state;
covariance_ = initial_covariance;
R_ = measurement_noise;
Q_ = process_noise;
}
void setNoiseMatrices(const MeasurementMatrix& R, const StateMatrix& Q) {
R_ = R;
Q_ = Q;
}
void setResidualFunc(std::function<MeasurementVector(const MeasurementVector&,
const MeasurementVector&)> func) {
residual_func_ = func;
}
void setInjectFunc(std::function<void(StateVector&, const StateVector&)> func) {
inject_func_ = func;
}
StateVector predict(double dt) {
PredictModel predict_model(dt);
// State prediction using the motion model
StateVector predicted_state = state_;
predict_model(state_.data(), predicted_state.data());
// Jacobian of the motion model w.r.t. state
JacobianMatrix F = JacobianMatrix::Identity();
// This is a simplified Jacobian - for complex models, numerical differentiation
// or analytical Jacobians should be used
// For now, we use identity as the motion is linear in velocity terms
// Covariance prediction
covariance_ = F * covariance_ * F.transpose() + Q_;
state_ = predicted_state;
return state_;
}
StateVector update(const MeasurementVector& measurement) {
MeasurementVector z_predicted;
MeasureModel measure_model;
measure_model(state_.data(), z_predicted.data());
// Innovation (measurement residual)
MeasurementVector y;
if (residual_func_) {
y = residual_func_(measurement, z_predicted);
} else {
y = measurement - z_predicted;
}
// Innovation covariance
MeasurementMatrix S = measurement_jacobian_ * covariance_ * measurement_jacobian_.transpose() + R_;
// Kalman gain
CrossCovarianceMatrix Pxz = covariance_ * measurement_jacobian_.transpose();
Eigen::Matrix<double, X_N, Z_N> K = Pxz * S.inverse();
// State update
StateVector delta_x = K * y;
state_ = state_ + delta_x;
// Covariance update (Joseph form for numerical stability)
Eigen::Matrix<double, X_N, X_N> I_KH = JacobianMatrix::Identity() - K * measurement_jacobian_;
covariance_ = I_KH * covariance_ * I_KH.transpose() + K * R_ * K.transpose();
return state_;
}
StateVector updateIterative(const MeasurementVector& measurement, int max_iterations = 10) {
StateVector updated_state = state_;
MeasurementMatrix H = measurement_jacobian_;
for (int i = 0; i < max_iterations; ++i) {
MeasurementVector z_predicted;
MeasureModel measure_model;
measure_model(updated_state.data(), z_predicted.data());
MeasurementVector y;
if (residual_func_) {
y = residual_func_(measurement, z_predicted);
} else {
y = measurement - z_predicted;
}
MeasurementMatrix S = H * covariance_ * H.transpose() + R_;
Eigen::Matrix<double, X_N, Z_N> K = covariance_ * H.transpose() * S.inverse();
StateVector delta_x = K * y;
updated_state = updated_state + delta_x;
if (delta_x.norm() < 1e-6) {
break;
}
}
if (inject_func_) {
inject_func_(state_, updated_state - state_);
}
state_ = updated_state;
return state_;
}
void setState(const StateVector& state) {
state_ = state;
}
StateVector getState() const {
return state_;
}
StateMatrix getCovariance() const {
return covariance_;
}
double getNIS() const {
return last_nis_;
}
void setMeasurementJacobian(const MeasurementMatrix& H) {
measurement_jacobian_ = H;
}
private:
StateVector state_;
StateMatrix covariance_;
MeasurementMatrix R_; // Measurement noise covariance
StateMatrix Q_; // Process noise covariance
MeasurementMatrix measurement_jacobian_;
double last_nis_ = 0;
std::function<MeasurementVector(const MeasurementVector&, const MeasurementVector&)> residual_func_;
std::function<void(StateVector&, const StateVector&)> inject_func_;
};
} // namespace fyt
#endif // ERROR_STATE_KALMAN_FILTER_HPP_

View File

@@ -48,6 +48,8 @@ public:
using UpdateQFunc = std::function<MatrixXX()>;
using UpdateRFunc = std::function<MatrixZZ(const MatrixZ1 &z)>;
// Residual function for handling angle wraparound (e.g., -pi~pi)
using ResidualFunc = std::function<MatrixZ1(const MatrixZ1 &z_meas, const MatrixZ1 &z_pred)>;
explicit ExtendedKalmanFilter(const PredicFunc &f,
const MeasureFunc &h,
@@ -66,6 +68,14 @@ public:
void setMeasureFunc(const MeasureFunc &h) noexcept { this->h = h; }
// Set residual function for handling angle wraparound
void setResidualFunc(const ResidualFunc &residual_func) noexcept {
residual_func_ = residual_func;
}
// Check if residual function is set
bool hasResidualFunc() const noexcept { return residual_func_ != nullptr; }
// Compute a predicted state
MatrixX1 predict() noexcept {
ceres::Jet<double, N_X> x_e_jet[N_X];
@@ -106,14 +116,66 @@ public:
}
R = update_R(z);
// Compute innovation (measurement residual)
// Use residual_func if set to handle angle wraparound
if (residual_func_ != nullptr) {
innovation_ = residual_func_(z, z_pri_);
} else {
innovation_ = z - z_pri_;
}
S_ = H * P_pri * H.transpose() + R;
K = P_pri * H.transpose() * S_.inverse();
innovation_ = z - z_pri_;
x_post = x_post + K * innovation_;
P_post = (MatrixXX::Identity() - K * H) * P_pri;
return x_post;
}
// Update with iterative measurement update for better angle handling
MatrixX1 updateIterative(const MatrixZ1 &z, int max_iterations = 3) noexcept {
MatrixX1 updated_state = x_post;
for (int iter = 0; iter < max_iterations; ++iter) {
ceres::Jet<double, N_X> x_jet[N_X];
for (int i = 0; i < N_X; i++) {
x_jet[i].a = updated_state[i];
x_jet[i].v[i] = 1;
}
ceres::Jet<double, N_X> z_jet[N_Z];
h(x_jet, z_jet);
MatrixZ1 z_pred;
for (int i = 0; i < N_Z; i++) {
z_pred[i] = z_jet[i].a;
}
R = update_R(z);
MatrixZ1 innovation;
if (residual_func_ != nullptr) {
innovation = residual_func_(z, z_pred);
} else {
innovation = z - z_pred;
}
MatrixZZ S = H * P_pri * H.transpose() + R;
MatrixXZ K = P_pri * H.transpose() * S.inverse();
MatrixX1 delta_x = K * innovation;
updated_state = updated_state + delta_x;
// Check convergence
if (delta_x.norm() < 1e-6) {
break;
}
}
x_post = updated_state;
return x_post;
}
// Get innovation (z - z_pri)
const MatrixZ1 & getInnovation() const { return innovation_; }
@@ -126,6 +188,15 @@ public:
return innovation_.transpose() * S_.inverse() * innovation_;
}
// Get Kalman gain
const MatrixXZ & getKalmanGain() const { return K; }
// Get posterior state
const MatrixX1 & getState() const { return x_post; }
// Get posterior covariance
const MatrixXX & getCovariance() const { return P_post; }
private:
// Process nonlinear vector function
PredicFunc f;
@@ -139,6 +210,8 @@ private:
// Measurement noise covariance matrix
UpdateRFunc update_R;
MatrixZZ R;
// Residual function for angle wraparound handling
ResidualFunc residual_func_;
// Priori error estimate covariance matrix
MatrixXX P_pri;