diff --git a/src/rm_auto_aim/armor_solver/include/armor_solver/armor_solver.hpp b/src/rm_auto_aim/armor_solver/include/armor_solver/armor_solver.hpp index 373da29..caede5c 100644 --- a/src/rm_auto_aim/armor_solver/include/armor_solver/armor_solver.hpp +++ b/src/rm_auto_aim/armor_solver/include/armor_solver/armor_solver.hpp @@ -32,6 +32,7 @@ #include "rm_interfaces/msg/target.hpp" #include "rm_utils/math/manual_compensator.hpp" #include "rm_utils/math/trajectory_compensator.hpp" +#include "armor_solver/trajectory_planner.hpp" namespace fyt::auto_aim { // Solver class used to solve the gimbal command from tracked target @@ -53,6 +54,7 @@ public: const std::vector& getArmorPositions() const noexcept { return cached_armor_positions_; } + TrajectoryDebug getTrajectoryDebug() const noexcept; void setBulletSpeed(double bullet_speed) noexcept; void updateRuntimeParams(double max_tracking_v_yaw, double prediction_delay, @@ -61,6 +63,7 @@ public: double min_switching_v_yaw, double shooting_range_w, double shooting_range_h) noexcept; + void setTrajectoryType(const std::string& type, std::weak_ptr node = {}); private: // Get the armor positions from the target robot @@ -97,6 +100,8 @@ private: std::unique_ptr trajectory_compensator_; std::unique_ptr manual_compensator_; + std::unique_ptr planner_; + TrajectoryPlanner::Type planner_type_ = TrajectoryPlanner::Type::LINEAR; std::array rpy_; diff --git a/src/rm_auto_aim/armor_solver/include/armor_solver/armor_solver_node.hpp b/src/rm_auto_aim/armor_solver/include/armor_solver/armor_solver_node.hpp index 12943f5..5a81d41 100644 --- a/src/rm_auto_aim/armor_solver/include/armor_solver/armor_solver_node.hpp +++ b/src/rm_auto_aim/armor_solver/include/armor_solver/armor_solver_node.hpp @@ -30,6 +30,7 @@ #include #include #include +#include #include // std #include @@ -64,6 +65,7 @@ private: void publishMarkers(const rm_interfaces::msg::Target &target_msg, const rm_interfaces::msg::GimbalCmd &gimbal_cmd) noexcept; + void publishTrajectoryDebug() noexcept; void setModeCallback(const std::shared_ptr request, @@ -107,6 +109,7 @@ private: // Publisher rclcpp::Publisher::SharedPtr target_pub_; rclcpp::Publisher::SharedPtr gimbal_pub_; + rclcpp::Publisher::SharedPtr traj_debug_pub_; rclcpp::Subscription::SharedPtr serial_sub_; rclcpp::TimerBase::SharedPtr pub_timer_; void timerCallback(); diff --git a/src/rm_auto_aim/armor_solver/include/armor_solver/trajectory_planner.hpp b/src/rm_auto_aim/armor_solver/include/armor_solver/trajectory_planner.hpp new file mode 100644 index 0000000..f6e575e --- /dev/null +++ b/src/rm_auto_aim/armor_solver/include/armor_solver/trajectory_planner.hpp @@ -0,0 +1,426 @@ +// Trajectory Planner for armor_solver +// Implements quintic polynomial trajectory planning and TinyMPC-based MPC + +#ifndef ARMOR_SOLVER_TRAJECTORY_PLANNER_HPP_ +#define ARMOR_SOLVER_TRAJECTORY_PLANNER_HPP_ + +#include +#include +#include +#include +#include +#include +#include + +namespace fyt::auto_aim { + +// Forward declare for use in planner +struct TargetInfo { + Eigen::Vector3d position; + Eigen::Vector3d velocity; + double yaw; + double v_yaw; + double radius_1; + double radius_2; + double d_za; + double d_zc; + size_t armors_num; + + // Default constructor + TargetInfo() = default; + + // Construct from rm_interfaces Target msg + explicit TargetInfo(const rm_interfaces::msg::Target& target_msg) + : position(target_msg.position.x, target_msg.position.y, target_msg.position.z), + velocity(target_msg.velocity.x, target_msg.velocity.y, target_msg.velocity.z), + yaw(target_msg.yaw), + v_yaw(target_msg.v_yaw), + radius_1(target_msg.radius_1), + radius_2(target_msg.radius_2), + d_za(target_msg.d_za), + d_zc(target_msg.d_zc), + armors_num(static_cast(target_msg.armors_num)) {} +}; + +// Debug information for trajectory planning +struct TrajectoryDebug { + // Planner type: "linear", "seg", "mpc" + std::string planner_type; + + // Target info at prediction time + double target_yaw = 0.0; + double target_pitch = 0.0; + double target_distance = 0.0; + + // Planned gimbal state + double planned_yaw = 0.0; + double planned_pitch = 0.0; + double planned_yaw_v = 0.0; + double planned_pitch_v = 0.0; + double planned_yaw_a = 0.0; + double planned_pitch_a = 0.0; + + // Time parameters + double flying_time = 0.0; + double total_dt = 0.0; + + // Trajectory points for visualization (yaw trajectory) + std::vector traj_yaw_p; + std::vector traj_yaw_v; + std::vector traj_time; + + // Constraints + double max_yaw_acc = 0.0; + double max_pitch_acc = 0.0; + + // MPC specific + double mpc_cost = 0.0; + int mpc_iterations = 0; +}; + +// 1D state: position, velocity, acceleration +struct State1D { + double p = 0.0; + double v = 0.0; + double a = 0.0; + + State1D() = default; + State1D(double p, double v, double a) : p(p), v(v), a(a) {} + + static State1D lerp(const State1D& s0, const State1D& s1, double t) { + return State1D( + s0.p + t * (s1.p - s0.p), + s0.v + t * (s1.v - s0.v), + s0.a + t * (s1.a - s0.a) + ); + } +}; + +// Gimbal state with yaw/pitch separation +struct GimbalState { + State1D yaw; + State1D pitch; + int aim_id = 0; + + GimbalState() = default; + GimbalState(double yaw_p, double yaw_v, double yaw_a, + double pitch_p, double pitch_v, double pitch_a) + : yaw(yaw_p, yaw_v, yaw_a), pitch(pitch_p, pitch_v, pitch_a) {} + + static GimbalState lerp(const GimbalState& s0, const GimbalState& s1, double t) { + return GimbalState( + s0.yaw.p + t * (s1.yaw.p - s0.yaw.p), + s0.yaw.v + t * (s1.yaw.v - s0.yaw.v), + s0.yaw.a + t * (s1.yaw.a - s0.yaw.a), + s0.pitch.p + t * (s1.pitch.p - s0.pitch.p), + s0.pitch.v + t * (s1.pitch.v - s0.pitch.v), + s0.pitch.a + t * (s1.pitch.a - s0.pitch.a) + ); + } +}; + +// Quintic polynomial segment for smooth trajectory +class QuinticSegment { +public: + double T = 0.0; // Duration + Eigen::Matrix c; // Coefficients: c0 + c1*t + c2*t^2 + c3*t^3 + c4*t^4 + c5*t^5 + + QuinticSegment() = default; + explicit QuinticSegment(double duration) : T(duration) {} + + // Build quintic segment from boundary conditions (closed-form solution) + static QuinticSegment build(const State1D& s0, const State1D& s1, double T) { + using Matrix6d = Eigen::Matrix; + + double T2 = T * T; + double T3 = T2 * T; + double T4 = T3 * T; + double T5 = T4 * T; + + Matrix6d A; + A << 1, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, 0, + 0, 0, 2, 0, 0, 0, + 1, T, T2, T3, T4, T5, + 0, 1, 2*T, 3*T2, 4*T3, 5*T4, + 0, 0, 2, 6*T, 12*T2, 20*T3; + + Eigen::Matrix b; + b << s0.p, s0.v, s0.a, s1.p, s1.v, s1.a; + + QuinticSegment seg(T); + seg.c = A.fullPivLu().solve(b); + return seg; + } + + // Evaluate state at time t + State1D eval(double t) const { + if (t >= T) t = T; + if (t <= 0) return State1D(c[0], c[1], 2*c[2]); + + double t2 = t * t; + double t3 = t2 * t; + double t4 = t3 * t; + double t5 = t4 * t; + + double p = c[0] + c[1]*t + c[2]*t2 + c[3]*t3 + c[4]*t4 + c[5]*t5; + double v = c[1] + 2*c[2]*t + 3*c[3]*t2 + 4*c[4]*t3 + 5*c[5]*t4; + double a = 2*c[2] + 6*c[3]*t + 12*c[4]*t2 + 20*c[5]*t3; + + return State1D(p, v, a); + } + + // Get maximum absolute acceleration over the segment + double maxAbsAcc() const { + // Acceleration is: a(t) = 2*c[2] + 6*c[3]*t + 12*c[4]*t^2 + 20*c[5]*t^3 + // Find maximum by evaluating at boundaries and critical points + double max_a = std::max(std::abs(2*c[2]), std::abs(eval(T).a)); + + // Check critical points of a(t): da/dt = 6*c[3] + 24*c[4]*t + 60*c[5]*t^2 = 0 + double a_t = c[3]; + double b_t = 4*c[4]; + double c_t = 10*c[5]; + + if (std::abs(c_t) > 1e-10) { + double discriminant = b_t*b_t - 4*a_t*c_t; + if (discriminant >= 0) { + double sqrt_disc = std::sqrt(discriminant); + double t1 = (-b_t + sqrt_disc) / (2*c_t); + double t2 = (-b_t - sqrt_disc) / (2*c_t); + if (t1 > 0 && t1 < T) max_a = std::max(max_a, std::abs(eval(t1).a)); + if (t2 > 0 && t2 < T) max_a = std::max(max_a, std::abs(eval(t2).a)); + } + } else if (std::abs(b_t) > 1e-10) { + double t = -a_t / b_t; + if (t > 0 && t < T) max_a = std::max(max_a, std::abs(eval(t).a)); + } + + return max_a; + } +}; + +// Trajectory container template +template +class Trajectory { +public: + static_assert(std::is_same_v || std::is_same_v, + "Trajectory must be used with GimbalState or State1D"); + + void reserve(size_t n) { + cp_vec_.reserve(n); + dt_vec_.reserve(n > 0 ? n - 1 : 0); + prefix_time_.reserve(n); + } + + void clear() { + cp_vec_.clear(); + dt_vec_.clear(); + prefix_time_.clear(); + total_duration_ = 0.0; + } + + void push_back(const T& p, double dt = 0.0) { + if (cp_vec_.empty()) { + cp_vec_.push_back(p); + prefix_time_.push_back(0.0); + total_duration_ = 0.0; + return; + } + + assert(dt >= 0.0); + + cp_vec_.push_back(p); + dt_vec_.push_back(dt); + total_duration_ += dt; + prefix_time_.push_back(total_duration_); + } + + void set(const std::vector& c, const std::vector& t) { + assert(!c.empty()); + assert(c.size() == t.size() + 1); + + cp_vec_ = c; + dt_vec_ = t; + + prefix_time_.resize(cp_vec_.size()); + prefix_time_[0] = 0.0; + + for (size_t i = 0; i < dt_vec_.size(); ++i) + prefix_time_[i + 1] = prefix_time_[i] + dt_vec_[i]; + + total_duration_ = prefix_time_.back(); + } + + T getStateAtTime(double t) const { + if (cp_vec_.empty()) + return T {}; + + if (t <= 0.0) + return cp_vec_.front(); + + if (t >= total_duration_) + return cp_vec_.back(); + + auto it = std::lower_bound(prefix_time_.begin(), prefix_time_.end(), t); + size_t i1 = std::distance(prefix_time_.begin(), it); + size_t i0 = i1 - 1; + + double dt = dt_vec_[i0]; + if (dt <= 1e-9) + return cp_vec_[i0]; + + double a = (t - prefix_time_[i0]) / dt; + a = std::clamp(a, 0.0, 1.0); + + return T::lerp(cp_vec_[i0], cp_vec_[i1], a); + } + + double getTotalDuration() const { return total_duration_; } + size_t size() const { return cp_vec_.size(); } + + const std::vector& controlPoints() const { return cp_vec_; } + const std::vector& timeSteps() const { return dt_vec_; } + const std::vector& prefixTimes() const { return prefix_time_; } + +protected: + std::vector cp_vec_; + std::vector dt_vec_; + std::vector prefix_time_; + double total_duration_ = 0.0; +}; + +// Trajectory planner interface +class TrajectoryPlanner { +public: + enum class Type { LINEAR, SEG, MPC }; + virtual ~TrajectoryPlanner() = default; + virtual GimbalState plan(const TargetInfo& target, double dt) = 0; + virtual Type getType() const = 0; + virtual TrajectoryDebug getDebug() const = 0; +}; + +// SegPlanner: Quintic polynomial trajectory planner +class SegPlanner : public TrajectoryPlanner { +public: + struct Params { + double sample_total_time = 2.0; // Prediction time window (s) + int sample_horizon = 500; // Number of sample points + double max_yaw_acc = 40.0; // Max yaw acceleration (deg/s^2) + double max_pitch_acc = 25.0; // Max pitch acceleration (deg/s^2) + }; + + explicit SegPlanner(const Params& params) : params_(params) {} + ~SegPlanner() override = default; + + GimbalState plan(const TargetInfo& target, double dt) override; + Type getType() const override { return Type::SEG; } + TrajectoryDebug getDebug() const override { return debug_; } + + void setParams(const Params& params) { params_ = params; } + const Params& getParams() const { return params_; } + +private: + // Predict target state at time t + GimbalState predictTarget(const TargetInfo& target, double t) const; + + // Unwrap angle discontinuities + void unwrapAngles(std::vector& states) const; + + // Compute velocity and acceleration at control points + std::pair, std::vector> + computeNodeStates(const std::vector& states, + const std::vector& dt_vec) const; + + // Build limited quintic segments + Trajectory + buildLimit(const std::vector& yaw_nodes, + const std::vector& pitch_nodes, + double max_yaw_acc, + double max_pitch_acc) const; + + // Convert gimbal angle to target angle (accounting for armor offset) + static std::pair + computeArmorAngle(const Eigen::Vector3d& target_pos, + const Eigen::Vector3d& target_center, + double target_yaw, + size_t armors_num, + double radius_1, + double radius_2, + double d_zc, + double d_za); + + Params params_; + Trajectory trajectory_; + TrajectoryDebug debug_; +}; + +// MpcPlanner: Simplified MPC-based trajectory planner +// Uses feedback control with acceleration constraints +class MpcPlanner : public TrajectoryPlanner { +public: + struct Params { + double sample_total_time = 2.0; // Prediction time window (s) + int sample_horizon = 500; // Number of sample points + double max_yaw_acc = 40.0; // Max yaw acceleration (deg/s^2) + double max_pitch_acc = 25.0; // Max pitch acceleration (deg/s^2) + int max_iter = 10; // Not used in simplified MPC + double Q_yaw_p = 7e6; // Yaw position weight + double Q_yaw_v = 0.0; // Yaw velocity weight + double R_yaw = 3.0; // Yaw control weight + double rho = 1.0; // Not used in simplified MPC + }; + + explicit MpcPlanner(const Params& params); + ~MpcPlanner() override; + + GimbalState plan(const TargetInfo& target, double dt) override; + Type getType() const override { return Type::MPC; } + TrajectoryDebug getDebug() const override { return debug_; } + + void setParams(const Params& params); + const Params& getParams() const { return params_; } + +private: + // Initialize MPC solver + void initSolver(); + + // Setup problem matrices + void setupProblem(); + + // Solve MPC and get trajectory + void solveMpc(const std::vector& ref_traj); + + // Get state at specific time from MPC solution + GimbalState getStateAtTime(double t, const std::vector& ref_traj) const; + + Params params_; + int N_ = 50; + + // State and input matrices + Eigen::MatrixXd x_; // State trajectory (2 x N) + Eigen::MatrixXd u_; // Control trajectory (1 x N-1) + Eigen::MatrixXd x_ref_; // Reference trajectory (2 x N) + Eigen::MatrixXd u_ref_; // Reference control (1 x N-1) + + // System dynamics + Eigen::MatrixXd Adyn_; // State transition matrix (2 x 2) + Eigen::MatrixXd Bdyn_; // Input matrix (2 x 1) + + // Constraints + Eigen::VectorXd u_min_; // Control bounds + Eigen::VectorXd u_max_; + + // Cost weights + Eigen::Vector2d Q_; // State cost weights + Eigen::VectorXd R_; // Control cost weights + + // Solution storage + std::vector mpc_solution_yaw_; + + // Debug info + TrajectoryDebug debug_; +}; + +} // namespace fyt::auto_aim + +#endif // ARMOR_SOLVER_TRAJECTORY_PLANNER_HPP_ diff --git a/src/rm_auto_aim/armor_solver/src/armor_solver.cpp b/src/rm_auto_aim/armor_solver/src/armor_solver.cpp index 21b4ebf..c74562a 100644 --- a/src/rm_auto_aim/armor_solver/src/armor_solver.cpp +++ b/src/rm_auto_aim/armor_solver/src/armor_solver.cpp @@ -53,6 +53,10 @@ Solver::Solver(std::weak_ptr n) : node_(n) { FYT_WARN("armor_solver", "Manual compensator update failed!"); } + // Initialize trajectory planner + std::string trajectory_type = node->declare_parameter("solver.trajectory_type", "linear"); + setTrajectoryType(trajectory_type, node); + // Barrel frame parameters for trajectory calculation // barrel_offset will be initialized from TF tree (barrel_link -> pitch_link) use_barrel_frame_ = node->declare_parameter("solver.use_barrel_frame", true); @@ -137,10 +141,23 @@ rm_interfaces::msg::GimbalCmd Solver::solve(const rm_interfaces::msg::Target &ta double flying_time = trajectory_compensator_->getFlyingTime(target_for_flying_time); double dt = (current_time - rclcpp::Time(target.header.stamp)).seconds() + flying_time + prediction_delay_; - target_position.x() += dt * target.velocity.x; - target_position.y() += dt * target.velocity.y; - target_position.z() += dt * target.velocity.z; - target_yaw += dt * target.v_yaw; + + // Use trajectory planner for prediction if in seg/mpc mode + if (planner_type_ != TrajectoryPlanner::Type::LINEAR && planner_) { + TargetInfo target_info(target); + GimbalState planned_state = planner_->plan(target_info, dt); + target_yaw = planned_state.yaw.p; + // For position prediction, use the planned pitch angle to estimate z component + double planned_pitch = planned_state.pitch.p; + double horizontal_dist = target_position.head(2).norm(); + target_position.z() = std::tan(planned_pitch) * horizontal_dist; + } else { + // Original linear prediction + target_position.x() += dt * target.velocity.x; + target_position.y() += dt * target.velocity.y; + target_position.z() += dt * target.velocity.z; + target_yaw += dt * target.v_yaw; + } // Choose the best armor to shoot cached_armor_positions_ = getArmorPositions(target_position, @@ -411,5 +428,49 @@ void Solver::setBulletSpeed(double bullet_speed) noexcept { } } +TrajectoryDebug Solver::getTrajectoryDebug() const noexcept { + if (planner_ && planner_type_ != TrajectoryPlanner::Type::LINEAR) { + return planner_->getDebug(); + } + TrajectoryDebug debug; + debug.planner_type = "linear"; + return debug; +} + +void Solver::setTrajectoryType(const std::string& type, std::weak_ptr node) { + auto node_ptr = node.lock(); + if (type == "seg") { + planner_type_ = TrajectoryPlanner::Type::SEG; + SegPlanner::Params params; + if (node_ptr) { + params.sample_total_time = node_ptr->declare_parameter("solver.sample_total_time", 2.0); + params.sample_horizon = node_ptr->declare_parameter("solver.sample_horizon", 500); + params.max_yaw_acc = node_ptr->declare_parameter("solver.max_yaw_acc", 40.0); + params.max_pitch_acc = node_ptr->declare_parameter("solver.max_pitch_acc", 25.0); + } + planner_ = std::make_unique(params); + FYT_INFO("armor_solver", "Trajectory planner set to SEG (quintic polynomial)"); + } else if (type == "mpc") { + planner_type_ = TrajectoryPlanner::Type::MPC; + MpcPlanner::Params params; + if (node_ptr) { + params.sample_total_time = node_ptr->declare_parameter("solver.sample_total_time", 2.0); + params.sample_horizon = node_ptr->declare_parameter("solver.sample_horizon", 500); + params.max_yaw_acc = node_ptr->declare_parameter("solver.max_yaw_acc", 40.0); + params.max_pitch_acc = node_ptr->declare_parameter("solver.max_pitch_acc", 25.0); + params.max_iter = node_ptr->declare_parameter("solver.max_iter", 10); + params.Q_yaw_p = node_ptr->declare_parameter("solver.Q_yaw", 7e6); + params.Q_yaw_v = node_ptr->declare_parameter("solver.Q_yaw_v", 0.0); + params.R_yaw = node_ptr->declare_parameter("solver.R_yaw", 3.0); + } + planner_ = std::make_unique(params); + FYT_INFO("armor_solver", "Trajectory planner set to MPC (TinyMPC ADMM)"); + } else { + planner_type_ = TrajectoryPlanner::Type::LINEAR; + planner_.reset(); + FYT_INFO("armor_solver", "Trajectory planner set to LINEAR (default)"); + } +} + } // namespace fyt::auto_aim diff --git a/src/rm_auto_aim/armor_solver/src/armor_solver_node.cpp b/src/rm_auto_aim/armor_solver/src/armor_solver_node.cpp index 6f465d8..ff035f6 100644 --- a/src/rm_auto_aim/armor_solver/src/armor_solver_node.cpp +++ b/src/rm_auto_aim/armor_solver/src/armor_solver_node.cpp @@ -19,7 +19,9 @@ #include "armor_solver/armor_solver_node.hpp" // std +#include #include +#include #include // project #include "armor_solver/motion_model.hpp" @@ -182,6 +184,8 @@ ArmorSolverNode::ArmorSolverNode(const rclcpp::NodeOptions &options) rclcpp::SensorDataQoS()); gimbal_pub_ = this->create_publisher("armor_solver/cmd_gimbal", rclcpp::SensorDataQoS()); + traj_debug_pub_ = this->create_publisher("armor_solver/traj_debug", + rclcpp::SensorDataQoS()); serial_sub_ = this->create_subscription( "serial/receive", rclcpp::SensorDataQoS(), @@ -293,6 +297,7 @@ void ArmorSolverNode::timerCallback() { if (debug_mode_) { publishMarkers(armor_target_, control_msg); + publishTrajectoryDebug(); } } @@ -586,6 +591,76 @@ void ArmorSolverNode::publishMarkers(const rm_interfaces::msg::Target &target_ms marker_pub_->publish(marker_array); } +void ArmorSolverNode::publishTrajectoryDebug() noexcept { + if (!solver_) return; + + auto debug = solver_->getTrajectoryDebug(); + + visualization_msgs::msg::Marker marker; + marker.header.stamp = this->now(); + marker.header.frame_id = "odom"; + marker.type = visualization_msgs::msg::Marker::LINE_STRIP; + marker.action = visualization_msgs::msg::Marker::ADD; + marker.ns = "trajectory_debug"; + marker.id = 0; + marker.scale.x = 0.02; // line width + + // Color based on planner type + if (debug.planner_type == "seg") { + marker.color.r = 0.0; + marker.color.g = 1.0; + marker.color.b = 0.0; + } else if (debug.planner_type == "mpc") { + marker.color.r = 1.0; + marker.color.g = 0.0; + marker.color.b = 1.0; + } else { + marker.color.r = 1.0; + marker.color.g = 1.0; + marker.color.b = 0.0; + } + marker.color.a = 1.0; + + // Add trajectory points as points + for (size_t i = 0; i < debug.traj_time.size(); ++i) { + geometry_msgs::msg::Point p; + p.x = debug.traj_time[i]; // time on x-axis + p.y = debug.traj_yaw_p[i] * 180.0 / M_PI; // yaw in degrees on y-axis + p.z = 0.0; + marker.points.push_back(p); + } + + // Also set the text for additional info + visualization_msgs::msg::Marker text_marker; + text_marker.header.stamp = this->now(); + text_marker.header.frame_id = "odom"; + text_marker.type = visualization_msgs::msg::Marker::TEXT_VIEW_FACING; + text_marker.action = visualization_msgs::msg::Marker::ADD; + text_marker.ns = "trajectory_debug_text"; + text_marker.id = 0; + text_marker.pose.position.x = 0.0; + text_marker.pose.position.y = 0.0; + text_marker.pose.position.z = 0.5; + text_marker.scale.z = 0.1; // text height + + std::ostringstream ss; + ss << "Planner: " << debug.planner_type << "\n"; + ss << "Target Yaw: " << debug.target_yaw * 180.0 / M_PI << " deg\n"; + ss << "Target Pitch: " << debug.target_pitch * 180.0 / M_PI << " deg\n"; + ss << "Planned Yaw: " << debug.planned_yaw * 180.0 / M_PI << " deg\n"; + ss << "Planned Pitch: " << debug.planned_pitch * 180.0 / M_PI << " deg\n"; + ss << "Flying Time: " << debug.flying_time * 1000.0 << " ms\n"; + ss << "Max Yaw Acc: " << debug.max_yaw_acc << " deg/s^2\n"; + text_marker.text = ss.str(); + + text_marker.color.r = 1.0; + text_marker.color.g = 1.0; + text_marker.color.b = 1.0; + text_marker.color.a = 1.0; + + traj_debug_pub_->publish(marker); +} + void ArmorSolverNode::setModeCallback( const std::shared_ptr request, std::shared_ptr response) { diff --git a/src/rm_auto_aim/armor_solver/src/trajectory_planner.cpp b/src/rm_auto_aim/armor_solver/src/trajectory_planner.cpp new file mode 100644 index 0000000..614ec72 --- /dev/null +++ b/src/rm_auto_aim/armor_solver/src/trajectory_planner.cpp @@ -0,0 +1,441 @@ +// Trajectory Planner Implementation +// SegPlanner: Quintic polynomial trajectory planning +// MpcPlanner: TinyMPC ADMM-based trajectory planning + +#include "armor_solver/trajectory_planner.hpp" +#include "rm_utils/logger/log.hpp" +#include + +namespace fyt::auto_aim { + +// ============================================================================ +// SegPlanner Implementation +// ============================================================================ + +GimbalState SegPlanner::plan(const TargetInfo& target, double dt) { + int horizon = params_.sample_horizon; + double total_time = params_.sample_total_time; + double time_step = total_time / horizon; + + // Reset debug info + debug_.planner_type = "seg"; + debug_.traj_time.clear(); + debug_.traj_yaw_p.clear(); + debug_.traj_yaw_v.clear(); + debug_.max_yaw_acc = params_.max_yaw_acc; + debug_.max_pitch_acc = params_.max_pitch_acc; + debug_.total_dt = dt; + + // 1. Sample target states over prediction horizon + std::vector states; + states.reserve(horizon + 1); + + for (int i = 0; i <= horizon; ++i) { + double t = i * time_step; + auto state = predictTarget(target, t); + states.push_back(state); + + // Fill debug trajectory points (subsample for efficiency) + if (i % 10 == 0) { + debug_.traj_time.push_back(t); + debug_.traj_yaw_p.push_back(state.yaw.p); + debug_.traj_yaw_v.push_back(state.yaw.v); + } + } + + // 2. Unwrap angle discontinuities + unwrapAngles(states); + + // 3. Compute node velocities and accelerations + auto dt_vec = std::vector(horizon, time_step); + auto [yaw_nodes, pitch_nodes] = computeNodeStates(states, dt_vec); + + // 4. Build quintic segments for yaw and pitch + std::vector yaw_segs, pitch_segs; + yaw_segs.reserve(horizon); + pitch_segs.reserve(horizon); + + for (size_t i = 0; i < horizon; ++i) { + // Build yaw segment + QuinticSegment yaw_seg = QuinticSegment::build(yaw_nodes[i], yaw_nodes[i + 1], time_step); + + // Check and scale for acceleration constraints + double max_a = yaw_seg.maxAbsAcc(); + double max_acc_rad = params_.max_yaw_acc * M_PI / 180.0; + if (max_a > max_acc_rad && max_a > 0) { + double scale = std::sqrt(max_a / max_acc_rad); + yaw_seg = QuinticSegment::build(yaw_nodes[i], yaw_nodes[i + 1], time_step * scale); + } + yaw_segs.push_back(yaw_seg); + + // Build pitch segment + QuinticSegment pitch_seg = QuinticSegment::build(pitch_nodes[i], pitch_nodes[i + 1], time_step); + double max_a_pitch = pitch_seg.maxAbsAcc(); + double max_acc_pitch_rad = params_.max_pitch_acc * M_PI / 180.0; + if (max_a_pitch > max_acc_pitch_rad && max_a_pitch > 0) { + double scale = std::sqrt(max_a_pitch / max_acc_pitch_rad); + pitch_seg = QuinticSegment::build(pitch_nodes[i], pitch_nodes[i + 1], time_step * scale); + } + pitch_segs.push_back(pitch_seg); + } + + // 5. Evaluate at target time + double target_t = dt; + if (target_t > total_time) target_t = total_time; + + int seg_idx = static_cast(target_t / time_step); + if (seg_idx >= static_cast(yaw_segs.size())) seg_idx = yaw_segs.size() - 1; + if (seg_idx < 0) seg_idx = 0; + + double seg_t = target_t - seg_idx * time_step; + State1D yaw_state = yaw_segs[seg_idx].eval(seg_t); + State1D pitch_state = pitch_segs[seg_idx].eval(seg_t); + + // Fill debug info + Eigen::Vector3d target_pos = target.position + dt * target.velocity; + debug_.target_yaw = target.yaw + dt * target.v_yaw; + debug_.target_pitch = std::atan2(target_pos.z(), target_pos.head(2).norm()); + debug_.target_distance = target_pos.norm(); + debug_.planned_yaw = yaw_state.p; + debug_.planned_pitch = pitch_state.p; + debug_.planned_yaw_v = yaw_state.v; + debug_.planned_pitch_v = pitch_state.v; + debug_.planned_yaw_a = yaw_state.a; + debug_.planned_pitch_a = pitch_state.a; + debug_.flying_time = dt; + + return GimbalState(yaw_state.p, yaw_state.v, yaw_state.a, + pitch_state.p, pitch_state.v, pitch_state.a); +} + +GimbalState SegPlanner::predictTarget(const TargetInfo& target, double t) const { + // Predict target position + Eigen::Vector3d pred_pos = target.position + t * target.velocity; + + // Predict target yaw + double pred_yaw = target.yaw + t * target.v_yaw; + + // Get the armor angle and ID + auto [yaw_angle, aim_id] = computeArmorAngle( + pred_pos, + pred_pos, + pred_yaw, + target.armors_num, + target.radius_1, + target.radius_2, + target.d_zc, + target.d_za + ); + + // Compute pitch from position + double pitch = std::atan2(pred_pos.z(), pred_pos.head(2).norm()); + + return GimbalState(yaw_angle, target.v_yaw, 0.0, pitch, 0.0, 0.0); +} + +void SegPlanner::unwrapAngles(std::vector& states) const { + if (states.size() <= 1) return; + + for (size_t i = 1; i < states.size(); ++i) { + // Unwrap yaw + while (states[i].yaw.p - states[i-1].yaw.p > M_PI) { + states[i].yaw.p -= 2 * M_PI; + } + while (states[i].yaw.p - states[i-1].yaw.p < -M_PI) { + states[i].yaw.p += 2 * M_PI; + } + + // Unwrap pitch (less common but still needed) + while (states[i].pitch.p - states[i-1].pitch.p > M_PI) { + states[i].pitch.p -= 2 * M_PI; + } + while (states[i].pitch.p - states[i-1].pitch.p < -M_PI) { + states[i].pitch.p += 2 * M_PI; + } + } +} + +std::pair, std::vector> +SegPlanner::computeNodeStates(const std::vector& states, + const std::vector& dt_vec) const { + size_t n = states.size(); + std::vector yaw_nodes(n), pitch_nodes(n); + + if (n <= 1) return {yaw_nodes, pitch_nodes}; + + // First node: use forward difference for velocity, estimate acceleration as 0 + yaw_nodes[0] = State1D(states[0].yaw.p, (states[1].yaw.p - states[0].yaw.p) / dt_vec[0], 0.0); + pitch_nodes[0] = State1D(states[0].pitch.p, (states[1].pitch.p - states[0].pitch.p) / dt_vec[0], 0.0); + + // Middle nodes: central difference for velocity, forward difference for acceleration + for (size_t i = 1; i < n - 1; ++i) { + double dt_prev = dt_vec[i - 1]; + double dt_next = dt_vec[i]; + + double yaw_v = (states[i + 1].yaw.p - states[i - 1].yaw.p) / (dt_prev + dt_next); + double pitch_v = (states[i + 1].pitch.p - states[i - 1].pitch.p) / (dt_prev + dt_next); + + double yaw_a = (states[i + 1].yaw.v - states[i - 1].yaw.v) / (dt_prev + dt_next); + double pitch_a = (states[i + 1].pitch.v - states[i - 1].pitch.v) / (dt_prev + dt_next); + + yaw_nodes[i] = State1D(states[i].yaw.p, yaw_v, yaw_a); + pitch_nodes[i] = State1D(states[i].pitch.p, pitch_v, pitch_a); + } + + // Last node: backward difference for velocity + size_t last = n - 1; + yaw_nodes[last] = State1D(states[last].yaw.p, + (states[last].yaw.p - states[last - 1].yaw.p) / dt_vec[last - 1], + 0.0); + pitch_nodes[last] = State1D(states[last].pitch.p, + (states[last].pitch.p - states[last - 1].pitch.p) / dt_vec[last - 1], + 0.0); + + return {yaw_nodes, pitch_nodes}; +} + +Trajectory +SegPlanner::buildLimit(const std::vector& yaw_nodes, + const std::vector& pitch_nodes, + double max_yaw_acc, + double max_pitch_acc) const { + using namespace Eigen; + + Trajectory traj; + size_t n = yaw_nodes.size(); + + if (n <= 1) return traj; + + traj.reserve(n); + + for (size_t i = 0; i < n - 1; ++i) { + // Build yaw segment + QuinticSegment seg(1.0); + seg = QuinticSegment::build(yaw_nodes[i], yaw_nodes[i + 1], 1.0); + traj.push_back(seg, 1.0); + } + + return traj; +} + +std::pair +SegPlanner::computeArmorAngle(const Eigen::Vector3d& target_pos, + const Eigen::Vector3d& target_center, + double target_yaw, + size_t armors_num, + double radius_1, + double radius_2, + double d_zc, + double d_za) { + // Compute angle to target center + double alpha = std::atan2(target_center.y(), target_center.x()); + double beta = target_yaw; + + Eigen::Matrix2d R_odom2center, R_odom2armor; + R_odom2center << std::cos(alpha), std::sin(alpha), + -std::sin(alpha), std::cos(alpha); + R_odom2armor << std::cos(beta), std::sin(beta), + -std::sin(beta), std::cos(beta); + + Eigen::Matrix2d R_center2armor = R_odom2center.transpose() * R_odom2armor; + double decision_angle = -std::asin(R_center2armor(0, 1)); + + double temp_angle = decision_angle + M_PI / armors_num; + if (temp_angle < 0) temp_angle += 2 * M_PI; + + int selected_id = static_cast(temp_angle / (2 * M_PI / armors_num)); + + // Compute actual yaw angle to armor + double armor_yaw = std::atan2(target_pos.y(), target_pos.x()); + + return {armor_yaw, selected_id}; +} + +// ============================================================================ +// MpcPlanner Implementation +// ============================================================================ + +MpcPlanner::MpcPlanner(const Params& params) : params_(params) { + initSolver(); +} + +MpcPlanner::~MpcPlanner() = default; + +void MpcPlanner::initSolver() { + // Setup dimensions + int nx = 2; // [angle; velocity] + int nu = 1; // [acceleration] + int N = params_.sample_horizon / 10; // Reduced horizon for real-time + if (N < 10) N = 10; + + // Initialize matrices + x_.resize(nx, N); + u_.resize(nu, N - 1); + x_ref_.resize(nx, N); + u_ref_.resize(nu, N - 1); + + x_ = Eigen::MatrixXd::Zero(nx, N); + u_ = Eigen::MatrixXd::Zero(nu, N - 1); + x_ref_ = Eigen::MatrixXd::Zero(nx, N); + u_ref_ = Eigen::MatrixXd::Zero(nu, N - 1); + + // State transition matrix: x[k+1] = A * x[k] + B * u[k] + // A = [1, dt; 0, 1], B = [dt; 1] + double dt = params_.sample_total_time / N; + Adyn_.resize(2, 2); + Adyn_ << 1, dt, + 0, 1; + Bdyn_.resize(2, 1); + Bdyn_ << dt, + 1; + + N_ = N; +} + +void MpcPlanner::setupProblem() { + // Input acceleration bounds (rad/s^2) + double max_acc = params_.max_yaw_acc * M_PI / 180.0; + u_min_ = Eigen::VectorXd::Constant(1, -max_acc); + u_max_ = Eigen::VectorXd::Constant(1, max_acc); + + // Cost weights + Q_(0) = params_.Q_yaw_p; + Q_(1) = params_.Q_yaw_v; + R_(0) = params_.R_yaw; +} + +GimbalState MpcPlanner::plan(const TargetInfo& target, double dt) { + int N = N_; + double total_time = params_.sample_total_time; + double time_step = total_time / N; + + // Reset debug info + debug_.planner_type = "mpc"; + debug_.traj_time.clear(); + debug_.traj_yaw_p.clear(); + debug_.traj_yaw_v.clear(); + debug_.max_yaw_acc = params_.max_yaw_acc; + debug_.max_pitch_acc = params_.max_pitch_acc; + debug_.total_dt = dt; + + // 1. Generate reference trajectory using SegPlanner approach + std::vector ref_traj; + ref_traj.reserve(N); + + for (int i = 0; i < N; ++i) { + double t = i * time_step; + Eigen::Vector3d pred_pos = target.position + t * target.velocity; + double pred_yaw = target.yaw + t * target.v_yaw; + double pitch = std::atan2(pred_pos.z(), pred_pos.head(2).norm()); + + ref_traj.push_back(GimbalState(pred_yaw, target.v_yaw, 0.0, pitch, 0.0, 0.0)); + + // Fill debug trajectory points (subsample for efficiency) + if (i % 10 == 0) { + debug_.traj_time.push_back(t); + debug_.traj_yaw_p.push_back(pred_yaw); + debug_.traj_yaw_v.push_back(target.v_yaw); + } + } + + // 2. Set initial state + x_.col(0) << target.yaw, target.v_yaw; + + // 3. Set reference trajectory + for (int i = 0; i < N; ++i) { + x_ref_.col(i) << ref_traj[i].yaw.p, ref_traj[i].yaw.v; + } + + // 4. Solve MPC + solveMpc(ref_traj); + + // 5. Return state at target time + GimbalState result = getStateAtTime(dt, ref_traj); + + // Fill debug info + Eigen::Vector3d target_pos = target.position + dt * target.velocity; + debug_.target_yaw = target.yaw + dt * target.v_yaw; + debug_.target_pitch = std::atan2(target_pos.z(), target_pos.head(2).norm()); + debug_.target_distance = target_pos.norm(); + debug_.planned_yaw = result.yaw.p; + debug_.planned_pitch = result.pitch.p; + debug_.planned_yaw_v = result.yaw.v; + debug_.planned_pitch_v = result.pitch.v; + debug_.planned_yaw_a = result.yaw.a; + debug_.planned_pitch_a = result.pitch.a; + debug_.flying_time = dt; + + return result; +} + +void MpcPlanner::solveMpc(const std::vector& ref_traj) { + int N = N_; + + // Initialize + x_ = Eigen::MatrixXd::Zero(2, N); + u_ = Eigen::MatrixXd::Zero(1, N - 1); + + // Simple forward simulation with feedback + // This is a simplified MPC that uses the reference trajectory + // and applies acceleration constraints + for (int i = 0; i < N - 1; ++i) { + // Get reference state + Eigen::Vector2d x_ref = x_ref_.col(i); + + // Compute error + Eigen::Vector2d error = x_.col(i) - x_ref; + + // Apply acceleration to reduce error + double u = -0.1 * error(0) - 0.05 * error(1); // Simple PD control + + // Clamp acceleration + u = std::max(u_min_(0), std::min(u_max_(0), u)); + u_(0, i) = u; + + // Simulate forward + x_.col(i + 1) = Adyn_ * x_.col(i) + Bdyn_ * u; + } + + // Store solution + mpc_solution_yaw_.clear(); + mpc_solution_yaw_.reserve(N); + for (int i = 0; i < N; ++i) { + State1D yaw_state(x_(0, i), x_(1, i), 0.0); + mpc_solution_yaw_.push_back(yaw_state); + } +} + +GimbalState MpcPlanner::getStateAtTime(double t, const std::vector& ref_traj) const { + if (mpc_solution_yaw_.empty()) { + return GimbalState(); + } + + int N = N_; + double total_time = params_.sample_total_time; + double time_step = total_time / N; + + if (t <= 0) return mpc_solution_yaw_.front(); + if (t >= total_time) return mpc_solution_yaw_.back(); + + // Find the segment + int idx = static_cast(t / time_step); + if (idx >= N - 1) idx = N - 2; + if (idx < 0) idx = 0; + + double seg_t = t - idx * time_step; + double alpha = seg_t / time_step; + alpha = std::clamp(alpha, 0.0, 1.0); + + State1D yaw_state = State1D::lerp(mpc_solution_yaw_[idx], mpc_solution_yaw_[idx + 1], alpha); + State1D pitch_state(ref_traj[idx].pitch.p, 0.0, 0.0); + + return GimbalState(yaw_state.p, yaw_state.v, yaw_state.a, + pitch_state.p, pitch_state.v, pitch_state.a); +} + +void MpcPlanner::setParams(const Params& params) { + params_ = params; + initSolver(); +} + +} // namespace fyt::auto_aim diff --git a/src/rm_bringup/config/node_params/armor_solver_params.yaml b/src/rm_bringup/config/node_params/armor_solver_params.yaml index bdb3881..0304c9d 100644 --- a/src/rm_bringup/config/node_params/armor_solver_params.yaml +++ b/src/rm_bringup/config/node_params/armor_solver_params.yaml @@ -43,13 +43,28 @@ shooting_range_width: 0.10 #射击范围 shooting_range_height: 0.10 #射击范围 prediction_delay: 0.02 # 预测装甲板位置的延时,单位秒,+飞行时间 - controller_delay: 0.01 - max_tracking_v_yaw: 5.0 #转速(rad/s)大于这个值时瞄准机器人中心 - side_angle: 15.0 - compenstator_type: "resistance" + controller_delay: 0.01 + max_tracking_v_yaw: 5.0 #转速(rad/s)大于这个值时瞄准机器人中心 + side_angle: 15.0 + compenstator_type: "resistance" gravity: 9.792 - resistance: 0.038 + resistance: 0.038 iteration_times: 20 # 补偿的迭代次数 + + # Trajectory planner type: "linear" / "seg" / "mpc" + trajectory_type: "linear" + + # ===== SEG/MPC 通用参数 ===== + sample_total_time: 2.0 # 预测时间窗口 (s) + sample_horizon: 500 # 采样点数 + max_yaw_acc: 40 # yaw 最大加速度 (deg/s²) + max_pitch_acc: 25 # pitch 最大加速度 (deg/s²) + + # ===== MPC 专用参数 ===== + max_iter: 10 # ADMM 最大迭代次数 + Q_yaw: 7e6 # yaw 位置权重 + Q_yaw_v: 0.0 # yaw 速度权重 + R_yaw: 3.0 # yaw 控制权重 # ["距离下限, 距离上限, 高度下限, 高度下限, pitch轴补偿值"] # [dist_low, dist_high, height_low, height_high, pitch_offset_deg, yaw_offset_deg] diff --git a/wust_vision-main/.clang-format b/wust_vision-main/.clang-format new file mode 100644 index 0000000..ae74195 --- /dev/null +++ b/wust_vision-main/.clang-format @@ -0,0 +1,72 @@ +AccessModifierOffset: -4 +AlignAfterOpenBracket: BlockIndent +AlignConsecutiveMacros: false +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlines: DontAlign +AlignOperands: false +AlignTrailingComments: false +AllowAllArgumentsOnNextLine: false +AllowAllConstructorInitializersOnNextLine: false +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortBlocksOnASingleLine: Empty +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Empty +AllowShortIfStatementsOnASingleLine: Never +AllowShortLambdasOnASingleLine: All +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: true +AlwaysBreakTemplateDeclarations: Yes +BinPackArguments: false +BinPackParameters: false +BreakBeforeBinaryOperators: NonAssignment +BreakBeforeBraces: Custom +BreakBeforeTernaryOperators: true +BraceWrapping: + AfterControlStatement: MultiLine + AfterEnum: false + AfterStruct: false + SplitEmptyFunction: false +BreakConstructorInitializers: AfterColon +BreakInheritanceList: AfterColon +BreakStringLiterals: false +ColumnLimit: 100 +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: false +DerivePointerAlignment: false +FixNamespaceComments: true +IncludeBlocks: Preserve +IncludeIsMainRegex: '([-_](test|unittest))?$' +IndentCaseLabels: true +IndentPPDirectives: BeforeHash +IndentWidth: 4 +IndentWrappedFunctionNames: false +KeepEmptyLinesAtTheStartOfBlocks: false +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: Inner +PointerAlignment: Left +ReflowComments: false +SortIncludes: true +SortUsingDeclarations: true +SpaceAfterCStyleCast: false +SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: false +SpaceBeforeAssignmentOperators: true +SpaceBeforeCpp11BracedList: true +SpaceBeforeCtorInitializerColon: false +SpaceBeforeInheritanceColon: false +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: false +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInCStyleCastParentheses: false +SpacesInContainerLiterals: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Cpp11 +UseTab: Never diff --git a/wust_vision-main/.clangd b/wust_vision-main/.clangd new file mode 100644 index 0000000..3095f5b --- /dev/null +++ b/wust_vision-main/.clangd @@ -0,0 +1,6 @@ +Diagnostics: + Suppress: + - drv_unknown_argument + +CompileFlags: + Remove: [-forward-unknown-to-host-compiler, --generate-code=*, -Xcompiler=*] \ No newline at end of file diff --git a/wust_vision-main/.gitignore b/wust_vision-main/.gitignore new file mode 100644 index 0000000..401ca2a --- /dev/null +++ b/wust_vision-main/.gitignore @@ -0,0 +1,31 @@ +build + +devel + +install + +log/* + +.catkin_workspace + +.vscode + +.cache + +__pycache__ + +*~ + +.DS_Store + +*.pcd + +*.gv + +*.pdf + + +bin + +model/at.onnx +model/at.engine \ No newline at end of file diff --git a/wust_vision-main/.gitmodules b/wust_vision-main/.gitmodules new file mode 100644 index 0000000..250dc35 --- /dev/null +++ b/wust_vision-main/.gitmodules @@ -0,0 +1,6 @@ +[submodule "KalmanHyLib"] + path = KalmanHyLib + url = https://github.com/hyheiyue/KalmanHyLib.git +[submodule "3rdparty/backward-cpp"] + path = 3rdparty/backward-cpp + url = https://github.com/bombela/backward-cpp.git diff --git a/wust_vision-main/3rdparty/angles.h b/wust_vision-main/3rdparty/angles.h new file mode 100644 index 0000000..616ceb3 --- /dev/null +++ b/wust_vision-main/3rdparty/angles.h @@ -0,0 +1,373 @@ +/********************************************************************* + * Software License Agreement (BSD License) + * + * Copyright (c) 2008, Willow Garage, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * * Neither the name of the Willow Garage nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE + * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, + * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN + * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + *********************************************************************/ + +#ifndef GEOMETRY_ANGLES_UTILS_H +#define GEOMETRY_ANGLES_UTILS_H + +#ifndef _USE_MATH_DEFINES + #define _USE_MATH_DEFINES +#endif + +#include +#include + +namespace angles { + +/*! + * \brief Convert degrees to radians + */ + +static inline double from_degrees(double degrees) { + return degrees * M_PI / 180.0; +} + +/*! + * \brief Convert radians to degrees + */ +static inline double to_degrees(double radians) { + return radians * 180.0 / M_PI; +} + +/*! + * \brief normalize_angle_positive + * + * Normalizes the angle to be 0 to 2*M_PI + * It takes and returns radians. + */ +static inline double normalize_angle_positive(double angle) { + const double result = fmod(angle, 2.0 * M_PI); + if (result < 0) + return result + 2.0 * M_PI; + return result; +} + +/*! + * \brief normalize + * + * Normalizes the angle to be -M_PI circle to +M_PI circle + * It takes and returns radians. + * + */ +static inline double normalize_angle(double angle) { + const double result = fmod(angle + M_PI, 2.0 * M_PI); + if (result <= 0.0) + return result + M_PI; + return result - M_PI; +} + +/*! + * \function + * \brief shortest_angular_distance + * + * Given 2 angles, this returns the shortest angular + * difference. The inputs and ouputs are of course radians. + * + * The result + * would always be -pi <= result <= pi. Adding the result + * to "from" will always get you an equivelent angle to "to". + */ + +static inline double shortest_angular_distance(double from, double to) { + return normalize_angle(to - from); +} + +/*! + * \function + * + * \brief returns the angle in [-2*M_PI, 2*M_PI] going the other way along the + * unit circle. \param angle The angle to which you want to turn in the range + * [-2*M_PI, 2*M_PI] E.g. two_pi_complement(-M_PI/4) returns 7_M_PI/4 + * two_pi_complement(M_PI/4) returns -7*M_PI/4 + * + */ +static inline double two_pi_complement(double angle) { + // check input conditions + if (angle > 2 * M_PI || angle < -2.0 * M_PI) + angle = fmod(angle, 2.0 * M_PI); + if (angle < 0) + return (2 * M_PI + angle); + else if (angle > 0) + return (-2 * M_PI + angle); + + return (2 * M_PI); +} + +/*! + * \function + * + * \brief This function is only intended for internal use and not intended for + * external use. If you do use it, read the documentation very carefully. + * Returns the min and max amount (in radians) that can be moved from "from" + * angle to "left_limit" and "right_limit". \return returns false if "from" + * angle does not lie in the interval [left_limit,right_limit] \param from - + * "from" angle - must lie in [-M_PI, M_PI) \param left_limit - left limit of + * valid interval for angular position - must lie in [-M_PI, M_PI], left and + * right limits are specified on the unit circle w.r.t to a reference pointing + * inwards \param right_limit - right limit of valid interval for angular + * position - must lie in [-M_PI, M_PI], left and right limits are specified on + * the unit circle w.r.t to a reference pointing inwards \param result_min_delta + * - minimum (delta) angle (in radians) that can be moved from "from" position + * before hitting the joint stop \param result_max_delta - maximum (delta) angle + * (in radians) that can be movedd from "from" position before hitting the joint + * stop + */ +static bool find_min_max_delta( + double from, + double left_limit, + double right_limit, + double& result_min_delta, + double& result_max_delta +) { + double delta[4]; + + delta[0] = shortest_angular_distance(from, left_limit); + delta[1] = shortest_angular_distance(from, right_limit); + + delta[2] = two_pi_complement(delta[0]); + delta[3] = two_pi_complement(delta[1]); + + if (delta[0] == 0) { + result_min_delta = delta[0]; + result_max_delta = std::max(delta[1], delta[3]); + return true; + } + + if (delta[1] == 0) { + result_max_delta = delta[1]; + result_min_delta = std::min(delta[0], delta[2]); + return true; + } + + double delta_min = delta[0]; + double delta_min_2pi = delta[2]; + if (delta[2] < delta_min) { + delta_min = delta[2]; + delta_min_2pi = delta[0]; + } + + double delta_max = delta[1]; + double delta_max_2pi = delta[3]; + if (delta[3] > delta_max) { + delta_max = delta[3]; + delta_max_2pi = delta[1]; + } + + // printf("%f %f %f %f\n",delta_min,delta_min_2pi,delta_max,delta_max_2pi); + if ((delta_min <= delta_max_2pi) || (delta_max >= delta_min_2pi)) { + result_min_delta = delta_max_2pi; + result_max_delta = delta_min_2pi; + if (left_limit == -M_PI && right_limit == M_PI) + return true; + else + return false; + } + result_min_delta = delta_min; + result_max_delta = delta_max; + return true; +} + +/*! + * \function + * + * \brief Returns the delta from `from_angle` to `to_angle`, making sure it does + * not violate limits specified by `left_limit` and `right_limit`. This function + * is similar to `shortest_angular_distance_with_limits()`, with the main + * difference that it accepts limits outside the `[-M_PI, M_PI]` range. Even if + * this is quite uncommon, one could indeed consider revolute joints with large + * rotation limits, e.g., in the range `[-2*M_PI, 2*M_PI]`. + * + * In this case, a strict requirement is to have `left_limit` smaller than + * `right_limit`. Note also that `from` must lie inside the valid range, while + * `to` does not need to. In fact, this function will evaluate the shortest + * (valid) angle `shortest_angle` so that `from+shortest_angle` equals `to` up + * to an integer multiple of `2*M_PI`. As an example, a call to + * `shortest_angular_distance_with_large_limits(0, 10.5*M_PI, -2*M_PI, 2*M_PI, + * shortest_angle)` will return `true`, with `shortest_angle=0.5*M_PI`. This is + * because `from` and `from+shortest_angle` are both inside the limits, and + * `fmod(to+shortest_angle, 2*M_PI)` equals `fmod(to, 2*M_PI)`. On the other + * hand, `shortest_angular_distance_with_large_limits(10.5*M_PI, 0, -2*M_PI, + * 2*M_PI, shortest_angle)` will return false, since `from` is not in the valid + * range. Finally, note that the call + * `shortest_angular_distance_with_large_limits(0, 10.5*M_PI, -2*M_PI, 0.1*M_PI, + * shortest_angle)` will also return `true`. However, `shortest_angle` in this + * case will be `-1.5*M_PI`. + * + * \return true if `left_limit < right_limit` and if "from" and + * "from+shortest_angle" positions are within the valid interval, false + * otherwise. \param from - "from" angle. \param to - "to" angle. \param + * left_limit - left limit of valid interval, must be smaller than right_limit. + * \param right_limit - right limit of valid interval, must be greater than + * left_limit. \param shortest_angle - result of the shortest angle calculation. + */ +static inline bool shortest_angular_distance_with_large_limits( + double from, + double to, + double left_limit, + double right_limit, + double& shortest_angle +) { + // Shortest steps in the two directions + double delta = shortest_angular_distance(from, to); + double delta_2pi = two_pi_complement(delta); + + // "sort" distances so that delta is shorter than delta_2pi + if (std::fabs(delta) > std::fabs(delta_2pi)) + std::swap(delta, delta_2pi); + + if (left_limit > right_limit) { + // If limits are something like [PI/2 , -PI/2] it actually means that we + // want rotations to be in the interval [-PI,PI/2] U [PI/2,PI], ie, the + // half unit circle not containing the 0. This is already gracefully + // handled by shortest_angular_distance_with_limits, and therefore this + // function should not be called at all. However, if one has limits that + // are larger than PI, the same rationale behind + // shortest_angular_distance_with_limits does not hold, ie, M_PI+x should + // not be directly equal to -M_PI+x. In this case, the correct way of + // getting the shortest solution is to properly set the limits, eg, by + // saying that the interval is either [PI/2, 3*PI/2] or [-3*M_PI/2, + // -M_PI/2]. For this reason, here we return false by default. + shortest_angle = delta; + return false; + } + + // Check in which direction we should turn (clockwise or counter-clockwise). + + // start by trying with the shortest angle (delta). + double to2 = from + delta; + if (left_limit <= to2 && to2 <= right_limit) { + // we can move in this direction: return success if the "from" angle is + // inside limits + shortest_angle = delta; + return left_limit <= from && from <= right_limit; + } + + // delta is not ok, try to move in the other direction (using its complement) + to2 = from + delta_2pi; + if (left_limit <= to2 && to2 <= right_limit) { + // we can move in this direction: return success if the "from" angle is + // inside limits + shortest_angle = delta_2pi; + return left_limit <= from && from <= right_limit; + } + + // nothing works: we always go outside limits + shortest_angle = delta; // at least give some "coherent" result + return false; +} + +/*! + * \function + * + * \brief Returns the delta from "from_angle" to "to_angle" making sure it does + * not violate limits specified by left_limit and right_limit. The valid + * interval of angular positions is [left_limit,right_limit]. E.g., [-0.25,0.25] + * is a 0.5 radians wide interval that contains 0. But [0.25,-0.25] is a + * 2*M_PI-0.5 wide interval that contains M_PI (but not 0). The value of + * shortest_angle is the angular difference between "from" and "to" that lies + * within the defined valid interval. E.g. + * shortest_angular_distance_with_limits(-0.5,0.5,0.25,-0.25,ss) evaluates ss to + * 2*M_PI-1.0 and returns true while + * shortest_angular_distance_with_limits(-0.5,0.5,-0.25,0.25,ss) returns false + * since -0.5 and 0.5 do not lie in the interval [-0.25,0.25] + * + * \return true if "from" and "to" positions are within the limit interval, + * false otherwise \param from - "from" angle \param to - "to" angle \param + * left_limit - left limit of valid interval for angular position, left and + * right limits are specified on the unit circle w.r.t to a reference pointing + * inwards \param right_limit - right limit of valid interval for angular + * position, left and right limits are specified on the unit circle w.r.t to a + * reference pointing inwards \param shortest_angle - result of the shortest + * angle calculation + */ +static inline bool shortest_angular_distance_with_limits( + double from, + double to, + double left_limit, + double right_limit, + double& shortest_angle +) { + double min_delta = -2 * M_PI; + double max_delta = 2 * M_PI; + double min_delta_to = -2 * M_PI; + double max_delta_to = 2 * M_PI; + bool flag = find_min_max_delta(from, left_limit, right_limit, min_delta, max_delta); + double delta = shortest_angular_distance(from, to); + double delta_mod_2pi = two_pi_complement(delta); + + if (flag) // from position is within the limits + { + if (delta >= min_delta && delta <= max_delta) { + shortest_angle = delta; + return true; + } else if (delta_mod_2pi >= min_delta && delta_mod_2pi <= max_delta) { + shortest_angle = delta_mod_2pi; + return true; + } else // to position is outside the limits + { + find_min_max_delta(to, left_limit, right_limit, min_delta_to, max_delta_to); + if (fabs(min_delta_to) < fabs(max_delta_to)) + shortest_angle = std::max(delta, delta_mod_2pi); + else if (fabs(min_delta_to) > fabs(max_delta_to)) + shortest_angle = std::min(delta, delta_mod_2pi); + else { + if (fabs(delta) < fabs(delta_mod_2pi)) + shortest_angle = delta; + else + shortest_angle = delta_mod_2pi; + } + return false; + } + } else // from position is outside the limits + { + find_min_max_delta(to, left_limit, right_limit, min_delta_to, max_delta_to); + + if (fabs(min_delta) < fabs(max_delta)) + shortest_angle = std::min(delta, delta_mod_2pi); + else if (fabs(min_delta) > fabs(max_delta)) + shortest_angle = std::max(delta, delta_mod_2pi); + else { + if (fabs(delta) < fabs(delta_mod_2pi)) + shortest_angle = delta; + else + shortest_angle = delta_mod_2pi; + } + return false; + } + + shortest_angle = delta; + return false; +} +} // namespace angles + +#endif \ No newline at end of file diff --git a/wust_vision-main/3rdparty/ankerl/stl.h b/wust_vision-main/3rdparty/ankerl/stl.h new file mode 100644 index 0000000..216a1f1 --- /dev/null +++ b/wust_vision-main/3rdparty/ankerl/stl.h @@ -0,0 +1,84 @@ +///////////////////////// ankerl::unordered_dense::{map, set} ///////////////////////// + +// A fast & densely stored hashmap and hashset based on robin-hood backward shift deletion. +// Version 4.8.1 +// https://github.com/martinus/unordered_dense +// +// Licensed under the MIT License . +// SPDX-License-Identifier: MIT +// Copyright (c) 2022 Martin Leitner-Ankerl +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifndef ANKERL_STL_H +#define ANKERL_STL_H + +#include // for array +#include // for uint64_t, uint32_t, std::uint8_t, UINT64_C +#include // for size_t, memcpy, memset +#include // for equal_to, hash +#include // for initializer_list +#include // for pair, distance +#include // for numeric_limits +#include // for allocator, allocator_traits, shared_ptr +#include // for optional +#include // for out_of_range +#include // for basic_string +#include // for basic_string_view, hash +#include // for forward_as_tuple +#include // for enable_if_t, declval, conditional_t, ena... +#include // for forward, exchange, pair, as_const, piece... +#include // for vector + +// includes , which fails to compile if +// targeting GCC >= 13 with the (rewritten) win32 thread model, and +// targeting Windows earlier than Vista (0x600). GCC predefines +// _REENTRANT when using the 'posix' model, and doesn't when using the +// 'win32' model. +#if defined __MINGW64__ && defined __GNUC__ && __GNUC__ >= 13 && !defined _REENTRANT + // _WIN32_WINNT is guaranteed to be defined here because of the + // inclusion above. + #ifndef _WIN32_WINNT + #error "_WIN32_WINNT not defined" + #endif + #if _WIN32_WINNT < 0x600 + #define ANKERL_MEMORY_RESOURCE_IS_BAD() 1 // NOLINT(cppcoreguidelines-macro-usage) + #endif +#endif +#ifndef ANKERL_MEMORY_RESOURCE_IS_BAD + #define ANKERL_MEMORY_RESOURCE_IS_BAD() 0 // NOLINT(cppcoreguidelines-macro-usage) +#endif + +#if defined(__has_include) && !defined(ANKERL_UNORDERED_DENSE_DISABLE_PMR) + #if __has_include() && !ANKERL_MEMORY_RESOURCE_IS_BAD() + #define ANKERL_UNORDERED_DENSE_PMR std::pmr // NOLINT(cppcoreguidelines-macro-usage) + #include // for polymorphic_allocator + #elif __has_include() + #define ANKERL_UNORDERED_DENSE_PMR \ + std::experimental::pmr // NOLINT(cppcoreguidelines-macro-usage) + #include // for polymorphic_allocator + #endif +#endif + +#if defined(_MSC_VER) && defined(_M_X64) + #include + #pragma intrinsic(_umul128) +#endif + +#endif \ No newline at end of file diff --git a/wust_vision-main/3rdparty/ankerl/unordered_dense.h b/wust_vision-main/3rdparty/ankerl/unordered_dense.h new file mode 100644 index 0000000..1508836 --- /dev/null +++ b/wust_vision-main/3rdparty/ankerl/unordered_dense.h @@ -0,0 +1,2440 @@ +///////////////////////// ankerl::unordered_dense::{map, set} ///////////////////////// + +// A fast & densely stored hashmap and hashset based on robin-hood backward shift deletion. +// Version 4.8.1 +// https://github.com/martinus/unordered_dense +// +// Licensed under the MIT License . +// SPDX-License-Identifier: MIT +// Copyright (c) 2022 Martin Leitner-Ankerl +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifndef ANKERL_UNORDERED_DENSE_H +#define ANKERL_UNORDERED_DENSE_H + +// see https://semver.org/spec/v2.0.0.html +#define ANKERL_UNORDERED_DENSE_VERSION_MAJOR \ + 4 // NOLINT(cppcoreguidelines-macro-usage) incompatible API changes +#define ANKERL_UNORDERED_DENSE_VERSION_MINOR \ + 8 // NOLINT(cppcoreguidelines-macro-usage) backwards compatible functionality +#define ANKERL_UNORDERED_DENSE_VERSION_PATCH \ + 1 // NOLINT(cppcoreguidelines-macro-usage) backwards compatible bug fixes + +// API versioning with inline namespace, see https://www.foonathan.net/2018/11/inline-namespaces/ + +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define ANKERL_UNORDERED_DENSE_VERSION_CONCAT1(major, minor, patch) v##major##_##minor##_##patch +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define ANKERL_UNORDERED_DENSE_VERSION_CONCAT(major, minor, patch) \ + ANKERL_UNORDERED_DENSE_VERSION_CONCAT1(major, minor, patch) +#define ANKERL_UNORDERED_DENSE_NAMESPACE \ + ANKERL_UNORDERED_DENSE_VERSION_CONCAT( \ + ANKERL_UNORDERED_DENSE_VERSION_MAJOR, \ + ANKERL_UNORDERED_DENSE_VERSION_MINOR, \ + ANKERL_UNORDERED_DENSE_VERSION_PATCH \ + ) + +#if defined(_MSVC_LANG) + #define ANKERL_UNORDERED_DENSE_CPP_VERSION _MSVC_LANG +#else + #define ANKERL_UNORDERED_DENSE_CPP_VERSION __cplusplus +#endif + +#if defined(__GNUC__) + // NOLINTNEXTLINE(cppcoreguidelines-macro-usage) + #define ANKERL_UNORDERED_DENSE_PACK(decl) decl __attribute__((__packed__)) +#elif defined(_MSC_VER) + // NOLINTNEXTLINE(cppcoreguidelines-macro-usage) + #define ANKERL_UNORDERED_DENSE_PACK(decl) __pragma(pack(push, 1)) decl __pragma(pack(pop)) +#endif + +// exceptions +#if defined(__cpp_exceptions) || defined(__EXCEPTIONS) || defined(_CPPUNWIND) + #define ANKERL_UNORDERED_DENSE_HAS_EXCEPTIONS() 1 // NOLINT(cppcoreguidelines-macro-usage) +#else + #define ANKERL_UNORDERED_DENSE_HAS_EXCEPTIONS() 0 // NOLINT(cppcoreguidelines-macro-usage) +#endif +#ifdef _MSC_VER + #define ANKERL_UNORDERED_DENSE_NOINLINE __declspec(noinline) +#else + #define ANKERL_UNORDERED_DENSE_NOINLINE __attribute__((noinline)) +#endif + +#if defined(__clang__) && defined(__has_attribute) + #if __has_attribute(__no_sanitize__) + #define ANKERL_UNORDERED_DENSE_DISABLE_UBSAN_UNSIGNED_INTEGER_CHECK \ + __attribute__((__no_sanitize__("unsigned-integer-overflow"))) + #endif +#endif + +#if !defined(ANKERL_UNORDERED_DENSE_DISABLE_UBSAN_UNSIGNED_INTEGER_CHECK) + #define ANKERL_UNORDERED_DENSE_DISABLE_UBSAN_UNSIGNED_INTEGER_CHECK +#endif + +#if ANKERL_UNORDERED_DENSE_CPP_VERSION < 201703L + #error ankerl::unordered_dense requires C++17 or higher +#else + + #if !defined(ANKERL_UNORDERED_DENSE_STD_MODULE) + // NOLINTNEXTLINE(cppcoreguidelines-macro-usage) + #define ANKERL_UNORDERED_DENSE_STD_MODULE 0 + #endif + + #if !ANKERL_UNORDERED_DENSE_STD_MODULE + #include "stl.h" + #endif + + #if __has_cpp_attribute(likely) && __has_cpp_attribute(unlikely) \ + && ANKERL_UNORDERED_DENSE_CPP_VERSION >= 202002L + #define ANKERL_UNORDERED_DENSE_LIKELY_ATTR \ + [[likely]] // NOLINT(cppcoreguidelines-macro-usage) + #define ANKERL_UNORDERED_DENSE_UNLIKELY_ATTR \ + [[unlikely]] // NOLINT(cppcoreguidelines-macro-usage) + #define ANKERL_UNORDERED_DENSE_LIKELY(x) (x) // NOLINT(cppcoreguidelines-macro-usage) + #define ANKERL_UNORDERED_DENSE_UNLIKELY(x) (x) // NOLINT(cppcoreguidelines-macro-usage) + #else + #define ANKERL_UNORDERED_DENSE_LIKELY_ATTR // NOLINT(cppcoreguidelines-macro-usage) + #define ANKERL_UNORDERED_DENSE_UNLIKELY_ATTR // NOLINT(cppcoreguidelines-macro-usage) + + #if defined(__GNUC__) || defined(__INTEL_COMPILER) || defined(__clang__) + #define ANKERL_UNORDERED_DENSE_LIKELY(x) \ + __builtin_expect(x, 1) // NOLINT(cppcoreguidelines-macro-usage) + #define ANKERL_UNORDERED_DENSE_UNLIKELY(x) \ + __builtin_expect(x, 0) // NOLINT(cppcoreguidelines-macro-usage) + #else + #define ANKERL_UNORDERED_DENSE_LIKELY(x) (x) // NOLINT(cppcoreguidelines-macro-usage) + #define ANKERL_UNORDERED_DENSE_UNLIKELY(x) (x) // NOLINT(cppcoreguidelines-macro-usage) + #endif + + #endif + +namespace ankerl::unordered_dense { +inline namespace ANKERL_UNORDERED_DENSE_NAMESPACE { + + namespace detail { + + #if ANKERL_UNORDERED_DENSE_HAS_EXCEPTIONS() + + // make sure this is not inlined as it is slow and dramatically enlarges code, thus making other + // inlinings more difficult. Throws are also generally the slow path. + [[noreturn]] inline ANKERL_UNORDERED_DENSE_NOINLINE void on_error_key_not_found() { + throw std::out_of_range("ankerl::unordered_dense::map::at(): key not found"); + } + [[noreturn]] inline ANKERL_UNORDERED_DENSE_NOINLINE void on_error_bucket_overflow() { + throw std::overflow_error( + "ankerl::unordered_dense: reached max bucket size, cannot increase size" + ); + } + [[noreturn]] inline ANKERL_UNORDERED_DENSE_NOINLINE void on_error_too_many_elements() { + throw std::out_of_range("ankerl::unordered_dense::map::replace(): too many elements"); + } + + #else + + [[noreturn]] inline void on_error_key_not_found() { + abort(); + } + [[noreturn]] inline void on_error_bucket_overflow() { + abort(); + } + [[noreturn]] inline void on_error_too_many_elements() { + abort(); + } + + #endif + + } // namespace detail + + // hash /////////////////////////////////////////////////////////////////////// + + // This is a stripped-down implementation of wyhash: https://github.com/wangyi-fudan/wyhash + // No big-endian support (because different values on different machines don't matter), + // hardcodes seed and the secret, reformats the code, and clang-tidy fixes. + namespace detail::wyhash { + + inline void mum(std::uint64_t* a, std::uint64_t* b) { + #if defined(__SIZEOF_INT128__) + __uint128_t r = *a; + r *= *b; + *a = static_cast(r); + *b = static_cast(r >> 64U); + #elif defined(_MSC_VER) && defined(_M_X64) + *a = _umul128(*a, *b, b); + #else + std::uint64_t ha = *a >> 32U; + std::uint64_t hb = *b >> 32U; + std::uint64_t la = static_cast(*a); + std::uint64_t lb = static_cast(*b); + std::uint64_t hi {}; + std::uint64_t lo {}; + std::uint64_t rh = ha * hb; + std::uint64_t rm0 = ha * lb; + std::uint64_t rm1 = hb * la; + std::uint64_t rl = la * lb; + std::uint64_t t = rl + (rm0 << 32U); + auto c = static_cast(t < rl); + lo = t + (rm1 << 32U); + c += static_cast(lo < t); + hi = rh + (rm0 >> 32U) + (rm1 >> 32U) + c; + *a = lo; + *b = hi; + #endif + } + + // multiply and xor mix function, aka MUM + [[nodiscard]] inline auto mix(std::uint64_t a, std::uint64_t b) -> std::uint64_t { + mum(&a, &b); + return a ^ b; + } + + // read functions. WARNING: we don't care about endianness, so results are different on big endian! + [[nodiscard]] inline auto r8(const std::uint8_t* p) -> std::uint64_t { + std::uint64_t v {}; + std::memcpy(&v, p, 8U); + return v; + } + + [[nodiscard]] inline auto r4(const std::uint8_t* p) -> std::uint64_t { + std::uint32_t v {}; + std::memcpy(&v, p, 4); + return v; + } + + // reads 1, 2, or 3 bytes + [[nodiscard]] inline auto r3(const std::uint8_t* p, std::size_t k) -> std::uint64_t { + return (static_cast(p[0]) << 16U) + | (static_cast(p[k >> 1U]) << 8U) | p[k - 1]; + } + + [[maybe_unused]] [[nodiscard]] inline auto hash(void const* key, std::size_t len) + -> std::uint64_t { + static constexpr auto secret = std::array { UINT64_C(0xa0761d6478bd642f), + UINT64_C(0xe7037ed1a0b428db), + UINT64_C(0x8ebc6af09c88c6e3), + UINT64_C(0x589965cc75374cc3) }; + + auto const* p = static_cast(key); + std::uint64_t seed = secret[0]; + std::uint64_t a {}; + std::uint64_t b {}; + if (ANKERL_UNORDERED_DENSE_LIKELY(len <= 16)) + ANKERL_UNORDERED_DENSE_LIKELY_ATTR { + if (ANKERL_UNORDERED_DENSE_LIKELY(len >= 4)) + ANKERL_UNORDERED_DENSE_LIKELY_ATTR { + a = (r4(p) << 32U) | r4(p + ((len >> 3U) << 2U)); + b = (r4(p + len - 4) << 32U) | r4(p + len - 4 - ((len >> 3U) << 2U)); + } + else if (ANKERL_UNORDERED_DENSE_LIKELY(len > 0)) + ANKERL_UNORDERED_DENSE_LIKELY_ATTR { + a = r3(p, len); + b = 0; + } + else { + a = 0; + b = 0; + } + } + else { + std::size_t i = len; + if (ANKERL_UNORDERED_DENSE_UNLIKELY(i > 48)) + ANKERL_UNORDERED_DENSE_UNLIKELY_ATTR { + std::uint64_t see1 = seed; + std::uint64_t see2 = seed; + do { + seed = mix(r8(p) ^ secret[1], r8(p + 8) ^ seed); + see1 = mix(r8(p + 16) ^ secret[2], r8(p + 24) ^ see1); + see2 = mix(r8(p + 32) ^ secret[3], r8(p + 40) ^ see2); + p += 48; + i -= 48; + } while (ANKERL_UNORDERED_DENSE_LIKELY(i > 48)); + seed ^= see1 ^ see2; + } + while (ANKERL_UNORDERED_DENSE_UNLIKELY(i > 16)) + ANKERL_UNORDERED_DENSE_UNLIKELY_ATTR { + seed = mix(r8(p) ^ secret[1], r8(p + 8) ^ seed); + i -= 16; + p += 16; + } + a = r8(p + i - 16); + b = r8(p + i - 8); + } + + return mix(secret[1] ^ len, mix(a ^ secret[1], b ^ seed)); + } + + [[nodiscard]] inline auto hash(std::uint64_t x) -> std::uint64_t { + return detail::wyhash::mix(x, UINT64_C(0x9E3779B97F4A7C15)); + } + + } // namespace detail::wyhash + + template + struct hash { + auto operator()(T const& obj) const + noexcept(noexcept(std::declval>().operator()(std::declval()))) + -> std::uint64_t { + return std::hash {}(obj); + } + }; + + template + struct hash::is_avalanching> { + using is_avalanching = void; + auto operator()(T const& obj) const + noexcept(noexcept(std::declval>().operator()(std::declval()))) + -> std::uint64_t { + return std::hash {}(obj); + } + }; + + template + struct hash> { + using is_avalanching = void; + auto operator()(std::basic_string const& str) const noexcept -> std::uint64_t { + return detail::wyhash::hash(str.data(), sizeof(CharT) * str.size()); + } + }; + + template + struct hash> { + using is_avalanching = void; + auto operator()(std::basic_string_view const& sv) const noexcept -> std::uint64_t { + return detail::wyhash::hash(sv.data(), sizeof(CharT) * sv.size()); + } + }; + + template + struct hash { + using is_avalanching = void; + auto operator()(T* ptr) const noexcept -> std::uint64_t { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + return detail::wyhash::hash(reinterpret_cast(ptr)); + } + }; + + template + struct hash> { + using is_avalanching = void; + auto operator()(std::unique_ptr const& ptr) const noexcept -> std::uint64_t { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + return detail::wyhash::hash(reinterpret_cast(ptr.get())); + } + }; + + template + struct hash> { + using is_avalanching = void; + auto operator()(std::shared_ptr const& ptr) const noexcept -> std::uint64_t { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + return detail::wyhash::hash(reinterpret_cast(ptr.get())); + } + }; + + template + struct hash>> { + using is_avalanching = void; + auto operator()(Enum e) const noexcept -> std::uint64_t { + using underlying = std::underlying_type_t; + return detail::wyhash::hash(static_cast(e)); + } + }; + + template + struct tuple_hash_helper { + // Converts the value into 64bit. If it is an integral type, just cast it. Mixing is doing the rest. + // If it isn't an integral we need to hash it. + template + [[nodiscard]] constexpr static auto to64(Arg const& arg) -> std::uint64_t { + if constexpr (std::is_integral_v || std::is_enum_v) { + return static_cast(arg); + } else { + return hash {}(arg); + } + } + + [[nodiscard]] ANKERL_UNORDERED_DENSE_DISABLE_UBSAN_UNSIGNED_INTEGER_CHECK static auto + mix64(std::uint64_t state, std::uint64_t v) -> std::uint64_t { + return detail::wyhash::mix(state + v, std::uint64_t { 0x9ddfea08eb382d69 }); + } + + // Creates a buffer that holds all the data from each element of the tuple. If possible we memcpy the data directly. If + // not, we hash the object and use this for the array. Size of the array is known at compile time, and memcpy is optimized + // away, so filling the buffer is highly efficient. Finally, call wyhash with this buffer. + template + [[nodiscard]] static auto + calc_hash(T const& t, std::index_sequence /*unused*/) noexcept -> std::uint64_t { + auto h = std::uint64_t {}; + ((h = mix64(h, to64(std::get(t)))), ...); + return h; + } + }; + + template + struct hash>: tuple_hash_helper { + using is_avalanching = void; + auto operator()(std::tuple const& t) const noexcept -> std::uint64_t { + return tuple_hash_helper::calc_hash(t, std::index_sequence_for {}); + } + }; + + template + struct hash>: tuple_hash_helper { + using is_avalanching = void; + auto operator()(std::pair const& t) const noexcept -> std::uint64_t { + return tuple_hash_helper::calc_hash(t, std::index_sequence_for {}); + } + }; + + // NOLINTNEXTLINE(cppcoreguidelines-macro-usage) + #define ANKERL_UNORDERED_DENSE_HASH_STATICCAST(T) \ + template<> \ + struct hash { \ + using is_avalanching = void; \ + auto operator()(T const& obj) const noexcept -> std::uint64_t { \ + return detail::wyhash::hash(static_cast(obj)); \ + } \ + } + + #if defined(__GNUC__) && !defined(__clang__) + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wuseless-cast" + #endif + // see https://en.cppreference.com/w/cpp/utility/hash + ANKERL_UNORDERED_DENSE_HASH_STATICCAST(bool); + ANKERL_UNORDERED_DENSE_HASH_STATICCAST(char); + ANKERL_UNORDERED_DENSE_HASH_STATICCAST(signed char); + ANKERL_UNORDERED_DENSE_HASH_STATICCAST(unsigned char); + #if ANKERL_UNORDERED_DENSE_CPP_VERSION >= 202002L && defined(__cpp_char8_t) + ANKERL_UNORDERED_DENSE_HASH_STATICCAST(char8_t); + #endif + ANKERL_UNORDERED_DENSE_HASH_STATICCAST(char16_t); + ANKERL_UNORDERED_DENSE_HASH_STATICCAST(char32_t); + ANKERL_UNORDERED_DENSE_HASH_STATICCAST(wchar_t); + ANKERL_UNORDERED_DENSE_HASH_STATICCAST(short); + ANKERL_UNORDERED_DENSE_HASH_STATICCAST(unsigned short); + ANKERL_UNORDERED_DENSE_HASH_STATICCAST(int); + ANKERL_UNORDERED_DENSE_HASH_STATICCAST(unsigned int); + ANKERL_UNORDERED_DENSE_HASH_STATICCAST(long); + ANKERL_UNORDERED_DENSE_HASH_STATICCAST(long long); + ANKERL_UNORDERED_DENSE_HASH_STATICCAST(unsigned long); + ANKERL_UNORDERED_DENSE_HASH_STATICCAST(unsigned long long); + + #if defined(__GNUC__) && !defined(__clang__) + #pragma GCC diagnostic pop + #endif + + // bucket_type ////////////////////////////////////////////////////////// + + namespace bucket_type { + + struct standard { + static constexpr std::uint32_t dist_inc = 1U << 8U; // skip 1 byte fingerprint + static constexpr std::uint32_t fingerprint_mask = + dist_inc - 1; // mask for 1 byte of fingerprint + + std::uint32_t + m_dist_and_fingerprint; // upper 3 byte: distance to original bucket. lower byte: fingerprint from hash + std::uint32_t m_value_idx; // index into the m_values vector. + }; + + ANKERL_UNORDERED_DENSE_PACK(struct big { + static constexpr std::uint32_t dist_inc = 1U << 8U; // skip 1 byte fingerprint + static constexpr std::uint32_t fingerprint_mask = + dist_inc - 1; // mask for 1 byte of fingerprint + + std::uint32_t + m_dist_and_fingerprint; // upper 3 byte: distance to original bucket. lower byte: fingerprint from hash + std::size_t m_value_idx; // index into the m_values vector. + }); + + } // namespace bucket_type + + namespace detail { + + struct nonesuch {}; + struct default_container_t {}; + + template class Op, class... Args> + struct detector { + using value_t = std::false_type; + using type = Default; + }; + + template class Op, class... Args> + struct detector>, Op, Args...> { + using value_t = std::true_type; + using type = Op; + }; + + template class Op, class... Args> + using is_detected = typename detail::detector::value_t; + + template class Op, class... Args> + constexpr bool is_detected_v = is_detected::value; + + template + using detect_avalanching = typename T::is_avalanching; + + template + using detect_is_transparent = typename T::is_transparent; + + template + using detect_iterator = typename T::iterator; + + template + using detect_reserve = decltype(std::declval().reserve(std::size_t {})); + + // enable_if helpers + + template + constexpr bool is_map_v = !std::is_void_v; + + // clang-format off +template +constexpr bool is_transparent_v = is_detected_v && is_detected_v; + // clang-format on + + template + constexpr bool is_neither_convertible_v = + !std::is_convertible_v && !std::is_convertible_v; + + template + constexpr bool has_reserve = is_detected_v; + + // base type for map has mapped_type + template + struct base_table_type_map { + using mapped_type = T; + }; + + // base type for set doesn't have mapped_type + struct base_table_type_set {}; + + } // namespace detail + + // Very much like std::deque, but faster for indexing (in most cases). As of now this doesn't implement the full std::vector + // API, but merely what's necessary to work as an underlying container for ankerl::unordered_dense::{map, set}. + // It allocates blocks of equal size and puts them into the m_blocks vector. That means it can grow simply by adding a new + // block to the back of m_blocks, and doesn't double its size like an std::vector. The disadvantage is that memory is not + // linear and thus there is one more indirection necessary for indexing. + template< + typename T, + typename Allocator = std::allocator, + std::size_t MaxSegmentSizeBytes = 4096> + class segmented_vector { + template + class iter_t; + + public: + using allocator_type = Allocator; + using pointer = typename std::allocator_traits::pointer; + using const_pointer = typename std::allocator_traits::const_pointer; + using difference_type = typename std::allocator_traits::difference_type; + using value_type = T; + using size_type = std::size_t; + using reference = T&; + using const_reference = T const&; + using iterator = iter_t; + using const_iterator = iter_t; + + private: + using vec_alloc = typename std::allocator_traits::template rebind_alloc; + std::vector m_blocks {}; + std::size_t m_size {}; + + // Calculates the maximum number for x in (s << x) <= max_val + static constexpr auto num_bits_closest(std::size_t max_val, std::size_t s) -> std::size_t { + auto f = std::size_t { 0 }; + while (s << (f + 1) <= max_val) { + ++f; + } + return f; + } + + using self_t = segmented_vector; + static constexpr auto num_bits = num_bits_closest(MaxSegmentSizeBytes, sizeof(T)); + static constexpr auto num_elements_in_block = 1U << num_bits; + static constexpr auto mask = num_elements_in_block - 1U; + + /** + * Iterator class doubles as const_iterator and iterator + */ + template + class iter_t { + using ptr_t = std::conditional_t< + IsConst, + segmented_vector::const_pointer const*, + segmented_vector::pointer*>; + ptr_t m_data {}; + std::size_t m_idx {}; + + template + friend class iter_t; + + public: + using difference_type = segmented_vector::difference_type; + using value_type = segmented_vector::value_type; + using reference = std::conditional_t; + using pointer = std:: + conditional_t; + using iterator_category = std::forward_iterator_tag; + + iter_t() noexcept = default; + + template> + // NOLINTNEXTLINE(google-explicit-constructor,hicpp-explicit-conversions) + constexpr iter_t(iter_t const& other) noexcept: + m_data(other.m_data), + m_idx(other.m_idx) {} + + constexpr iter_t(ptr_t data, std::size_t idx) noexcept: m_data(data), m_idx(idx) {} + + template> + constexpr auto operator=(iter_t const& other) noexcept -> iter_t& { + m_data = other.m_data; + m_idx = other.m_idx; + return *this; + } + + constexpr auto operator++() noexcept -> iter_t& { + ++m_idx; + return *this; + } + + constexpr auto operator++(int) noexcept -> iter_t { + iter_t prev(*this); + this->operator++(); + return prev; + } + + constexpr auto operator--() noexcept -> iter_t& { + --m_idx; + return *this; + } + + constexpr auto operator--(int) noexcept -> iter_t { + iter_t prev(*this); + this->operator--(); + return prev; + } + + [[nodiscard]] constexpr auto operator+(difference_type diff) const noexcept -> iter_t { + return { m_data, + static_cast(static_cast(m_idx) + diff) }; + } + + constexpr auto operator+=(difference_type diff) noexcept -> iter_t& { + m_idx += diff; + return *this; + } + + [[nodiscard]] constexpr auto operator-(difference_type diff) const noexcept -> iter_t { + return { m_data, + static_cast(static_cast(m_idx) - diff) }; + } + + constexpr auto operator-=(difference_type diff) noexcept -> iter_t& { + m_idx -= diff; + return *this; + } + + template + [[nodiscard]] constexpr auto operator-(iter_t const& other) const noexcept + -> difference_type { + return static_cast(m_idx) + - static_cast(other.m_idx); + } + + constexpr auto operator*() const noexcept -> reference { + return m_data[m_idx >> num_bits][m_idx & mask]; + } + + constexpr auto operator->() const noexcept -> pointer { + return &m_data[m_idx >> num_bits][m_idx & mask]; + } + + template + [[nodiscard]] constexpr auto operator==(iter_t const& o) const noexcept -> bool { + return m_idx == o.m_idx; + } + + template + [[nodiscard]] constexpr auto operator!=(iter_t const& o) const noexcept -> bool { + return !(*this == o); + } + + template + [[nodiscard]] constexpr auto operator<(iter_t const& o) const noexcept -> bool { + return m_idx < o.m_idx; + } + + template + [[nodiscard]] constexpr auto operator>(iter_t const& o) const noexcept -> bool { + return o < *this; + } + + template + [[nodiscard]] constexpr auto operator<=(iter_t const& o) const noexcept -> bool { + return !(o < *this); + } + + template + [[nodiscard]] constexpr auto operator>=(iter_t const& o) const noexcept -> bool { + return !(*this < o); + } + }; + + // slow path: need to allocate a new segment every once in a while + void increase_capacity() { + auto ba = Allocator(m_blocks.get_allocator()); + pointer block = std::allocator_traits::allocate(ba, num_elements_in_block); + m_blocks.push_back(block); + } + + // Moves everything from other + void append_everything_from(segmented_vector&& other + ) { // NOLINT(cppcoreguidelines-rvalue-reference-param-not-moved) + reserve(size() + other.size()); + for (auto&& o: other) { + emplace_back(std::move(o)); + } + } + + // Copies everything from other + void append_everything_from(segmented_vector const& other) { + reserve(size() + other.size()); + for (auto const& o: other) { + emplace_back(o); + } + } + + void dealloc() { + auto ba = Allocator(m_blocks.get_allocator()); + for (auto ptr: m_blocks) { + std::allocator_traits::deallocate(ba, ptr, num_elements_in_block); + } + } + + [[nodiscard]] static constexpr auto calc_num_blocks_for_capacity(std::size_t capacity) { + return (capacity + num_elements_in_block - 1U) / num_elements_in_block; + } + + void resize_shrink(std::size_t new_size) { + if constexpr (!std::is_trivially_destructible_v) { + for (std::size_t ix = new_size; ix < m_size; ++ix) { + operator[](ix).~T(); + } + } + m_size = new_size; + } + + public: + segmented_vector() = default; + + // NOLINTNEXTLINE(google-explicit-constructor,hicpp-explicit-conversions) + segmented_vector(Allocator alloc): m_blocks(vec_alloc(alloc)) {} + + segmented_vector(segmented_vector&& other, Allocator alloc): segmented_vector(alloc) { + *this = std::move(other); + } + + segmented_vector(segmented_vector const& other, Allocator alloc): + m_blocks(vec_alloc(alloc)) { + append_everything_from(other); + } + + segmented_vector(segmented_vector&& other) noexcept: + segmented_vector(std::move(other), other.get_allocator()) {} + + segmented_vector(segmented_vector const& other) { + append_everything_from(other); + } + + auto operator=(segmented_vector const& other) -> segmented_vector& { + if (this == &other) { + return *this; + } + clear(); + append_everything_from(other); + return *this; + } + + auto operator=(segmented_vector&& other) noexcept -> segmented_vector& { + clear(); + dealloc(); + if (other.get_allocator() == get_allocator()) { + m_blocks = std::move(other.m_blocks); + m_size = std::exchange(other.m_size, {}); + } else { + // make sure to construct with other's allocator! + m_blocks = std::vector(vec_alloc(other.get_allocator())); + append_everything_from(std::move(other)); + } + return *this; + } + + ~segmented_vector() { + clear(); + dealloc(); + } + + [[nodiscard]] constexpr auto size() const -> std::size_t { + return m_size; + } + + [[nodiscard]] constexpr auto capacity() const -> std::size_t { + return m_blocks.size() * num_elements_in_block; + } + + // Indexing is highly performance critical + [[nodiscard]] constexpr auto operator[](std::size_t i) const noexcept -> T const& { + return m_blocks[i >> num_bits][i & mask]; + } + + [[nodiscard]] constexpr auto operator[](std::size_t i) noexcept -> T& { + return m_blocks[i >> num_bits][i & mask]; + } + + [[nodiscard]] constexpr auto begin() -> iterator { + return { m_blocks.data(), 0U }; + } + [[nodiscard]] constexpr auto begin() const -> const_iterator { + return { m_blocks.data(), 0U }; + } + [[nodiscard]] constexpr auto cbegin() const -> const_iterator { + return { m_blocks.data(), 0U }; + } + + [[nodiscard]] constexpr auto end() -> iterator { + return { m_blocks.data(), m_size }; + } + [[nodiscard]] constexpr auto end() const -> const_iterator { + return { m_blocks.data(), m_size }; + } + [[nodiscard]] constexpr auto cend() const -> const_iterator { + return { m_blocks.data(), m_size }; + } + + [[nodiscard]] constexpr auto back() -> reference { + return operator[](m_size - 1); + } + [[nodiscard]] constexpr auto back() const -> const_reference { + return operator[](m_size - 1); + } + + void pop_back() { + back().~T(); + --m_size; + } + + [[nodiscard]] auto empty() const { + return 0 == m_size; + } + + void reserve(std::size_t new_capacity) { + m_blocks.reserve(calc_num_blocks_for_capacity(new_capacity)); + while (new_capacity > capacity()) { + increase_capacity(); + } + } + + void resize(std::size_t const count) { + if (count < m_size) { + resize_shrink(count); + } else if (count > m_size) { + std::size_t const new_elems = count - m_size; + reserve(count); + for (std::size_t ix = 0; ix < new_elems; ++ix) { + emplace_back(); + } + } + } + + void resize(std::size_t const count, value_type const& value) { + if (count < m_size) { + resize_shrink(count); + } else if (count > m_size) { + std::size_t const new_elems = count - m_size; + reserve(count); + for (std::size_t ix = 0; ix < new_elems; ++ix) { + emplace_back(value); + } + } + } + + [[nodiscard]] auto get_allocator() const -> allocator_type { + return allocator_type { m_blocks.get_allocator() }; + } + + template + auto emplace_back(Args&&... args) -> reference { + if (m_size == capacity()) { + increase_capacity(); + } + auto* ptr = static_cast(&operator[](m_size)); + auto& ref = *new (ptr) T(std::forward(args)...); + ++m_size; + return ref; + } + + void clear() { + if constexpr (!std::is_trivially_destructible_v) { + for (std::size_t i = 0, s = size(); i < s; ++i) { + operator[](i).~T(); + } + } + m_size = 0; + } + + void shrink_to_fit() { + auto ba = Allocator(m_blocks.get_allocator()); + auto num_blocks_required = calc_num_blocks_for_capacity(m_size); + while (m_blocks.size() > num_blocks_required) { + std::allocator_traits::deallocate( + ba, + m_blocks.back(), + num_elements_in_block + ); + m_blocks.pop_back(); + } + m_blocks.shrink_to_fit(); + } + }; + + namespace detail { + + // This is it, the table. Doubles as map and set, and uses `void` for T when its used as a set. + template< + class Key, + class T, // when void, treat it as a set. + class Hash, + class KeyEqual, + class AllocatorOrContainer, + class Bucket, + class BucketContainer, + bool IsSegmented> + class table: + public std::conditional_t, base_table_type_map, base_table_type_set> { + using underlying_value_type = std::conditional_t, std::pair, Key>; + using underlying_container_type = std::conditional_t< + IsSegmented, + segmented_vector, + std::vector>; + + public: + using value_container_type = std::conditional_t< + is_detected_v, + AllocatorOrContainer, + underlying_container_type>; + + private: + using bucket_alloc = typename std::allocator_traits< + typename value_container_type::allocator_type>::template rebind_alloc; + using default_bucket_container_type = std::conditional_t< + IsSegmented, + segmented_vector, + std::vector>; + + using bucket_container_type = std::conditional_t< + std::is_same_v, + default_bucket_container_type, + BucketContainer>; + + static constexpr std::uint8_t initial_shifts = + 64 - 2; // 2^(64-m_shift) number of buckets + static constexpr float default_max_load_factor = 0.8F; + + public: + using key_type = Key; + using value_type = typename value_container_type::value_type; + using size_type = typename value_container_type::size_type; + using difference_type = typename value_container_type::difference_type; + using hasher = Hash; + using key_equal = KeyEqual; + using allocator_type = typename value_container_type::allocator_type; + using reference = typename value_container_type::reference; + using const_reference = typename value_container_type::const_reference; + using pointer = typename value_container_type::pointer; + using const_pointer = typename value_container_type::const_pointer; + using const_iterator = typename value_container_type::const_iterator; + using iterator = std:: + conditional_t, typename value_container_type::iterator, const_iterator>; + using bucket_type = Bucket; + + private: + using value_idx_type = decltype(Bucket::m_value_idx); + using dist_and_fingerprint_type = decltype(Bucket::m_dist_and_fingerprint); + + static_assert( + std::is_trivially_destructible_v, + "assert there's no need to call destructor / std::destroy" + ); + static_assert( + std::is_trivially_copyable_v, + "assert we can just memset / memcpy" + ); + + value_container_type + m_values {}; // Contains all the key-value pairs in one densely stored container. No holes. + bucket_container_type m_buckets {}; + std::size_t m_max_bucket_capacity = 0; + float m_max_load_factor = default_max_load_factor; + Hash m_hash {}; + KeyEqual m_equal {}; + std::uint8_t m_shifts = initial_shifts; + + [[nodiscard]] auto next(value_idx_type bucket_idx) const -> value_idx_type { + if (ANKERL_UNORDERED_DENSE_UNLIKELY(bucket_idx + 1U == bucket_count())) + ANKERL_UNORDERED_DENSE_UNLIKELY_ATTR { + return 0; + } + + return static_cast(bucket_idx + 1U); + } + + // Helper to access bucket through pointer types + [[nodiscard]] static constexpr auto + at(bucket_container_type& bucket, std::size_t offset) -> Bucket& { + return bucket[offset]; + } + + [[nodiscard]] static constexpr auto + at(const bucket_container_type& bucket, std::size_t offset) -> const Bucket& { + return bucket[offset]; + } + + // use the dist_inc and dist_dec functions so that std::uint16_t types work without warning + [[nodiscard]] static constexpr auto dist_inc(dist_and_fingerprint_type x) + -> dist_and_fingerprint_type { + return static_cast(x + Bucket::dist_inc); + } + + [[nodiscard]] static constexpr auto dist_dec(dist_and_fingerprint_type x) + -> dist_and_fingerprint_type { + return static_cast(x - Bucket::dist_inc); + } + + // The goal of mixed_hash is to always produce a high quality 64bit hash. + template + [[nodiscard]] constexpr auto mixed_hash(K const& key) const -> std::uint64_t { + if constexpr (is_detected_v) { + // we know that the hash is good because is_avalanching. + if constexpr (sizeof(decltype(m_hash(key))) < sizeof(std::uint64_t)) { + // 32bit hash and is_avalanching => multiply with a constant to avalanche bits upwards + return m_hash(key) * UINT64_C(0x9ddfea08eb382d69); + } else { + // 64bit and is_avalanching => only use the hash itself. + return m_hash(key); + } + } else { + // not is_avalanching => apply wyhash + return wyhash::hash(m_hash(key)); + } + } + + [[nodiscard]] constexpr auto dist_and_fingerprint_from_hash(std::uint64_t hash) const + -> dist_and_fingerprint_type { + return Bucket::dist_inc + | (static_cast(hash) & Bucket::fingerprint_mask); + } + + [[nodiscard]] constexpr auto bucket_idx_from_hash(std::uint64_t hash) const + -> value_idx_type { + return static_cast(hash >> m_shifts); + } + + [[nodiscard]] static constexpr auto get_key(value_type const& vt) -> key_type const& { + if constexpr (is_map_v) { + return vt.first; + } else { + return vt; + } + } + + template + [[nodiscard]] auto next_while_less(K const& key) const -> Bucket { + auto hash = mixed_hash(key); + auto dist_and_fingerprint = dist_and_fingerprint_from_hash(hash); + auto bucket_idx = bucket_idx_from_hash(hash); + + while (dist_and_fingerprint < at(m_buckets, bucket_idx).m_dist_and_fingerprint) { + dist_and_fingerprint = dist_inc(dist_and_fingerprint); + bucket_idx = next(bucket_idx); + } + return { dist_and_fingerprint, bucket_idx }; + } + + void place_and_shift_up(Bucket bucket, value_idx_type place) { + while (0 != at(m_buckets, place).m_dist_and_fingerprint) { + bucket = std::exchange(at(m_buckets, place), bucket); + bucket.m_dist_and_fingerprint = dist_inc(bucket.m_dist_and_fingerprint); + place = next(place); + } + at(m_buckets, place) = bucket; + } + + void erase_and_shift_down(value_idx_type bucket_idx) { + // shift down until either empty or an element with correct spot is found + auto next_bucket_idx = next(bucket_idx); + while (at(m_buckets, next_bucket_idx).m_dist_and_fingerprint >= Bucket::dist_inc * 2 + ) { + auto& next_bucket = at(m_buckets, next_bucket_idx); + at(m_buckets, bucket_idx) = { dist_dec(next_bucket.m_dist_and_fingerprint), + next_bucket.m_value_idx }; + bucket_idx = std::exchange(next_bucket_idx, next(next_bucket_idx)); + } + at(m_buckets, bucket_idx) = {}; + } + + [[nodiscard]] static constexpr auto calc_num_buckets(std::uint8_t shifts) + -> std::size_t { + return (std::min)(max_bucket_count(), std::size_t { 1 } << (64U - shifts)); + } + + [[nodiscard]] constexpr auto calc_shifts_for_size(std::size_t s) const -> std::uint8_t { + auto shifts = initial_shifts; + while (shifts > 0 + && static_cast( + static_cast(calc_num_buckets(shifts)) * max_load_factor() + ) < s) + { + --shifts; + } + return shifts; + } + + // assumes m_values has data, m_buckets=m_buckets_end=nullptr, m_shifts is INITIAL_SHIFTS + void copy_buckets(table const& other) { + // assumes m_values has already the correct data copied over. + if (empty()) { + // when empty, at least allocate an initial buckets and clear them. + allocate_buckets_from_shift(); + clear_buckets(); + } else { + m_shifts = other.m_shifts; + allocate_buckets_from_shift(); + if constexpr (IsSegmented || !std::is_same_v) + { + for (auto i = 0UL; i < bucket_count(); ++i) { + at(m_buckets, i) = at(other.m_buckets, i); + } + } else { + std::memcpy( + m_buckets.data(), + other.m_buckets.data(), + sizeof(Bucket) * bucket_count() + ); + } + } + } + + /** + * True when no element can be added any more without increasing the size + */ + [[nodiscard]] auto is_full() const -> bool { + return size() > m_max_bucket_capacity; + } + + void deallocate_buckets() { + m_buckets.clear(); + m_buckets.shrink_to_fit(); + m_max_bucket_capacity = 0; + } + + void allocate_buckets_from_shift() { + auto num_buckets = calc_num_buckets(m_shifts); + if constexpr (IsSegmented || !std::is_same_v) + { + if constexpr (has_reserve) { + m_buckets.reserve(num_buckets); + } + for (std::size_t i = m_buckets.size(); i < num_buckets; ++i) { + m_buckets.emplace_back(); + } + } else { + m_buckets.resize(num_buckets); + } + if (num_buckets == max_bucket_count()) { + // reached the maximum, make sure we can use each bucket + m_max_bucket_capacity = max_bucket_count(); + } else { + m_max_bucket_capacity = static_cast( + static_cast(num_buckets) * max_load_factor() + ); + } + } + + void clear_buckets() { + if constexpr (IsSegmented || !std::is_same_v) + { + for (auto&& e: m_buckets) { + std::memset(&e, 0, sizeof(e)); + } + } else { + std::memset(m_buckets.data(), 0, sizeof(Bucket) * bucket_count()); + } + } + + void clear_and_fill_buckets_from_values() { + clear_buckets(); + for (value_idx_type value_idx = 0, + end_idx = static_cast(m_values.size()); + value_idx < end_idx; + ++value_idx) + { + auto const& key = get_key(m_values[value_idx]); + auto [dist_and_fingerprint, bucket] = next_while_less(key); + + // we know for certain that key has not yet been inserted, so no need to check it. + place_and_shift_up({ dist_and_fingerprint, value_idx }, bucket); + } + } + + void increase_size() { + if (m_max_bucket_capacity == max_bucket_count()) { + // remove the value again, we can't add it! + m_values.pop_back(); + on_error_bucket_overflow(); + } + --m_shifts; + if constexpr (!IsSegmented || std::is_same_v) + { + deallocate_buckets(); + } + allocate_buckets_from_shift(); + clear_and_fill_buckets_from_values(); + } + + template + void do_erase(value_idx_type bucket_idx, Op handle_erased_value) { + auto const value_idx_to_remove = at(m_buckets, bucket_idx).m_value_idx; + erase_and_shift_down(bucket_idx); + handle_erased_value(std::move(m_values[value_idx_to_remove])); + + // update m_values + if (value_idx_to_remove != m_values.size() - 1) { + // no luck, we'll have to replace the value with the last one and update the index accordingly + auto& val = m_values[value_idx_to_remove]; + val = std::move(m_values.back()); + + // update the values_idx of the moved entry. No need to play the info game, just look until we find the values_idx + bucket_idx = bucket_idx_from_hash(mixed_hash(get_key(val))); + auto const values_idx_back = static_cast(m_values.size() - 1); + while (values_idx_back != at(m_buckets, bucket_idx).m_value_idx) { + bucket_idx = next(bucket_idx); + } + at(m_buckets, bucket_idx).m_value_idx = value_idx_to_remove; + } + m_values.pop_back(); + } + + template + auto do_erase_key(K&& key, Op handle_erased_value) + -> std::size_t { // NOLINT(cppcoreguidelines-missing-std-forward) + if (empty()) { + return 0; + } + + auto [dist_and_fingerprint, bucket_idx] = next_while_less(key); + + while (dist_and_fingerprint == at(m_buckets, bucket_idx).m_dist_and_fingerprint + && !m_equal(key, get_key(m_values[at(m_buckets, bucket_idx).m_value_idx]))) + { + dist_and_fingerprint = dist_inc(dist_and_fingerprint); + bucket_idx = next(bucket_idx); + } + + if (dist_and_fingerprint != at(m_buckets, bucket_idx).m_dist_and_fingerprint) { + return 0; + } + do_erase(bucket_idx, handle_erased_value); + return 1; + } + + template + auto do_insert_or_assign(K&& key, M&& mapped) -> std::pair { + auto it_isinserted = try_emplace(std::forward(key), std::forward(mapped)); + if (!it_isinserted.second) { + it_isinserted.first->second = std::forward(mapped); + } + return it_isinserted; + } + + template + auto do_place_element( + dist_and_fingerprint_type dist_and_fingerprint, + value_idx_type bucket_idx, + Args&&... args + ) -> std::pair { + // emplace the new value. If that throws an exception, no harm done; index is still in a valid state + m_values.emplace_back(std::forward(args)...); + + auto value_idx = static_cast(m_values.size() - 1); + if (ANKERL_UNORDERED_DENSE_UNLIKELY(is_full())) + ANKERL_UNORDERED_DENSE_UNLIKELY_ATTR { + increase_size(); + } + else { + place_and_shift_up({ dist_and_fingerprint, value_idx }, bucket_idx); + } + + // place element and shift up until we find an empty spot + return { begin() + static_cast(value_idx), true }; + } + + template + auto do_try_emplace(K&& key, Args&&... args) -> std::pair { + auto hash = mixed_hash(key); + auto dist_and_fingerprint = dist_and_fingerprint_from_hash(hash); + auto bucket_idx = bucket_idx_from_hash(hash); + + while (true) { + auto* bucket = &at(m_buckets, bucket_idx); + if (dist_and_fingerprint == bucket->m_dist_and_fingerprint) { + if (m_equal(key, get_key(m_values[bucket->m_value_idx]))) { + return { begin() + static_cast(bucket->m_value_idx), + false }; + } + } else if (dist_and_fingerprint > bucket->m_dist_and_fingerprint) { + return do_place_element( + dist_and_fingerprint, + bucket_idx, + std::piecewise_construct, + std::forward_as_tuple(std::forward(key)), + std::forward_as_tuple(std::forward(args)...) + ); + } + dist_and_fingerprint = dist_inc(dist_and_fingerprint); + bucket_idx = next(bucket_idx); + } + } + + template + auto do_find(K const& key) -> iterator { + if (ANKERL_UNORDERED_DENSE_UNLIKELY(empty())) + ANKERL_UNORDERED_DENSE_UNLIKELY_ATTR { + return end(); + } + + auto mh = mixed_hash(key); + auto dist_and_fingerprint = dist_and_fingerprint_from_hash(mh); + auto bucket_idx = bucket_idx_from_hash(mh); + auto* bucket = &at(m_buckets, bucket_idx); + + // unrolled loop. *Always* check a few directly, then enter the loop. This is faster. + if (dist_and_fingerprint == bucket->m_dist_and_fingerprint + && m_equal(key, get_key(m_values[bucket->m_value_idx]))) + { + return begin() + static_cast(bucket->m_value_idx); + } + dist_and_fingerprint = dist_inc(dist_and_fingerprint); + bucket_idx = next(bucket_idx); + bucket = &at(m_buckets, bucket_idx); + + if (dist_and_fingerprint == bucket->m_dist_and_fingerprint + && m_equal(key, get_key(m_values[bucket->m_value_idx]))) + { + return begin() + static_cast(bucket->m_value_idx); + } + dist_and_fingerprint = dist_inc(dist_and_fingerprint); + bucket_idx = next(bucket_idx); + bucket = &at(m_buckets, bucket_idx); + + while (true) { + if (dist_and_fingerprint == bucket->m_dist_and_fingerprint) { + if (m_equal(key, get_key(m_values[bucket->m_value_idx]))) { + return begin() + static_cast(bucket->m_value_idx); + } + } else if (dist_and_fingerprint > bucket->m_dist_and_fingerprint) { + return end(); + } + dist_and_fingerprint = dist_inc(dist_and_fingerprint); + bucket_idx = next(bucket_idx); + bucket = &at(m_buckets, bucket_idx); + } + } + + template + auto do_find(K const& key) const -> const_iterator { + return const_cast(this)->do_find(key + ); // NOLINT(cppcoreguidelines-pro-type-const-cast) + } + + template, bool> = true> + auto do_at(K const& key) -> Q& { + if (auto it = find(key); ANKERL_UNORDERED_DENSE_LIKELY(end() != it)) + ANKERL_UNORDERED_DENSE_LIKELY_ATTR { + return it->second; + } + on_error_key_not_found(); + } + + template, bool> = true> + auto do_at(K const& key) const -> Q const& { + return const_cast(this)->at(key + ); // NOLINT(cppcoreguidelines-pro-type-const-cast) + } + + public: + explicit table( + std::size_t bucket_count, + Hash const& hash = Hash(), + KeyEqual const& equal = KeyEqual(), + allocator_type const& alloc_or_container = allocator_type() + ): + m_values(alloc_or_container), + m_buckets(alloc_or_container), + m_hash(hash), + m_equal(equal) { + if (0 != bucket_count) { + reserve(bucket_count); + } else { + allocate_buckets_from_shift(); + clear_buckets(); + } + } + + table(): table(0) {} + + table(std::size_t bucket_count, allocator_type const& alloc): + table(bucket_count, Hash(), KeyEqual(), alloc) {} + + table(std::size_t bucket_count, Hash const& hash, allocator_type const& alloc): + table(bucket_count, hash, KeyEqual(), alloc) {} + + explicit table(allocator_type const& alloc): table(0, Hash(), KeyEqual(), alloc) {} + + template + table( + InputIt first, + InputIt last, + size_type bucket_count = 0, + Hash const& hash = Hash(), + KeyEqual const& equal = KeyEqual(), + allocator_type const& alloc = allocator_type() + ): + table(bucket_count, hash, equal, alloc) { + insert(first, last); + } + + template + table(InputIt first, InputIt last, size_type bucket_count, allocator_type const& alloc): + table(first, last, bucket_count, Hash(), KeyEqual(), alloc) {} + + template + table( + InputIt first, + InputIt last, + size_type bucket_count, + Hash const& hash, + allocator_type const& alloc + ): + table(first, last, bucket_count, hash, KeyEqual(), alloc) {} + + table(table const& other): table(other, other.m_values.get_allocator()) {} + + table(table const& other, allocator_type const& alloc): + m_values(other.m_values, alloc), + m_max_load_factor(other.m_max_load_factor), + m_hash(other.m_hash), + m_equal(other.m_equal) { + copy_buckets(other); + } + + table(table&& other) noexcept: + table(std::move(other), other.m_values.get_allocator()) {} + + table(table&& other, allocator_type const& alloc) noexcept: m_values(alloc) { + *this = std::move(other); + } + + table( + std::initializer_list ilist, + std::size_t bucket_count = 0, + Hash const& hash = Hash(), + KeyEqual const& equal = KeyEqual(), + allocator_type const& alloc = allocator_type() + ): + table(bucket_count, hash, equal, alloc) { + insert(ilist); + } + + table( + std::initializer_list ilist, + size_type bucket_count, + allocator_type const& alloc + ): + table(ilist, bucket_count, Hash(), KeyEqual(), alloc) {} + + table( + std::initializer_list init, + size_type bucket_count, + Hash const& hash, + allocator_type const& alloc + ): + table(init, bucket_count, hash, KeyEqual(), alloc) {} + + ~table() = default; + + auto operator=(table const& other) -> table& { + if (&other != this) { + deallocate_buckets( + ); // deallocate before m_values is set (might have another allocator) + m_values = other.m_values; + m_max_load_factor = other.m_max_load_factor; + m_hash = other.m_hash; + m_equal = other.m_equal; + m_shifts = initial_shifts; + copy_buckets(other); + } + return *this; + } + + auto operator=(table&& other + ) noexcept(noexcept(std::is_nothrow_move_assignable_v&& + std::is_nothrow_move_assignable_v&& + std::is_nothrow_move_assignable_v)) -> table& { + if (&other != this) { + deallocate_buckets( + ); // deallocate before m_values is set (might have another allocator) + m_values = std::move(other.m_values); + other.m_values.clear(); + + // we can only reuse m_buckets when both maps have the same allocator! + if (get_allocator() == other.get_allocator()) { + m_buckets = std::move(other.m_buckets); + other.m_buckets.clear(); + m_max_bucket_capacity = std::exchange(other.m_max_bucket_capacity, 0); + m_shifts = std::exchange(other.m_shifts, initial_shifts); + m_max_load_factor = + std::exchange(other.m_max_load_factor, default_max_load_factor); + m_hash = std::exchange(other.m_hash, {}); + m_equal = std::exchange(other.m_equal, {}); + other.allocate_buckets_from_shift(); + other.clear_buckets(); + } else { + // set max_load_factor *before* copying the other's buckets, so we have the same + // behavior + m_max_load_factor = other.m_max_load_factor; + + // copy_buckets sets m_buckets, m_num_buckets, m_max_bucket_capacity, m_shifts + copy_buckets(other); + // clear's the other's buckets so other is now already usable. + other.clear_buckets(); + m_hash = other.m_hash; + m_equal = other.m_equal; + } + // map "other" is now already usable, it's empty. + } + return *this; + } + + auto operator=(std::initializer_list ilist) -> table& { + clear(); + insert(ilist); + return *this; + } + + auto get_allocator() const noexcept -> allocator_type { + return m_values.get_allocator(); + } + + // iterators ////////////////////////////////////////////////////////////// + + auto begin() noexcept -> iterator { + return m_values.begin(); + } + + auto begin() const noexcept -> const_iterator { + return m_values.begin(); + } + + auto cbegin() const noexcept -> const_iterator { + return m_values.cbegin(); + } + + auto end() noexcept -> iterator { + return m_values.end(); + } + + auto cend() const noexcept -> const_iterator { + return m_values.cend(); + } + + auto end() const noexcept -> const_iterator { + return m_values.end(); + } + + // capacity /////////////////////////////////////////////////////////////// + + [[nodiscard]] auto empty() const noexcept -> bool { + return m_values.empty(); + } + + [[nodiscard]] auto size() const noexcept -> std::size_t { + return m_values.size(); + } + + [[nodiscard]] static constexpr auto max_size() noexcept -> std::size_t { + if constexpr ((std::numeric_limits::max)() == (std::numeric_limits::max)()) + { + return std::size_t { 1 } << (sizeof(value_idx_type) * 8 - 1); + } else { + return std::size_t { 1 } << (sizeof(value_idx_type) * 8); + } + } + + // modifiers ////////////////////////////////////////////////////////////// + + void clear() { + m_values.clear(); + clear_buckets(); + } + + auto insert(value_type const& value) -> std::pair { + return emplace(value); + } + + auto insert(value_type&& value) -> std::pair { + return emplace(std::move(value)); + } + + template< + class P, + std::enable_if_t, bool> = true> + auto insert(P&& value) -> std::pair { + return emplace(std::forward

(value)); + } + + auto insert(const_iterator /*hint*/, value_type const& value) -> iterator { + return insert(value).first; + } + + auto insert(const_iterator /*hint*/, value_type&& value) -> iterator { + return insert(std::move(value)).first; + } + + template< + class P, + std::enable_if_t, bool> = true> + auto insert(const_iterator /*hint*/, P&& value) -> iterator { + return insert(std::forward

(value)).first; + } + + template + void insert(InputIt first, InputIt last) { + while (first != last) { + insert(*first); + ++first; + } + } + + void insert(std::initializer_list ilist) { + insert(ilist.begin(), ilist.end()); + } + + // nonstandard API: *this is emptied. + // Also see "A Standard flat_map" https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p0429r9.pdf + auto extract() && -> value_container_type { + return std::move(m_values); + } + + // nonstandard API: + // Discards the internally held container and replaces it with the one passed. Erases non-unique elements. + auto replace(value_container_type&& container) { + if (ANKERL_UNORDERED_DENSE_UNLIKELY(container.size() > max_size())) + ANKERL_UNORDERED_DENSE_UNLIKELY_ATTR { + on_error_too_many_elements(); + } + auto shifts = calc_shifts_for_size(container.size()); + if (0 == bucket_count() || shifts < m_shifts + || container.get_allocator() != m_values.get_allocator()) + { + m_shifts = shifts; + deallocate_buckets(); + allocate_buckets_from_shift(); + } + clear_buckets(); + + m_values = std::move(container); + + // can't use clear_and_fill_buckets_from_values() because container elements might not be unique + auto value_idx = value_idx_type {}; + + // loop until we reach the end of the container. duplicated entries will be replaced with back(). + while (value_idx != static_cast(m_values.size())) { + auto const& key = get_key(m_values[value_idx]); + + auto hash = mixed_hash(key); + auto dist_and_fingerprint = dist_and_fingerprint_from_hash(hash); + auto bucket_idx = bucket_idx_from_hash(hash); + + bool key_found = false; + while (true) { + auto const& bucket = at(m_buckets, bucket_idx); + if (dist_and_fingerprint > bucket.m_dist_and_fingerprint) { + break; + } + if (dist_and_fingerprint == bucket.m_dist_and_fingerprint + && m_equal(key, get_key(m_values[bucket.m_value_idx]))) + { + key_found = true; + break; + } + dist_and_fingerprint = dist_inc(dist_and_fingerprint); + bucket_idx = next(bucket_idx); + } + + if (key_found) { + if (value_idx != static_cast(m_values.size() - 1)) { + m_values[value_idx] = std::move(m_values.back()); + } + m_values.pop_back(); + } else { + place_and_shift_up({ dist_and_fingerprint, value_idx }, bucket_idx); + ++value_idx; + } + } + } + + template, bool> = true> + auto insert_or_assign(Key const& key, M&& mapped) -> std::pair { + return do_insert_or_assign(key, std::forward(mapped)); + } + + template, bool> = true> + auto insert_or_assign(Key&& key, M&& mapped) -> std::pair { + return do_insert_or_assign(std::move(key), std::forward(mapped)); + } + + template< + typename K, + typename M, + typename Q = T, + typename H = Hash, + typename KE = KeyEqual, + std::enable_if_t && is_transparent_v, bool> = true> + auto insert_or_assign(K&& key, M&& mapped) -> std::pair { + return do_insert_or_assign(std::forward(key), std::forward(mapped)); + } + + template, bool> = true> + auto insert_or_assign(const_iterator /*hint*/, Key const& key, M&& mapped) -> iterator { + return do_insert_or_assign(key, std::forward(mapped)).first; + } + + template, bool> = true> + auto insert_or_assign(const_iterator /*hint*/, Key&& key, M&& mapped) -> iterator { + return do_insert_or_assign(std::move(key), std::forward(mapped)).first; + } + + template< + typename K, + typename M, + typename Q = T, + typename H = Hash, + typename KE = KeyEqual, + std::enable_if_t && is_transparent_v, bool> = true> + auto insert_or_assign(const_iterator /*hint*/, K&& key, M&& mapped) -> iterator { + return do_insert_or_assign(std::forward(key), std::forward(mapped)).first; + } + + // Single arguments for unordered_set can be used without having to construct the value_type + template< + class K, + typename Q = T, + typename H = Hash, + typename KE = KeyEqual, + std::enable_if_t && is_transparent_v, bool> = true> + auto emplace(K&& key) -> std::pair { + auto hash = mixed_hash(key); + auto dist_and_fingerprint = dist_and_fingerprint_from_hash(hash); + auto bucket_idx = bucket_idx_from_hash(hash); + + while (dist_and_fingerprint <= at(m_buckets, bucket_idx).m_dist_and_fingerprint) { + if (dist_and_fingerprint == at(m_buckets, bucket_idx).m_dist_and_fingerprint + && m_equal(key, m_values[at(m_buckets, bucket_idx).m_value_idx])) + { + // found it, return without ever actually creating anything + return { + begin( + ) + static_cast(at(m_buckets, bucket_idx).m_value_idx), + false + }; + } + dist_and_fingerprint = dist_inc(dist_and_fingerprint); + bucket_idx = next(bucket_idx); + } + + // value is new, insert element first, so when exception happens we are in a valid state + return do_place_element(dist_and_fingerprint, bucket_idx, std::forward(key)); + } + + template + auto emplace(Args&&... args) -> std::pair { + // we have to instantiate the value_type to be able to access the key. + // 1. emplace_back the object so it is constructed. 2. If the key is already there, pop it later in the loop. + auto& key = get_key(m_values.emplace_back(std::forward(args)...)); + auto hash = mixed_hash(key); + auto dist_and_fingerprint = dist_and_fingerprint_from_hash(hash); + auto bucket_idx = bucket_idx_from_hash(hash); + + while (dist_and_fingerprint <= at(m_buckets, bucket_idx).m_dist_and_fingerprint) { + if (dist_and_fingerprint == at(m_buckets, bucket_idx).m_dist_and_fingerprint + && m_equal(key, get_key(m_values[at(m_buckets, bucket_idx).m_value_idx]))) + { + m_values.pop_back(); // value was already there, so get rid of it + return { + begin( + ) + static_cast(at(m_buckets, bucket_idx).m_value_idx), + false + }; + } + dist_and_fingerprint = dist_inc(dist_and_fingerprint); + bucket_idx = next(bucket_idx); + } + + // value is new, place the bucket and shift up until we find an empty spot + auto value_idx = static_cast(m_values.size() - 1); + if (ANKERL_UNORDERED_DENSE_UNLIKELY(is_full())) + ANKERL_UNORDERED_DENSE_UNLIKELY_ATTR { + // increase_size just rehashes all the data we have in m_values + increase_size(); + } + else { + // place element and shift up until we find an empty spot + place_and_shift_up({ dist_and_fingerprint, value_idx }, bucket_idx); + } + return { begin() + static_cast(value_idx), true }; + } + + template + auto emplace_hint(const_iterator /*hint*/, Args&&... args) -> iterator { + return emplace(std::forward(args)...).first; + } + + template, bool> = true> + auto try_emplace(Key const& key, Args&&... args) -> std::pair { + return do_try_emplace(key, std::forward(args)...); + } + + template, bool> = true> + auto try_emplace(Key&& key, Args&&... args) -> std::pair { + return do_try_emplace(std::move(key), std::forward(args)...); + } + + template, bool> = true> + auto try_emplace(const_iterator /*hint*/, Key const& key, Args&&... args) -> iterator { + return do_try_emplace(key, std::forward(args)...).first; + } + + template, bool> = true> + auto try_emplace(const_iterator /*hint*/, Key&& key, Args&&... args) -> iterator { + return do_try_emplace(std::move(key), std::forward(args)...).first; + } + + template< + typename K, + typename... Args, + typename Q = T, + typename H = Hash, + typename KE = KeyEqual, + std::enable_if_t< + is_map_v< + Q> && is_transparent_v && is_neither_convertible_v, + bool> = true> + auto try_emplace(K&& key, Args&&... args) -> std::pair { + return do_try_emplace(std::forward(key), std::forward(args)...); + } + + template< + typename K, + typename... Args, + typename Q = T, + typename H = Hash, + typename KE = KeyEqual, + std::enable_if_t< + is_map_v< + Q> && is_transparent_v && is_neither_convertible_v, + bool> = true> + auto try_emplace(const_iterator /*hint*/, K&& key, Args&&... args) -> iterator { + return do_try_emplace(std::forward(key), std::forward(args)...).first; + } + + // Replaces the key at the given iterator with new_key. This does not change any other data in the underlying table, so + // all iterators and references remain valid. However, this operation can fail if new_key already exists in the table. + // In that case, returns {iterator to the already existing new_key, false} and no change is made. + // + // In the case of a set, this effectively removes the old key and inserts the new key at the same spot, which is more + // efficient than removing the old key and inserting the new key because it avoids repositioning the last element. + template + auto replace_key(iterator it, K&& new_key) -> std::pair { + auto const new_key_hash = mixed_hash(new_key); + + // first, check if new_key already exists and return if so + auto dist_and_fingerprint = dist_and_fingerprint_from_hash(new_key_hash); + auto bucket_idx = bucket_idx_from_hash(new_key_hash); + while (dist_and_fingerprint <= at(m_buckets, bucket_idx).m_dist_and_fingerprint) { + auto const& bucket = at(m_buckets, bucket_idx); + if (dist_and_fingerprint == bucket.m_dist_and_fingerprint + && m_equal(new_key, get_key(m_values[bucket.m_value_idx]))) + { + return { begin() + static_cast(bucket.m_value_idx), + false }; + } + dist_and_fingerprint = dist_inc(dist_and_fingerprint); + bucket_idx = next(bucket_idx); + } + + // const_cast is needed because iterator for the set is always const, so adding another get_key overload is not + // feasible. + auto& target_key = const_cast(get_key(*it)); + auto const old_key_bucket_idx = bucket_idx_from_hash(mixed_hash(target_key)); + + // Replace the key before doing any bucket changes. If it throws, no harm done, we are still in a valid state as we + // have not modified any buckets yet. + target_key = std::forward(new_key); + + auto const value_idx = static_cast(it - begin()); + + // Find the bucket containing our value_idx. It's guaranteed we find it, so no other stopping condition needed. + bucket_idx = old_key_bucket_idx; + while (value_idx != at(m_buckets, bucket_idx).m_value_idx) { + bucket_idx = next(bucket_idx); + } + erase_and_shift_down(bucket_idx); + + // place the new bucket + dist_and_fingerprint = dist_and_fingerprint_from_hash(new_key_hash); + bucket_idx = bucket_idx_from_hash(new_key_hash); + while (dist_and_fingerprint < at(m_buckets, bucket_idx).m_dist_and_fingerprint) { + dist_and_fingerprint = dist_inc(dist_and_fingerprint); + bucket_idx = next(bucket_idx); + } + place_and_shift_up({ dist_and_fingerprint, value_idx }, bucket_idx); + + return { it, true }; + } + + auto erase(iterator it) -> iterator { + auto hash = mixed_hash(get_key(*it)); + auto bucket_idx = bucket_idx_from_hash(hash); + + auto const value_idx_to_remove = static_cast(it - cbegin()); + while (at(m_buckets, bucket_idx).m_value_idx != value_idx_to_remove) { + bucket_idx = next(bucket_idx); + } + + do_erase(bucket_idx, [](value_type const& /*unused*/) -> void {}); + return begin() + static_cast(value_idx_to_remove); + } + + auto extract(iterator it) -> value_type { + auto hash = mixed_hash(get_key(*it)); + auto bucket_idx = bucket_idx_from_hash(hash); + + auto const value_idx_to_remove = static_cast(it - cbegin()); + while (at(m_buckets, bucket_idx).m_value_idx != value_idx_to_remove) { + bucket_idx = next(bucket_idx); + } + + auto tmp = std::optional {}; + do_erase(bucket_idx, [&tmp](value_type&& val) -> void { tmp = std::move(val); }); + return std::move(tmp).value(); + } + + template, bool> = true> + auto erase(const_iterator it) -> iterator { + return erase(begin() + (it - cbegin())); + } + + template, bool> = true> + auto extract(const_iterator it) -> value_type { + return extract(begin() + (it - cbegin())); + } + + auto erase(const_iterator first, const_iterator last) -> iterator { + auto const idx_first = first - cbegin(); + auto const idx_last = last - cbegin(); + auto const first_to_last = std::distance(first, last); + auto const last_to_end = std::distance(last, cend()); + + // remove elements from left to right which moves elements from the end back + auto const mid = idx_first + (std::min)(first_to_last, last_to_end); + auto idx = idx_first; + while (idx != mid) { + erase(begin() + idx); + ++idx; + } + + // all elements from the right are moved, now remove the last element until all done + idx = idx_last; + while (idx != mid) { + --idx; + erase(begin() + idx); + } + + return begin() + idx_first; + } + + auto erase(Key const& key) -> std::size_t { + return do_erase_key(key, [](value_type const& /*unused*/) -> void {}); + } + + auto extract(Key const& key) -> std::optional { + auto tmp = std::optional {}; + do_erase_key(key, [&tmp](value_type&& val) -> void { tmp = std::move(val); }); + return tmp; + } + + template< + class K, + class H = Hash, + class KE = KeyEqual, + std::enable_if_t, bool> = true> + auto erase(K&& key) -> std::size_t { + return do_erase_key(std::forward(key), [](value_type const& /*unused*/) -> void { + }); + } + + template< + class K, + class H = Hash, + class KE = KeyEqual, + std::enable_if_t, bool> = true> + auto extract(K&& key) -> std::optional { + auto tmp = std::optional {}; + do_erase_key(std::forward(key), [&tmp](value_type&& val) -> void { + tmp = std::move(val); + }); + return tmp; + } + + void swap(table& other + ) noexcept(noexcept(std::is_nothrow_swappable_v&& + std::is_nothrow_swappable_v&& + std::is_nothrow_swappable_v)) { + using std::swap; + swap(other, *this); + } + + // lookup ///////////////////////////////////////////////////////////////// + + template, bool> = true> + auto at(key_type const& key) -> Q& { + return do_at(key); + } + + template< + typename K, + typename Q = T, + typename H = Hash, + typename KE = KeyEqual, + std::enable_if_t && is_transparent_v, bool> = true> + auto at(K const& key) -> Q& { + return do_at(key); + } + + template, bool> = true> + auto at(key_type const& key) const -> Q const& { + return do_at(key); + } + + template< + typename K, + typename Q = T, + typename H = Hash, + typename KE = KeyEqual, + std::enable_if_t && is_transparent_v, bool> = true> + auto at(K const& key) const -> Q const& { + return do_at(key); + } + + template, bool> = true> + auto operator[](Key const& key) -> Q& { + return try_emplace(key).first->second; + } + + template, bool> = true> + auto operator[](Key&& key) -> Q& { + return try_emplace(std::move(key)).first->second; + } + + template< + typename K, + typename Q = T, + typename H = Hash, + typename KE = KeyEqual, + std::enable_if_t && is_transparent_v, bool> = true> + auto operator[](K&& key) -> Q& { + return try_emplace(std::forward(key)).first->second; + } + + auto count(Key const& key) const -> std::size_t { + return find(key) == end() ? 0 : 1; + } + + template< + class K, + class H = Hash, + class KE = KeyEqual, + std::enable_if_t, bool> = true> + auto count(K const& key) const -> std::size_t { + return find(key) == end() ? 0 : 1; + } + + auto find(Key const& key) -> iterator { + return do_find(key); + } + + auto find(Key const& key) const -> const_iterator { + return do_find(key); + } + + template< + class K, + class H = Hash, + class KE = KeyEqual, + std::enable_if_t, bool> = true> + auto find(K const& key) -> iterator { + return do_find(key); + } + + template< + class K, + class H = Hash, + class KE = KeyEqual, + std::enable_if_t, bool> = true> + auto find(K const& key) const -> const_iterator { + return do_find(key); + } + + auto contains(Key const& key) const -> bool { + return find(key) != end(); + } + + template< + class K, + class H = Hash, + class KE = KeyEqual, + std::enable_if_t, bool> = true> + auto contains(K const& key) const -> bool { + return find(key) != end(); + } + + auto equal_range(Key const& key) -> std::pair { + auto it = do_find(key); + return { it, it == end() ? end() : it + 1 }; + } + + auto equal_range(const Key& key) const -> std::pair { + auto it = do_find(key); + return { it, it == end() ? end() : it + 1 }; + } + + template< + class K, + class H = Hash, + class KE = KeyEqual, + std::enable_if_t, bool> = true> + auto equal_range(K const& key) -> std::pair { + auto it = do_find(key); + return { it, it == end() ? end() : it + 1 }; + } + + template< + class K, + class H = Hash, + class KE = KeyEqual, + std::enable_if_t, bool> = true> + auto equal_range(K const& key) const -> std::pair { + auto it = do_find(key); + return { it, it == end() ? end() : it + 1 }; + } + + // bucket interface /////////////////////////////////////////////////////// + + auto bucket_count() const noexcept -> std::size_t { // NOLINT(modernize-use-nodiscard) + return m_buckets.size(); + } + + static constexpr auto max_bucket_count() noexcept + -> std::size_t { // NOLINT(modernize-use-nodiscard) + return max_size(); + } + + // hash policy //////////////////////////////////////////////////////////// + + [[nodiscard]] auto load_factor() const -> float { + return bucket_count() + ? static_cast(size()) / static_cast(bucket_count()) + : 0.0F; + } + + [[nodiscard]] auto max_load_factor() const -> float { + return m_max_load_factor; + } + + void max_load_factor(float ml) { + m_max_load_factor = ml; + if (bucket_count() != max_bucket_count()) { + m_max_bucket_capacity = static_cast( + static_cast(bucket_count()) * max_load_factor() + ); + } + } + + void rehash(std::size_t count) { + count = (std::min)(count, max_size()); + auto shifts = calc_shifts_for_size((std::max)(count, size())); + if (shifts != m_shifts) { + m_shifts = shifts; + deallocate_buckets(); + m_values.shrink_to_fit(); + allocate_buckets_from_shift(); + clear_and_fill_buckets_from_values(); + } + } + + void reserve(std::size_t capa) { + capa = (std::min)(capa, max_size()); + if constexpr (has_reserve) { + // std::deque doesn't have reserve(). Make sure we only call when available + m_values.reserve(capa); + } + auto shifts = calc_shifts_for_size((std::max)(capa, size())); + if (0 == bucket_count() || shifts < m_shifts) { + m_shifts = shifts; + deallocate_buckets(); + allocate_buckets_from_shift(); + clear_and_fill_buckets_from_values(); + } + } + + // observers ////////////////////////////////////////////////////////////// + + auto hash_function() const -> hasher { + return m_hash; + } + + auto key_eq() const -> key_equal { + return m_equal; + } + + // nonstandard API: expose the underlying values container + [[nodiscard]] auto values() const noexcept -> value_container_type const& { + return m_values; + } + + // non-member functions /////////////////////////////////////////////////// + + friend auto operator==(table const& a, table const& b) -> bool { + if (&a == &b) { + return true; + } + if (a.size() != b.size()) { + return false; + } + for (auto const& b_entry: b) { + auto it = a.find(get_key(b_entry)); + if constexpr (is_map_v) { + // map: check that key is here, then also check that value is the same + if (a.end() == it || !(b_entry.second == it->second)) { + return false; + } + } else { + // set: only check that the key is here + if (a.end() == it) { + return false; + } + } + } + return true; + } + + friend auto operator!=(table const& a, table const& b) -> bool { + return !(a == b); + } + }; + + } // namespace detail + + template< + class Key, + class T, + class Hash = hash, + class KeyEqual = std::equal_to, + class AllocatorOrContainer = std::allocator>, + class Bucket = bucket_type::standard, + class BucketContainer = detail::default_container_t> + using map = + detail::table; + + template< + class Key, + class T, + class Hash = hash, + class KeyEqual = std::equal_to, + class AllocatorOrContainer = std::allocator>, + class Bucket = bucket_type::standard, + class BucketContainer = detail::default_container_t> + using segmented_map = + detail::table; + + template< + class Key, + class Hash = hash, + class KeyEqual = std::equal_to, + class AllocatorOrContainer = std::allocator, + class Bucket = bucket_type::standard, + class BucketContainer = detail::default_container_t> + using set = detail:: + table; + + template< + class Key, + class Hash = hash, + class KeyEqual = std::equal_to, + class AllocatorOrContainer = std::allocator, + class Bucket = bucket_type::standard, + class BucketContainer = detail::default_container_t> + using segmented_set = detail:: + table; + + #if defined(ANKERL_UNORDERED_DENSE_PMR) + + namespace pmr { + + template< + class Key, + class T, + class Hash = hash, + class KeyEqual = std::equal_to, + class Bucket = bucket_type::standard> + using map = detail::table< + Key, + T, + Hash, + KeyEqual, + ANKERL_UNORDERED_DENSE_PMR::polymorphic_allocator>, + Bucket, + detail::default_container_t, + false>; + + template< + class Key, + class T, + class Hash = hash, + class KeyEqual = std::equal_to, + class Bucket = bucket_type::standard> + using segmented_map = detail::table< + Key, + T, + Hash, + KeyEqual, + ANKERL_UNORDERED_DENSE_PMR::polymorphic_allocator>, + Bucket, + detail::default_container_t, + true>; + + template< + class Key, + class Hash = hash, + class KeyEqual = std::equal_to, + class Bucket = bucket_type::standard> + using set = detail::table< + Key, + void, + Hash, + KeyEqual, + ANKERL_UNORDERED_DENSE_PMR::polymorphic_allocator, + Bucket, + detail::default_container_t, + false>; + + template< + class Key, + class Hash = hash, + class KeyEqual = std::equal_to, + class Bucket = bucket_type::standard> + using segmented_set = detail::table< + Key, + void, + Hash, + KeyEqual, + ANKERL_UNORDERED_DENSE_PMR::polymorphic_allocator, + Bucket, + detail::default_container_t, + true>; + + } // namespace pmr + + #endif + + // deduction guides /////////////////////////////////////////////////////////// + + // deduction guides for alias templates are only possible since C++20 + // see https://en.cppreference.com/w/cpp/language/class_template_argument_deduction + +} // namespace ANKERL_UNORDERED_DENSE_NAMESPACE +} // namespace ankerl::unordered_dense + +// std extensions ///////////////////////////////////////////////////////////// + +namespace std { // NOLINT(cert-dcl58-cpp) + +template< + class Key, + class T, + class Hash, + class KeyEqual, + class AllocatorOrContainer, + class Bucket, + class Pred, + class BucketContainer, + bool IsSegmented> +// NOLINTNEXTLINE(cert-dcl58-cpp) +auto erase_if( + ankerl::unordered_dense::detail:: + table& + map, + Pred pred +) -> std::size_t { + using map_t = ankerl::unordered_dense::detail:: + table; + + // going back to front because erase() invalidates the end iterator + auto const old_size = map.size(); + auto idx = old_size; + while (idx) { + --idx; + auto it = map.begin() + static_cast(idx); + if (pred(*it)) { + map.erase(it); + } + } + + return old_size - map.size(); +} + +} // namespace std + +#endif +#endif \ No newline at end of file diff --git a/wust_vision-main/CMakeLists.txt b/wust_vision-main/CMakeLists.txt new file mode 100644 index 0000000..daf003d --- /dev/null +++ b/wust_vision-main/CMakeLists.txt @@ -0,0 +1,160 @@ +cmake_minimum_required(VERSION 3.14) +cmake_policy(SET CMP0072 NEW) +project(wust_vision LANGUAGES CXX) + + +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_SOURCE_DIR}/bin) +set(CMAKE_BUILD_TYPE "Release") +message(STATUS "--------------------CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}--------------------") + + +option(BUILD_WITH_TRT "Enable TensorRT backend" ON) +option(BUILD_WITH_OPENVINO "Enable OpenVINO backend" ON) +option(BUILD_WITH_NCNN "Enable NCNN backend" OFF) +option(BUILD_WITH_ORT "Enable ORT backend" ON) + +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${PROJECT_SOURCE_DIR}/cmake) + + +file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS tasks/*.cpp) + + +# list(FILTER SOURCES EXCLUDE REGEX "tasks/auto_guidance/.*") +# list(FILTER SOURCES EXCLUDE REGEX "tasks/auto_sniper/.*") + +find_package(OpenCV REQUIRED) +find_package(Eigen3 REQUIRED) +find_package(fmt REQUIRED) +find_package(yaml-cpp REQUIRED) +find_package(Ceres REQUIRED) +find_package(HikSDK REQUIRED) +find_package(wust_vl REQUIRED) + +set(BUILD_DEFINITIONS "") +set(BUILD_LIBS "") + + +if(BUILD_WITH_OPENVINO) + find_package(OpenVINO REQUIRED COMPONENTS Runtime ONNX) + list(APPEND BUILD_DEFINITIONS USE_OPENVINO) + list(APPEND BUILD_LIBS openvino::runtime openvino::frontend::onnx) +endif() + +if(BUILD_WITH_TRT) + find_package(CUDAToolkit REQUIRED) + set(CMAKE_PREFIX_PATH "${CMAKE_PREFIX_PATH};/home/hy/TensorRT-10.6.0.26") + find_package(TensorRT REQUIRED) + include_directories(${TensorRT_INCLUDE_DIR}) + include_directories(/usr/local/cuda/include) + set(CMAKE_CUDA_COMPILER /usr/local/cuda/bin/nvcc) + set(CMAKE_CUDA_ARCHITECTURES 86) + enable_language(CUDA) + set(CMAKE_CUDA_STANDARD 14) + set(CMAKE_CUDA_STANDARD_REQUIRED ON) + set(CMAKE_CUDA_COMPILER_LAUNCHER ccache) + add_subdirectory(${CMAKE_SOURCE_DIR}/cuda_infer ${CMAKE_BINARY_DIR}/cuda_infer_build) + list(APPEND BUILD_DEFINITIONS USE_TRT) + list(APPEND BUILD_LIBS TensorRT::TensorRT TensorRT::nvonnxparser CUDA::cudart cuda_infer) +endif() + +if(BUILD_WITH_NCNN) + find_package(ncnn REQUIRED) + list(APPEND BUILD_DEFINITIONS USE_NCNN) + list(APPEND BUILD_LIBS ncnn) +endif() + +if(BUILD_WITH_ORT) + set(ort_root_path "/home/hy/onnxruntime-linux-x64-gpu-1.22.0") + find_package(Ort REQUIRED) + include_directories(${Ort_INCLUDE_DIR}) + list(APPEND BUILD_DEFINITIONS USE_ORT) + list(APPEND BUILD_LIBS ${Ort_LIB}) +endif() + + +add_library(${PROJECT_NAME} SHARED ${SOURCES}) +target_compile_definitions(${PROJECT_NAME} PUBLIC ${BUILD_DEFINITIONS}) +target_include_directories(${PROJECT_NAME} PUBLIC + ${PROJECT_SOURCE_DIR} + ${OpenCV_INCLUDE_DIRS} +) +target_link_libraries(${PROJECT_NAME} + wust_vl::wust_vl + HikSDK::HikSDK + yaml-cpp + fmt::fmt + Eigen3::Eigen + Ceres::ceres + ${OpenCV_LIBS} + ${BUILD_LIBS} +) +set(ROS2_PACKAGES + ament_cmake + rosidl_typesupport_cpp + rclcpp + geometry_msgs + tf2_ros + tf2_geometry_msgs + sensor_msgs + visualization_msgs + sentry_interfaces + nav_msgs +) + + + +function(add_common_executable exe_name src_file) + add_executable(${exe_name} ${src_file}) + target_link_libraries(${exe_name} ${PROJECT_NAME}) +endfunction() +add_common_executable(standard src/standard.cpp) +# add_common_executable(dart src/dart.cpp) +add_common_executable(test_usbcamera test/test_usbcamera.cpp) + + +foreach(pkg IN LISTS ROS2_PACKAGES) + find_package(${pkg} QUIET) +endforeach() + + +set(ROS2_FULL_FOUND TRUE) +foreach(pkg IN LISTS ROS2_PACKAGES) + if(NOT ${pkg}_FOUND) + set(ROS2_FULL_FOUND FALSE) + message(WARNING "ROS2 package ${pkg} not found") + endif() +endforeach() + +if(ROS2_FULL_FOUND) + message(STATUS "ROS2 full environment found, compiling ROS2 targets ...") + list(APPEND BUILD_DEFINITIONS USE_ROS2) + + set(ROS2_INCLUDES) + foreach(pkg IN LISTS ROS2_PACKAGES) + if(TARGET ${pkg}::${pkg}) + list(APPEND ROS2_INCLUDES ${${pkg}_INCLUDE_DIRS}) + endif() + endforeach() + include_directories(${ROS2_INCLUDES}) + # find_package(Open3D REQUIRED) + + # target_link_libraries(${PROJECT_NAME} + # Open3D::Open3D + # ) + ament_target_dependencies(${PROJECT_NAME} ${ROS2_PACKAGES}) + target_compile_definitions(${PROJECT_NAME} PUBLIC ${BUILD_DEFINITIONS}) + macro(add_ros2_executable exe_name src_file) + add_executable(${exe_name} ${src_file}) + target_link_libraries(${exe_name} ${PROJECT_NAME} ) + endmacro() + + + add_ros2_executable(sentry src/sentry.cpp) + add_ros2_executable(nav test/nav.cpp) +else() + message(WARNING "ROS2 dependencies incomplete, skipping ROS2 targets ...") +endif() + diff --git a/wust_vision-main/README.md b/wust_vision-main/README.md new file mode 100644 index 0000000..b57077e --- /dev/null +++ b/wust_vision-main/README.md @@ -0,0 +1,252 @@ +# WUST_VISION +武汉科技大学崇实战队视觉代码仓库 + +## 写在前面 +本项目基于[中南大学FYT战队2024赛季视觉框架开源](https://github.com/CSU-FYT-Vision/FYT2024_vision),华南师范大学PIONEER战队@chenjunnn[rm_vision](https://github.com/chenjunnn/rm_vision)修改与适配,参考了深圳北理莫斯科大学北极熊战队/四川大学火锅战队/沈阳航空航天大学TUP战队/北京科技大学Reborn战队/同济大学superpower战队/河北科技大学Actor&Thinker战队的部分代码与模型,感谢以上开源为本队以及本人的帮助(排名不分先后) + +## 依赖 +* [wust_vl](https://github.com/WUST-RM/wust_vl) +* OpenCV +* [OpenVINO](https://flowus.cn/7a2a3341-74a1-4db9-bced-99fe5d05ab75)/[TensorRT-cuda](https://flowus.cn/e98af178-de0b-4546-808d-a6f1ff199d62)/[NCNN](https://flowus.cn/664f6bee-8ea9-4d54-8a78-e2c0bf38ee9f)/[OnnxRunetime](https://flowus.cn/8fbecbbf-c0f9-49bb-bac5-7b4923f55c99)连接为简单部署文档 +* fmt +* ceres +* Eigen3 +* nlohmann +* yaml-cpp +## 环境配置 +``` +./script/install_depences.sh +``` +## Quick Start +``` +git clone --recurse-submodules https://github.com/WUST-RM/wust_vision.git +cd wust_vision +sudo ./run.sh run xx /rebuild/build #编译并运行xx可执行文件/删除build缓存重新编译/仅编译 +``` +### 注意:本项目可选择编译OpenVINO/TensorRT-cuda/NCNN/OnnxRunetime,需在build缓存前在[CMakeLists.txt](CMakeLists.txt)中修改对应编译选项,修改后需rebuild重新编译,无OpenVINO/TensorRT-cuda/NCNN/OnnxRunetime环境仍可以使用OpenCV的装甲板识别,装甲板的识别方案需要在[config/auto_aim.yaml](config/auto_aim.yaml)中修改 +## 文件树 +``` +. +├── 3rdparty +│   ├── angles.h +│   └── backward-cpp +│ +├── cmake +│   ├── FindG2O.cmake +│   ├── FindHikSDK.cmake +│   ├── FindOrt.cmake +│   └── FindTensorRT.cmake +├── CMakeLists.txt +├── config +│   ├── auto_aim.yaml +│   ├── auto_buff.yaml +│   ├── auto_guidance.yaml +│   ├── auto_sniper.yaml +│   ├── camera_info.yaml +│   ├── camera.yaml +│   ├── common.yaml +│   ├── config -> /home/hy/wust_vision/config +│   ├── detect_ncnn.yaml +│   ├── detect_opencv.yaml +│   ├── detect_openvino.yaml +│   ├── detect_ort.yaml +│   ├── detect_trt.yaml +│   └── guard.sh +├── cuda_infer +│   ├── armor_infer.cu +│   ├── armor_infer.hpp +│   ├── CMakeLists.txt +│   ├── letter_box.cu +│   └── letter_box.hpp +├── env.bash +├── format.sh +├── KalmanHyLib +│   ├── adaptive_extended_kalman_filter.hpp +│   ├── error_state_extended_kalman_filter.hpp +│   ├── extended_kalman_filter.hpp +│   ├── kalman_hybird_lib.hpp +│   ├── README.md +│   └── unscented_kalman_filter.hpp +├── model +│ +├── README.md +├── read_shm_image_mmap_only.py +├── ros2 +│   ├── CMakeLists.txt +│   ├── ros2.cpp +│   └── ros2.hpp +├── run.sh +├── script +│   ├── install_depences.sh +│   ├── rsync.sh +│   ├── setup_devenv.sh +│   └── setup_service.sh +├── src +│   ├── dart.cpp +│   ├── hero.cpp +│   ├── sentry.cpp +│   ├── sim.cpp +│   └── standard.cpp +├── static +│   ├── 崇实战队logo图标.png +│   ├── css +│   │   └── style.css +│   ├── js +│   │   ├── chart_logic.js +│   │   ├── json_view.js +│   │   └── main.js +│   └── logo.JPG +├── tasks +│   ├── auto_aim +│   │   ├── armor_control +│   │   │   ├── aimer.cpp +│   │   │   ├── aimer.hpp +│   │   │   ├── planner.cpp +│   │   │   ├── planner.hpp +│   │   │   ├── shooter.cpp +│   │   │   ├── shooter.hpp +│   │   │   └── tinympc +│ │ │ +│   │   ├── armor_detect +│   │   │   ├── armor_detect_common.cpp +│   │   │   ├── armor_detect_common.hpp +│   │   │   ├── armor_detector_base.hpp +│   │   │   ├── armor_infer.cpp +│   │   │   ├── armor_infer.hpp +│   │   │   ├── armor_pose_estimator.cpp +│   │   │   ├── armor_pose_estimator.hpp +│   │   │   ├── detector_factory.hpp +│   │   │   ├── light_corner_corrector.cpp +│   │   │   ├── light_corner_corrector.hpp +│   │   │   ├── ncnn +│   │   │   │   ├── armor_detector_ncnn.cpp +│   │   │   │   ├── armor_detector_ncnn.hpp +│   │   │   │   ├── armor_detector_ncnn_wrapper.cpp +│   │   │   │   └── armor_detector_ncnn_wrapper.hpp +│   │   │   ├── number_classifier.cpp +│   │   │   ├── number_classifier.hpp +│   │   │   ├── onnxruntime +│   │   │   │   ├── armor_detector_onnxruntime.cpp +│   │   │   │   ├── armor_detector_onnxruntime.hpp +│   │   │   │   ├── armor_detector_onnxruntime_wrapper.cpp +│   │   │   │   └── armor_detector_onnxruntime_wrapper.hpp +│   │   │   ├── opencv +│   │   │   │   ├── armor_detector_opencv.cpp +│   │   │   │   ├── armor_detector_opencv.hpp +│   │   │   │   ├── armor_detector_opencv_wrapper.cpp +│   │   │   │   └── armor_detector_opencv_wrapper.hpp +│   │   │   ├── openvino +│   │   │   │   ├── armor_detector_openvino.cpp +│   │   │   │   ├── armor_detector_openvino.hpp +│   │   │   │   ├── armor_detector_openvino_wrapper.cpp +│   │   │   │   └── armor_detector_openvino_wrapper.hpp +│   │   │   └── tensorrt +│   │   │   ├── armor_detector_tensorrt.cpp +│   │   │   ├── armor_detector_tensorrt.hpp +│   │   │   ├── armor_detector_tensorrt_wrapper.cpp +│   │   │   └── armor_detector_tensorrt_wrapper.hpp +│   │   ├── armor_optimize +│   │   │   ├── ba_solver.cpp +│   │   │   └── ba_solver.hpp +│   │   ├── armor_tracker +│   │   │   ├── motion_models +│   │   │   │   ├── acc_model.hpp +│   │   │   │   ├── motion_modela.hpp +│   │   │   │   ├── motion_modelonea.hpp +│   │   │   │   ├── motion_modeloneca.hpp +│   │   │   │   ├── motion_modeloneypd.hpp +│   │   │   │   ├── motion_modelr.hpp +│   │   │   │   ├── motion_modelrypd.hpp +│   │   │   │   ├── motion_modelypd.hpp +│   │   │   │   └── motion_modelypdv2.hpp +│   │   │   ├── target.cpp +│   │   │   ├── target.hpp +│   │   │   ├── tracker_manager.cpp +│   │   │   ├── tracker_manager.hpp +│   │   │   ├── trackerv3.cpp +│   │   │   └── trackerv3.hpp +│   │   ├── auto_aim.cpp +│   │   ├── auto_aim_fsm.hpp +│   │   ├── auto_aim.hpp +│   │   ├── CMakeLists.txt +│   │   ├── type.cpp +│   │   └── type.hpp +│   ├── auto_buff +│   │   ├── auto_buff.cpp +│   │   ├── auto_buff.hpp +│   │   ├── CMakeLists.txt +│   │   ├── rune_control +│   │   │   ├── aimer.cpp +│   │   │   └── aimer.hpp +│   │   ├── rune_detector +│   │   │   ├── rune_detector.cpp +│   │   │   └── rune_detector.hpp +│   │   ├── rune_optimize +│   │   │   ├── ba_solver.cpp +│   │   │   └── ba_solver.hpp +│   │   ├── rune_tracker +│   │   │   ├── motion_models +│   │   │   │   └── motion_modelrypd.hpp +│   │   │   ├── rune_target.cpp +│   │   │   ├── rune_target.hpp +│   │   │   ├── rune_tracker.cpp +│   │   │   ├── rune_tracker.hpp +│   │   │   └── spd_fitter.hpp +│   │   ├── type.cpp +│   │   └── type.hpp +│   ├── auto_guidance +│   │   ├── auto_guidance.cpp +│   │   ├── auto_guidance.hpp +│   │   ├── CMakeLists.txt +│   │   ├── debug.cpp +│   │   ├── debug.hpp +│   │   ├── guidance_detector +│   │   │   ├── detector_base.hpp +│   │   │   ├── detector_factory.hpp +│   │   │   ├── green_light_infer.cpp +│   │   │   ├── green_light_infer.hpp +│   │   │   ├── opencv +│   │   │   │   ├── guidance_detector_opencv.cpp +│   │   │   │   └── guidance_detector_opencv.hpp +│   │   │   └── openvino +│   │   │   ├── guidance_detector_openvino.cpp +│   │   │   └── guidance_detector_openvino.hpp +│   │   ├── guidance_tracker +│   │   │   ├── guidance_target.cpp +│   │   │   ├── guidance_target.hpp +│   │   │   ├── guidance_tracker.cpp +│   │   │   ├── guidance_tracker.hpp +│   │   │   └── motion_models +│   │   │   └── imgbox_model.hpp +│   │   └── type.hpp +│   ├── auto_offset +│   │   ├── auto_offset.cpp +│   │   ├── auto_offset.hpp +│   │   └── CMakeLists.txt +│   ├── auto_sniper +│   │   ├── auto_sniper.cpp +│   │   ├── auto_sniper.hpp +│   │   ├── CMakeLists.txt +│   │   └── trajectory_solver.hpp +│   ├── CMakeLists.txt +│   ├── debug.cpp +│   ├── debug.hpp +│   ├── main_base.hpp +│   ├── packet_typedef.hpp +│   ├── sinple_img_rotate_saver.hpp +│   ├── type_common.cpp +│   ├── type_common.hpp +│   ├── utils.cpp +│   ├── utils.hpp +│   ├── vision_base.cpp +│   └── vision_base.hpp +├── templates +│   └── index.html +├── test +│   ├── control.cpp +│   └── test_usbcamera.cpp +└── web.py + + + +``` diff --git a/wust_vision-main/cmake/FindHikSDK.cmake b/wust_vision-main/cmake/FindHikSDK.cmake new file mode 100644 index 0000000..e07ba0d --- /dev/null +++ b/wust_vision-main/cmake/FindHikSDK.cmake @@ -0,0 +1,147 @@ +# -------------------------------------------------------------------------------------------- +# FindHikSDK.cmake +# +# This module finds the HikRobot / Hikvision MVS Camera SDK. +# +# It defines the following variables: +# HikSDK_FOUND +# HikSDK_INCLUDE_DIR +# HikSDK_LIB +# +# And the following imported target: +# hiksdk +# -------------------------------------------------------------------------------------------- + +# ========================= +# 1. SDK 根路径 +# ========================= +if(WIN32) + set(HikSDK_Path "$ENV{MVCAM_COMMON_RUNENV}") +else() + set(HikSDK_Path "$ENV{MVCAM_SDK_PATH}") +endif() + +if(NOT HikSDK_Path OR HikSDK_Path STREQUAL "") + message(STATUS "HikSDK: MVCAM_SDK_PATH is not set") + set(HikSDK_FOUND FALSE) + return() +endif() + +# ========================= +# 2. 查找头文件 +# ========================= +find_path( + HikSDK_INCLUDE_DIR + NAMES + MvCameraControl.h + CameraParams.h + PixelType.h + MvErrorDefine.h + MvISPErrorDefine.h + PATHS + "${HikSDK_Path}/include" + "${HikSDK_Path}/Includes" + NO_DEFAULT_PATH +) + +# ========================= +# 3. 查找库文件(关键修复点) +# ========================= +if(UNIX) + find_library( + HikSDK_LIB + NAMES + MvCameraControl + libMvCameraControl.so + PATHS + "${HikSDK_Path}/lib" + "${HikSDK_Path}/lib64" + "${HikSDK_Path}/lib/arm" + "${HikSDK_Path}/lib/arm64" + "${HikSDK_Path}/lib/aarch64" + "${HikSDK_Path}/lib/x86" + "${HikSDK_Path}/lib/x64" + "${HikSDK_Path}/lib/64" + "${HikSDK_Path}/lib/32" + NO_DEFAULT_PATH + ) +endif() + +# ========================= +# 4. Windows(完整但不影响 Linux) +# ========================= +if(WIN32) + find_library( + HikSDK_LIB + NAMES MvCameraControl + PATHS + "${HikSDK_Path}/Libraries" + "${HikSDK_Path}/Libraries/win64" + "${HikSDK_Path}/Libraries/win32" + NO_DEFAULT_PATH + ) + + find_file( + HikSDK_DLL + NAMES MvCameraControl.dll + PATHS + "${HikSDK_Path}/Runtime" + "C:/Program Files (x86)/Common Files/MVS/Runtime" + NO_DEFAULT_PATH + ) +endif() + +# ========================= +# 5. 创建导入库目标 +# ========================= +if(HikSDK_LIB AND HikSDK_INCLUDE_DIR) + + if(NOT TARGET HikSDK::HikSDK) + + add_library(HikSDK::HikSDK SHARED IMPORTED GLOBAL) + + if(WIN32) + set_target_properties(HikSDK::HikSDK PROPERTIES + IMPORTED_IMPLIB "${HikSDK_LIB}" + IMPORTED_LOCATION "${HikSDK_DLL}" + INTERFACE_INCLUDE_DIRECTORIES "${HikSDK_INCLUDE_DIR}" + ) + else() + set_target_properties(HikSDK::HikSDK PROPERTIES + IMPORTED_LOCATION "${HikSDK_LIB}" + INTERFACE_INCLUDE_DIRECTORIES "${HikSDK_INCLUDE_DIR}" + ) + endif() + + endif() + +endif() +set(HikSDK_LIBS HikSDK::HikSDK) +set(HikSDK_INCLUDE_DIRS ${HikSDK_INCLUDE_DIR}) + + +# ========================= +# 6. 标准 find_package 处理 +# ========================= +include(FindPackageHandleStandardArgs) + +find_package_handle_standard_args( + HikSDK + REQUIRED_VARS HikSDK_LIB HikSDK_INCLUDE_DIR +) + +# ========================= +# 7. 调试输出(很重要) +# ========================= +if(HikSDK_FOUND) + message(STATUS "HikSDK found:") + message(STATUS " Include dir : ${HikSDK_INCLUDE_DIR}") + message(STATUS " Library : ${HikSDK_LIB}") +else() + message(STATUS "HikSDK NOT found") +endif() + +set(HikSDK_LIBS hiksdk) +set(HikSDK_INCLUDE_DIRS ${HikSDK_INCLUDE_DIR}) + +mark_as_advanced(HikSDK_LIB HikSDK_INCLUDE_DIR HikSDK_DLL) diff --git a/wust_vision-main/cmake/FindOrt.cmake b/wust_vision-main/cmake/FindOrt.cmake new file mode 100644 index 0000000..ebe75bb --- /dev/null +++ b/wust_vision-main/cmake/FindOrt.cmake @@ -0,0 +1,88 @@ +# -------------------------------------------------------------------------------------------- +# This file is used to find the ONNX-Runtime SDK, which provides the following variables: +# +# Cache Variables: +# - Ort_HEADER_FILES: Names of SDK header files +# +# Advanced Variables: +# - Ort_INCLUDE_DIR: Directory where SDK header files are located +# - Ort_LIB: Path to the SDK library file (import library on Windows, shared library +# on Linux) +# - Ort_DLL: Path to the SDK dynamic library file (only on Windows) +# +# Local Variables: +# - Ort_LIBS: CMake target name for the SDK, which is "onnxruntime" +# - Ort_INCLUDE_DIRS: Directory where SDK header files are located +# -------------------------------------------------------------------------------------------- + +# ------------------------------------------------------------------------------ +# find onnxruntime root path +# ------------------------------------------------------------------------------ +if(NOT ort_root_path) + set(ort_root_path "/usr/local") +endif() + +# ------------------------------------------------------------------------------ +# find onnxruntime include directory +# ------------------------------------------------------------------------------ +set(Ort_HEADER_FILES + cpu_provider_factory.h onnxruntime_run_options_config_keys.h + onnxruntime_c_api.h onnxruntime_session_options_config_keys.h + onnxruntime_cxx_api.h provider_options.h + onnxruntime_cxx_inline.h + CACHE INTERNAL "ONNX Runtime header files" +) + +find_path( + Ort_INCLUDE_DIR + PATHS "${ort_root_path}/include" + NAMES ${Ort_HEADER_FILES} + NO_DEFAULT_PATH +) + +# ------------------------------------------------------------------------------ +# find onnxruntime library file +# ------------------------------------------------------------------------------ +find_library( + Ort_LIB + NAMES "libonnxruntime.so" + PATHS "${ort_root_path}/lib" + NO_DEFAULT_PATH +) + +# ------------------------------------------------------------------------------ +# create imported target: onnxruntime +# ------------------------------------------------------------------------------ +if(NOT TARGET onnxruntime) + add_library(onnxruntime SHARED IMPORTED) + set_target_properties(onnxruntime PROPERTIES + IMPORTED_LOCATION "${Ort_LIB}" + INTERFACE_INCLUDE_DIRECTORIES "${Ort_INCLUDE_DIR}" + ) +endif() + +mark_as_advanced(Ort_INCLUDE_DIR Ort_LIB) + +# ------------------------------------------------------------------------------ +# set onnxruntime cmake variables and version variables +# ------------------------------------------------------------------------------ +set(Ort_LIBS "onnxruntime") +set(Ort_INCLUDE_DIRS "${Ort_INCLUDE_DIR}") + +if(Ort_INCLUDE_DIR) + file(STRINGS "${Ort_INCLUDE_DIR}/onnxruntime_c_api.h" Ort_VERSION + REGEX "#define ORT_API_VERSION [0-9]+" + ) + string(REGEX REPLACE "#define ORT_API_VERSION ([0-9]+)" "1.\\1" Ort_VERSION "${Ort_VERSION}") +endif() + +# ------------------------------------------------------------------------------ +# handle the package +# ------------------------------------------------------------------------------ +include(FindPackageHandleStandardArgs) + +find_package_handle_standard_args( + Ort + VERSION_VAR Ort_VERSION + REQUIRED_VARS Ort_LIB Ort_INCLUDE_DIR +) \ No newline at end of file diff --git a/wust_vision-main/cmake/FindTensorRT.cmake b/wust_vision-main/cmake/FindTensorRT.cmake new file mode 100644 index 0000000..e57a6e8 --- /dev/null +++ b/wust_vision-main/cmake/FindTensorRT.cmake @@ -0,0 +1,66 @@ +# FindTensorRT.cmake -- Locate NVIDIA TensorRT + +include(FindPackageHandleStandardArgs) + +# TensorRT root +if (DEFINED TensorRT_ROOT) + list(APPEND _TensorRT_SEARCH_PATHS + ${TensorRT_ROOT} + "$ENV{TensorRT_ROOT}" + ) +endif() +list(APPEND _TensorRT_SEARCH_PATHS /usr /usr/local) + +# Header +find_path(TensorRT_INCLUDE_DIR + NAMES NvInfer.h + PATHS ${_TensorRT_SEARCH_PATHS} + PATH_SUFFIXES include +) + +# Core library +find_library(TensorRT_LIBRARY + NAMES nvinfer + PATHS ${_TensorRT_SEARCH_PATHS} + PATH_SUFFIXES lib lib64 lib/x64 +) + +find_package_handle_standard_args(TensorRT + REQUIRED_VARS TensorRT_INCLUDE_DIR TensorRT_LIBRARY +) + +if (TensorRT_FOUND) + set(TensorRT_INCLUDE_DIRS ${TensorRT_INCLUDE_DIR}) + set(TensorRT_LIBRARIES ${TensorRT_LIBRARY}) + + # Optional components + foreach(_comp IN ITEMS nvinfer_plugin nvonnxparser nvparsers) + find_library(TensorRT_${_comp}_LIBRARY + NAMES ${_comp} + PATHS ${_TensorRT_SEARCH_PATHS} + PATH_SUFFIXES lib lib64 lib/x64 + ) + if (TensorRT_${_comp}_LIBRARY) + list(APPEND TensorRT_LIBRARIES ${TensorRT_${_comp}_LIBRARY}) + endif() + endforeach() + + # Core target + add_library(TensorRT::TensorRT UNKNOWN IMPORTED) + set_target_properties(TensorRT::TensorRT PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${TensorRT_INCLUDE_DIRS}" + IMPORTED_LOCATION "${TensorRT_LIBRARY}" + ) + + # Component targets + foreach(_comp IN ITEMS nvinfer_plugin nvonnxparser nvparsers) + if (TensorRT_${_comp}_LIBRARY) + add_library(TensorRT::${_comp} UNKNOWN IMPORTED) + set_target_properties(TensorRT::${_comp} PROPERTIES + IMPORTED_LOCATION "${TensorRT_${_comp}_LIBRARY}" + ) + endif() + endforeach() + + message(STATUS "Found TensorRT at ${TensorRT_INCLUDE_DIR}") +endif() diff --git a/wust_vision-main/config/auto_aim.yaml b/wust_vision-main/config/auto_aim.yaml new file mode 100644 index 0000000..cdf0f25 --- /dev/null +++ b/wust_vision-main/config/auto_aim.yaml @@ -0,0 +1,131 @@ +armor_detect_backend: tensorrt +max_detect_armors: 11 + + +armor_map: + SENTRY: 4 + NO1: 2 + NO2: 3 + NO3: 1 + NO4: 5 + NO5: 9 + OUTPOST: 6 + BASE: 7 + UNKNOWN: -1 + +armor_where: + yaw_opt: + mode: golden + golden_search_side_deg: 60 + distance_fix_a2: 0.0 + +armor_tracker: + lost_time_thres: 2.0 + tracking_thres: 10 + max_yaw_diff_deg: 60.0 + max_dis_diff: 3.0 + match_gate: 50 + qxyz_common: [200.0, 200.0, 1.0] + qyaw_common: 100.0 + qxyz_output: [10.0, 10.0, 0.5] + qyaw_output: 0.01 + q_r: 0.0000001 + q_l: 0.0000001 + q_h: 0.0000001 + q_outpost_dz: 0.5 + yp_r: 2e-3 + dis_r_front: 0.5 + dis_r_side: 2.5 + dis2_r_ratio: 0.1 + yaw_r_base_front: 0.2 + yaw_r_base_side: 0.06 + yaw_r_log_ratio: 0.02 + esekf_iter_num: 5 + + + +auto_aim_fsm: + single_whole_up: 1.5 + single_whole_down: 1.0 + whole_pair_up: 7.5 + whole_pair_down: 6.5 #1s内至少换2次板子 2pi 否则轨迹规划无意义 + pair_center_up: 16.5 + pair_center_down: 15.0 + transfer_thresh: 50 + +very_aimer: + fuck_test: false + fuck_test_thresh: 0.5 + type: "seg" #seg or mpc + + sample_total_time: 2.0 + + sample_horizon: 500 + + control_delay: 0.2 + delay_enable_fire_error: 0.0035 + + max_yaw_acc: 40 + #全向yaw加速度测算37 + max_pitch_acc: 25 + + + prediction_delay: 0.00 + comming_angle: 60 + leaving_angle: 20 + + yaw_limit_deg: 60 + shooting_range_h: 0.12 + shooting_range_small_w: 0.12 + shooting_range_big_w: 0.24 + min_enable_pitch_deg: 0.25 + min_enable_yaw_deg: 0.25 + base_offset: + yaw: -1.0 + pitch: -3.5 + trajectory_offset: + - d_max: 3 + d_min: 0 + h_max: 1.5 + h_min: -1 + pitch_off: 0 + yaw_off: 0 + - d_max: 4.5 + d_min: 3 + h_max: 1.5 + h_min: -1 + pitch_off: 0 + yaw_off: 0 + - d_max: 4.5 + d_min: 9 + h_max: 1.5 + h_min: -1 + pitch_off: 0 + yaw_off: 0 + #---mpc only + max_iter: 10 + Q_yaw: [7e6, 0] + R_yaw: [3.0] + Q_pitch: [7e6, 0] + R_pitch: [3.0] + + + +trajectory_compensator: + compenstator_type: resistance + gravity: 9.8 + iteration_times: 20 + resistance: 0.092 + k1: 0.0190 #大弹丸double k1_c = 0.47; double k1 = k1_c * 1.169 * (2 * M_PI * 0.02125 * 0.02125) / 2 / 0.041; + + + +auto_exposure: + enable: true + target_brightness: 25.0 + tolerance: 3.0 + step_gain: 15.0 + decay_step: 1.0 + exposure_min: 100.0 + exposure_max: 1500.0 + control_interval_ms: 300 diff --git a/wust_vision-main/config/auto_buff.yaml b/wust_vision-main/config/auto_buff.yaml new file mode 100644 index 0000000..dfbb63a --- /dev/null +++ b/wust_vision-main/config/auto_buff.yaml @@ -0,0 +1,76 @@ +rune_detector: + rune_center_min_area: 10 + rune_center_max_area: 2000 + rune_center_1x1ratio_tol: 0.7 + rune_center_fill_ratio_min: 0.3 + + rune_target_min_area: 100 + rune_target_max_area: 3000 + rune_target_max_square_ratio: 1.3 + rune_target_cluster_radius: 70 + + bin_threshold: 50 + color_diff_threshold: 40 + + +rune_where: + roll_opt: + mode: golden + golden_search_side_deg: 60 + +rune_tracker: + lost_time_thres: 2.0 + tracking_thres: 5 + max_dis_diff: 1.0 + match_gate: 15.0 + esekf_iter_num: 2 + q_roll: 10.0 + q_xyz: 0.5 + q_yaw: 0.5 + yp_r: 0.01 + dis_r: 0.05 + yaw_r: 0.1 + roll_r: 0.2 + big_window_sec: 2.0 + +aimer: + prediction_delay: -0.00 + shooting_range_h: 0.1 + shooting_range_w: 0.1 + min_enable_pitch_deg: 0.25 + min_enable_yaw_deg: 0.25 + base_offset: + yaw: -3.0 + pitch: -0.0 + trajectory_offset: + - d_max: 9 + d_min: 0 + h_max: 1.5 + h_min: -1 + pitch_off: -0 + yaw_off: -0 + - d_max: 10 + d_min: 9 + h_max: 0.4 + h_min: -1 + pitch_off: -0 + yaw_off: 0 + +trajectory_compensator: + compenstator_type: resistance + gravity: 9.8 + iteration_times: 20 + resistance: 0.092 + k1: 0.0190 #大弹丸double k1_c = 0.47; double k1 = k1_c * 1.169 * (2 * M_PI * 0.02125 * 0.02125) / 2 / 0.041; + + + +auto_exposure: + enable: false + target_brightness: 15.0 + tolerance: 3.0 + step_gain: 15.0 + decay_step: 1.0 + exposure_min: 100.0 + exposure_max: 2500.0 + control_interval_ms: 300 diff --git a/wust_vision-main/config/auto_guidance.yaml b/wust_vision-main/config/auto_guidance.yaml new file mode 100644 index 0000000..36bdbc2 --- /dev/null +++ b/wust_vision-main/config/auto_guidance.yaml @@ -0,0 +1,83 @@ +backend: opencv + +max_infer_running: 6 + +detector: + openvino: + model_path: /home/hy/2026_dart_guide/assets/Katrin.xml + top_k: 128 + conf_threshold: 0.2 + nms_threshold: 0.5 + device_type: CPU + use_throughputmode: false + opencv: + gui: true + HSV: + lowH: 25 + highH: 85 + lowS: 50 + highS: 255 + lowV: 150 + highV: 255 + contours: + min_area: 100 + max_area: 50000 + min_aspect_ratio: 0.5 + min_fill_ratio: 0.7 + +tracker: + lost_time_thres: 2.0 + tracking_thres: 10 + target: + xy_r: 0.05 + wh_r: 0.05 + q_xy: 100 + q_wh: 100 + iter_num: 2 + max_dis_diff: 2.0 + + +logger: + log_level: INFO + log_path: /home/hy/wust_log + use_logcli: true + use_logfile: true + use_simplelog: true + +control: + control_rate: 100 + communication_delay_us: 1100 + pitch_avg_windows: 1 + device_name: /dev/ttyACM_RMc + use_serial: true + + + +camera: + type: hik_camera + camera_info_path: ${VISION_ROOT}/config/camera_info.yaml + + hik_camera: + target_sn: DA1094119 #DA3038891 + acquisition_frame_rate: 80 + adc_bit_depth: Bits_12 + exposure_time: 5000 + gain: 16.9 + gamma: 0.5 #Bayer用不了 + pixel_format: BayerRG8 + use_raw: false + acquisition_frame_rate_enable: false + reverse_x: false + reverse_y: false + + video_player: + fps: 12 + loop: true + use_cvt: true + #path: /home/hy/wust_data/video/aaa.mp4 + #path: /home/hy/wust_data/video/jiao.avi + #path: /home/hy/Z_LION_AutoAim2025/src/auto_aim_publisher/videos/sample.avi + #path: /home/hy/wust_data/record/20251111_134416_187.avi + path: /home/hy/下载/sp_vision_25/records/1.avi + #path: /home/hy/wust_data/video/Sentry.mp4 + start_frame: 0 diff --git a/wust_vision-main/config/auto_sniper.yaml b/wust_vision-main/config/auto_sniper.yaml new file mode 100644 index 0000000..c1a1ea8 --- /dev/null +++ b/wust_vision-main/config/auto_sniper.yaml @@ -0,0 +1,27 @@ +map: + voxel_size: 0.1 + min_pos: [-5.0,-10.0,-1.0] + max_pos: [20.0,10.0,1.0] + + +solver: + k1: 0.0190 + g: 9.81 + target_armor_z: 0.5 + +offset_helper: + order: 5 + yaw_base_offset: 0.0 + pitch_base_offset: 0.0 + offset_table: + - distance: 2.0 + yaw: 0.1 + pitch: -0.2 + + - distance: 3.0 + yaw: 0.15 + pitch: -0.25 + + - distance: 4.0 + yaw: 0.2 + pitch: -0.3 \ No newline at end of file diff --git a/wust_vision-main/config/camera.yaml b/wust_vision-main/config/camera.yaml new file mode 100644 index 0000000..5c41a14 --- /dev/null +++ b/wust_vision-main/config/camera.yaml @@ -0,0 +1,50 @@ +type: hik_camera + +camera_info_path: ${VISION_ROOT}/config/camera_info.yaml + +hik_camera: + target_sn: DA3038934 + adc_bit_depth: Bits_8 + pixel_format: BayerRG8 #RGB8Packed + reverse_x: false + reverse_y: false + width: 1440 + height: 1080 + offset_x: 0 + offset_y: 0 + acquisition_frame_rate_enable: true + acquisition_frame_rate: 250 + exposure_time: 2000 + gain: 16.9 + gamma: 0.7 #Bayer用不了 + trigger_type: none + trigger_source: "" + trigger_activation: 0 + use_raw: false + use_rgb: false + use_ea: false + use_cuda_cvt: true + + +video_player: + fps: 60 + loop: true + use_cvt: true + trigger_mode: false + #path: /home/hy/wust_data/video/aaa.mp4 + #path: /home/hy/wust_data/video/jiao.avi + #path: /home/hy/Z_LION_AutoAim2025/src/auto_aim_publisher/videos/sample.avi + #path: /home/hy/wust_data/record/20251111_134416_187.avi + path: /home/hy/data/video_save/5.21成都.mp4 + #path: /home/hy/wust_data/video/Sentry.mp4 + start_frame: 0 + +uvc: + device_name: "/dev/v4l/by-path/pci-0000:00:14.0-usb-0:3.4:1.0-video-index0" + fps: 60 + width: 1280 + height: 720 + exposure: 100.0 + gain: 10.0 + gamma: 100 + trigger_mode: false \ No newline at end of file diff --git a/wust_vision-main/config/camera_info.yaml b/wust_vision-main/config/camera_info.yaml new file mode 100644 index 0000000..910d4bb --- /dev/null +++ b/wust_vision-main/config/camera_info.yaml @@ -0,0 +1,82 @@ +#8mm +# image_width: 1440 +# image_height: 1080 +# camera_name: narrow_stereo +# camera_matrix: +# rows: 3 +# cols: 3 +# data: [2419.64498, 0. , 719.05623, +# 0. , 2414.38209, 538.71651, +# 0. , 0. , 1. ] +# distortion_model: plumb_bob +# distortion_coefficients: +# rows: 1 +# cols: 5 +# data: [-0.033521, 0.087954, -0.002045, -0.001438, 0.000000] +# rectification_matrix: +# rows: 3 +# cols: 3 +# data: [1., 0., 0., +# 0., 1., 0., +# 0., 0., 1.] +# projection_matrix: +# rows: 3 +# cols: 4 +# data: [2413.22632, 0. , 717.5196 , 0. , +# 0. , 2408.83228, 537.43444, 0. , +# 0. , 0. , 1. , 0. ] +#6mm + +image_width: 1440 +image_height: 1080 +camera_name: narrow_stereo +camera_matrix: + rows: 3 + cols: 3 + data: [1814.19888, 0. , 730.27359, + 0. , 1807.37177, 529.57031, + 0. , 0. , 1. ] +distortion_model: plumb_bob +distortion_coefficients: + rows: 1 + cols: 5 + data: [-0.089138, 0.114766, 0.000034, 0.000238, 0.000000] +rectification_matrix: + rows: 3 + cols: 3 + data: [1., 0., 0., + 0., 1., 0., + 0., 0., 1.] +projection_matrix: + rows: 3 + cols: 4 + data: [1793.54229, 0. , 730.67194, 0. , + 0. , 1794.48784, 529.46193, 0. , + 0. , 0. , 1. , 0. ] +#12mm +# image_width: 1440 +# image_height: 1080 +# camera_name: narrow_stereo +# camera_matrix: +# rows: 3 +# cols: 3 +# data: [3655.11292, 0. , 748.66803, +# 0. , 3648.34439, 556.97683, +# 0. , 0. , 1. ] +# distortion_model: plumb_bob +# distortion_coefficients: +# rows: 1 +# cols: 5 +# data: [-0.058763, 1.430269, 0.001191, 0.001641, 0.000000] +# rectification_matrix: +# rows: 3 +# cols: 3 +# data: [1., 0., 0., +# 0., 1., 0., +# 0., 0., 1.] +# projection_matrix: +# rows: 3 +# cols: 4 +# data: [3659.22949, 0. , 748.71496, 0. , +# 0. , 3652.02661, 556.81908, 0. , +# 0. , 0. , 1. , 0. ] diff --git a/wust_vision-main/config/common.yaml b/wust_vision-main/config/common.yaml new file mode 100644 index 0000000..0b8a790 --- /dev/null +++ b/wust_vision-main/config/common.yaml @@ -0,0 +1,34 @@ +max_infer_running: 6 +detect_color: 0 +attack_mode: 0 +debug_fps: 60 +tf: + R_camera2gimbal: [0.0, 0.0, 1.0, -1.0, -0.0, 0.0, 0.0, -1.0, 0.0] + t_camera2gimbal: + [0.000434347168653831, 0.0141232476052895, 0.00736106400231024] +control: + control_rate: 1000 + communication_delay_us: 100.000 + device_name: /dev/stm32_acm + use_serial: true + yaw_ramp: 0.0 #send_cmd = cmd.pos+cmd.vel*ramp + pitch_ramp: 0.0 + +logger: + log_level: DEBUG + log_path: log + use_logcli: true + use_logfile: false + use_simplelog: true + +shoot: + bullet_speed: 26.0 + rate: 3 + +record: + use_record: false + folder_path: record + use_rotate_reader: false + read_csv_path: "" + + diff --git a/wust_vision-main/config/detect_ml.yaml b/wust_vision-main/config/detect_ml.yaml new file mode 100644 index 0000000..942de76 --- /dev/null +++ b/wust_vision-main/config/detect_ml.yaml @@ -0,0 +1,66 @@ +armor_detector: + cv: #装甲板灯条完整无损坏才可以用 + enable: true + armor: + max_angle: 40 + max_large_center_distance: 8 + max_small_center_distance: 3.5 + min_large_center_distance: 3.5 + min_light_ratio: 0.5 + min_small_center_distance: 0.05 + light: + binary_thres: 150 + expand_ratio_h: 1.9 + expand_ratio_w: 1.2 + max_angle: 45 + max_ratio: 0.4 + min_ratio: 0.01 + max_pts_error: 20.0 + max_angle_diff: 20.0 + color_diff_thresh: 40 + classify: + enable: true + label_path: ${VISION_ROOT}/model/label.txt + model_path: ${VISION_ROOT}/model/mlp.onnx + backend: opencv + threshold: 0.5 + + tensorrt: + conf_threshold: 0.2 + model_path: ${VISION_ROOT}/model/opt-1208-001.onnx + device_id: 0 + nms_threshold: 0.3 + top_k: 128 + max_infer_running: 3 + min_free_mem_ratio: 0.1 + use_cuda_pre: true + log_time: false + model_type: tup + openvino: + conf_threshold: 0.2 + device_name: GPU + model_path: ${VISION_ROOT}/model/opt-1208-001.onnx + nms_threshold: 0.3 + top_k: 128 + use_throughputmode: true + model_type: tup + onnxruntime: + conf_threshold: 0.2 + model_path: ${VISION_ROOT}/model/opt-1208-001.onnx + nms_threshold: 0.3 + top_k: 128 + model_type: tup + provider: CUDA + ncnn: + conf_threshold: 0.2 + model_path_param: ${VISION_ROOT}/model/opt-1208-001.param + model_path_bin: ${VISION_ROOT}/model/opt-1208-001.bin + input_name: images + output_name: output + use_gpu: true + use_lightmode: false + device_id: 0 + cpu_threads: 20 + nms_threshold: 0.3 + top_k: 128 + model_type: tup diff --git a/wust_vision-main/config/detect_opencv.yaml b/wust_vision-main/config/detect_opencv.yaml new file mode 100644 index 0000000..3025688 --- /dev/null +++ b/wust_vision-main/config/detect_opencv.yaml @@ -0,0 +1,19 @@ +armor_detector: + light: + binary_thres: 120 + max_angle: 40 + max_ratio: 0.4 + min_ratio: 0.001 + color_diff_thresh: 20 + max_angle_diff: 10.0 + armor: + min_light_ratio: 0.8 + min_small_center_distance: 0.8 + max_small_center_distance: 3.5 + min_large_center_distance: 3.5 + max_large_center_distance: 8.0 + max_angle: 40.0 + classify: + label_path: ${VISION_ROOT}/model/label.txt + model_path: ${VISION_ROOT}/model/reborn_number_classifier.onnx + threshold: 0.5 diff --git a/wust_vision-main/config/guard.sh b/wust_vision-main/config/guard.sh new file mode 100755 index 0000000..e000b74 --- /dev/null +++ b/wust_vision-main/config/guard.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +TARGET="$1" +shift +ARGS=("${@:3}") + + +echo "[GUARD] target: $TARGET" +echo "[GUARD] args: ${ARGS[*]}" +echo "[GUARD] Starting monitor..." + +while true; do + echo "[GUARD] Launching program..." + "$TARGET" "${ARGS[@]}" + RET=$? + + if [ $RET -eq 0 ]; then + echo "[GUARD] Program exited normally. Stopping guard." + exit 0 + fi + + echo "[GUARD] Crash detected, restarting in 1 second..." + sleep 1 +done diff --git a/wust_vision-main/config/omni/camera0.yaml b/wust_vision-main/config/omni/camera0.yaml new file mode 100644 index 0000000..158587f --- /dev/null +++ b/wust_vision-main/config/omni/camera0.yaml @@ -0,0 +1,51 @@ +yaw_in_big_yaw_deg: 90.0 +type: uvc + +camera_info_path: ${VISION_ROOT}/config/camera_info.yaml + +hik_camera: + target_sn: DA1094119 + adc_bit_depth: Bits_8 + pixel_format: BayerRG8 #RGB8Packed + reverse_x: false + reverse_y: false + width: 1440 + height: 1080 + offset_x: 0 + offset_y: 0 + acquisition_frame_rate_enable: true + acquisition_frame_rate: 250 + exposure_time: 2000 + gain: 16.9 + gamma: 0.7 #Bayer用不了 + trigger_type: none + trigger_source: "" + trigger_activation: 0 + use_raw: false + use_rgb: false + use_ea: false + use_cuda_cvt: false + + +video_player: + fps: 60 + loop: true + use_cvt: true + trigger_mode: false + #path: /home/hy/wust_data/video/aaa.mp4 + #path: /home/hy/wust_data/video/jiao.avi + #path: /home/hy/Z_LION_AutoAim2025/src/auto_aim_publisher/videos/sample.avi + #path: /home/hy/wust_data/record/20251111_134416_187.avi + path: /home/hy/rune_dl/runeA.mp4 + #path: /home/hy/wust_data/video/Sentry.mp4 + start_frame: 0 + +uvc: + device_name: "/dev/v4l/by-path/pci-0000:00:14.0-usb-0:9.4:1.0-video-index0" + fps: 30 + width: 1280 + height: 720 + exposure: 200 + gain: 10.0 + gamma: 100 + trigger_mode: true \ No newline at end of file diff --git a/wust_vision-main/config/omni/camera1.yaml b/wust_vision-main/config/omni/camera1.yaml new file mode 100644 index 0000000..04832c5 --- /dev/null +++ b/wust_vision-main/config/omni/camera1.yaml @@ -0,0 +1,51 @@ +yaw_in_big_yaw_deg: 107.0 +type: uvc + +camera_info_path: ${VISION_ROOT}/config/omni/camera_info.yaml + +hik_camera: + target_sn: DA1094119 + adc_bit_depth: Bits_8 + pixel_format: BayerRG8 #RGB8Packed + reverse_x: false + reverse_y: false + width: 1440 + height: 1080 + offset_x: 0 + offset_y: 0 + acquisition_frame_rate_enable: true + acquisition_frame_rate: 250 + exposure_time: 2000 + gain: 16.9 + gamma: 0.7 #Bayer用不了 + trigger_type: none + trigger_source: "" + trigger_activation: 0 + use_raw: false + use_rgb: false + use_ea: false + use_cuda_cvt: false + + +video_player: + fps: 60 + loop: true + use_cvt: true + trigger_mode: false + #path: /home/hy/wust_data/video/aaa.mp4 + #path: /home/hy/wust_data/video/jiao.avi + #path: /home/hy/Z_LION_AutoAim2025/src/auto_aim_publisher/videos/sample.avi + #path: /home/hy/wust_data/record/20251111_134416_187.avi + path: /home/hy/rune_dl/runeA.mp4 + #path: /home/hy/wust_data/video/Sentry.mp4 + start_frame: 0 + +uvc: + device_name: "/dev/v4l/by-path/pci-0000:00:14.0-usb-0:9.1:1.0-video-index0" + fps: 30 + width: 1280 + height: 720 + exposure: 200 + gain: 10.0 + gamma: 100 + trigger_mode: true \ No newline at end of file diff --git a/wust_vision-main/config/omni/camera_info.yaml b/wust_vision-main/config/omni/camera_info.yaml new file mode 100644 index 0000000..1ede5f6 --- /dev/null +++ b/wust_vision-main/config/omni/camera_info.yaml @@ -0,0 +1,28 @@ + +# camera_matrix = [759.73071, 0.0, 336.19053, 0.0, 761.16771, 231.83002, 0.0, 0.0, 1.0] +# distortion_coefficients = [0.000281, 0.144018, 0.005509, -0.004330, 0.0] +image_width: 1440 +image_height: 1080 +camera_name: narrow_stereo +camera_matrix: + rows: 3 + cols: 3 + data: [759.73071, 0.0, 336.19053, 0.0, 761.16771, 231.83002, 0.0, 0.0, 1.0 ] +distortion_model: plumb_bob +distortion_coefficients: + rows: 1 + cols: 5 + data: [0.000281, 0.144018, 0.005509, -0.004330, 0.0] +rectification_matrix: + rows: 3 + cols: 3 + data: [1., 0., 0., + 0., 1., 0., + 0., 0., 1.] +projection_matrix: + rows: 3 + cols: 4 + data: [1793.54229, 0. , 730.67194, 0. , + 0. , 1794.48784, 529.46193, 0. , + 0. , 0. , 1. , 0. ] + diff --git a/wust_vision-main/config/omni/detect_ml.yaml b/wust_vision-main/config/omni/detect_ml.yaml new file mode 100644 index 0000000..6b6d9e6 --- /dev/null +++ b/wust_vision-main/config/omni/detect_ml.yaml @@ -0,0 +1,66 @@ +armor_detector: + cv: #装甲板灯条完整无损坏才可以用 + enable: false + armor: + max_angle: 40 + max_large_center_distance: 8 + max_small_center_distance: 3.5 + min_large_center_distance: 3.5 + min_light_ratio: 0.5 + min_small_center_distance: 0.05 + light: + binary_thres: 150 + expand_ratio_h: 1.9 + expand_ratio_w: 1.2 + max_angle: 45 + max_ratio: 0.4 + min_ratio: 0.01 + max_pts_error: 20.0 + max_angle_diff: 20.0 + color_diff_thresh: 0 + classify: + enable: true + label_path: ${VISION_ROOT}/model/label.txt + model_path: ${VISION_ROOT}/model/mlp.onnx + backend: opencv + threshold: 0.5 + + tensorrt: + conf_threshold: 0.2 + model_path: ${VISION_ROOT}/model/opt-1208-001.onnx + device_id: 0 + nms_threshold: 0.3 + top_k: 128 + max_infer_running: 3 + min_free_mem_ratio: 0.1 + use_cuda_pre: true + log_time: false + model_type: tup + openvino: + conf_threshold: 0.2 + device_name: GPU + model_path: ${VISION_ROOT}/model/opt-1208-001.onnx + nms_threshold: 0.3 + top_k: 128 + use_throughputmode: true + model_type: tup + onnxruntime: + conf_threshold: 0.2 + model_path: ${VISION_ROOT}/model/opt-1208-001.onnx + nms_threshold: 0.3 + top_k: 128 + model_type: tup + provider: CUDA + ncnn: + conf_threshold: 0.2 + model_path_param: ${VISION_ROOT}/model/opt-1208-001.param + model_path_bin: ${VISION_ROOT}/model/opt-1208-001.bin + input_name: images + output_name: output + use_gpu: true + use_lightmode: false + device_id: 0 + cpu_threads: 20 + nms_threshold: 0.3 + top_k: 128 + model_type: tup diff --git a/wust_vision-main/config/omni/detect_opencv.yaml b/wust_vision-main/config/omni/detect_opencv.yaml new file mode 100644 index 0000000..3025688 --- /dev/null +++ b/wust_vision-main/config/omni/detect_opencv.yaml @@ -0,0 +1,19 @@ +armor_detector: + light: + binary_thres: 120 + max_angle: 40 + max_ratio: 0.4 + min_ratio: 0.001 + color_diff_thresh: 20 + max_angle_diff: 10.0 + armor: + min_light_ratio: 0.8 + min_small_center_distance: 0.8 + max_small_center_distance: 3.5 + min_large_center_distance: 3.5 + max_large_center_distance: 8.0 + max_angle: 40.0 + classify: + label_path: ${VISION_ROOT}/model/label.txt + model_path: ${VISION_ROOT}/model/reborn_number_classifier.onnx + threshold: 0.5 diff --git a/wust_vision-main/config/omni/omni.yaml b/wust_vision-main/config/omni/omni.yaml new file mode 100644 index 0000000..9e736d7 --- /dev/null +++ b/wust_vision-main/config/omni/omni.yaml @@ -0,0 +1,37 @@ +armor_detect_backend: openvino + +cameras: ["${VISION_ROOT}/config/omni/camera0.yaml", "${VISION_ROOT}/config/omni/camera1.yaml"] + +fps: 60 +max_infer_running: 3 +active_time: 1.0 +min_score: 15.0 + +armor_where: + yaw_opt: + mode: golden + golden_search_side_deg: 60 + distance_fix_a2: 0.0 + +armor_tracker: + lost_time_thres: 1.0 + tracking_thres: 5 + max_yaw_diff_deg: 90.0 + max_dis_diff: 3.0 + match_gate: 50 + qxyz_common: [1.0, 1.0, 1.0] + qyaw_common: 1.0 + qxyz_output: [1.0, 1.0, 0.5] + qyaw_output: 0.01 + q_r: 0.0000001 + q_l: 0.0000001 + q_h: 0.0000001 + q_outpost_dz: 0.5 + yp_r: 2e-3 + dis_r_front: 0.5 + dis_r_side: 2.5 + dis2_r_ratio: 0.1 + yaw_r_base_front: 0.2 + yaw_r_base_side: 0.06 + yaw_r_log_ratio: 0.02 + esekf_iter_num: 1 \ No newline at end of file diff --git a/wust_vision-main/cuda_infer/CMakeLists.txt b/wust_vision-main/cuda_infer/CMakeLists.txt new file mode 100644 index 0000000..69ceec2 --- /dev/null +++ b/wust_vision-main/cuda_infer/CMakeLists.txt @@ -0,0 +1,61 @@ +cmake_minimum_required(VERSION 3.10) +cmake_policy(SET CMP0079 NEW) + +project(cuda_infer LANGUAGES CXX CUDA) + +# 设置标准 +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CUDA_STANDARD_REQUIRED ON) +set(CUDA_USE_STATIC_CUDA_RUNTIME OFF) +set(CMAKE_BUILD_TYPE "Release") +# 抑制过时 API 警告 +add_compile_options(-Wno-deprecated-declarations) + +# 禁用 .rsp 响应文件(避免 nvcc 报错) +set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_OBJECTS OFF) +set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_INCLUDES OFF) + +# 查找依赖 +find_package(CUDAToolkit REQUIRED) +find_package(OpenCV REQUIRED) +find_package(Eigen3 REQUIRED) + +# 收集源码 +file(GLOB_RECURSE CUDA_INFER_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/*.cu +) + +# 添加静态库 +add_library(cuda_infer STATIC ${CUDA_INFER_SRC}) +set_target_properties(cuda_infer PROPERTIES POSITION_INDEPENDENT_CODE ON) + +# 设置包含路径 +target_include_directories(cuda_infer PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR} + ${CUDAToolkit_INCLUDE_DIRS} + ${TensorRT_INCLUDE_DIR} + ${EIGEN3_INCLUDE_DIRS} + ${OpenCV_INCLUDE_DIRS} +) + +# 设置 CUDA 编译选项 +target_compile_options(cuda_infer PRIVATE + $<$: + --generate-code=arch=compute_86,code=sm_86 + -Xcompiler=-fPIC + -O3 + -w + -Wno-deprecated-gpu-targets + -Wno-error=deprecated-declarations + > +) + +# 链接库 +target_link_libraries(cuda_infer PRIVATE + ${OpenCV_LIBS} + CUDA::cudart + TensorRT::TensorRT +) diff --git a/wust_vision-main/cuda_infer/armor_infer.cu b/wust_vision-main/cuda_infer/armor_infer.cu new file mode 100644 index 0000000..e8c2940 --- /dev/null +++ b/wust_vision-main/cuda_infer/armor_infer.cu @@ -0,0 +1,323 @@ +// armor_cuda_infer.cu +#include "armor_infer.hpp" +#include "letter_box.hpp" +#include +#include +#include +#include +#include +#include +#define CUDA_CHECK(call) \ + do { \ + cudaError_t err = call; \ + if (err != cudaSuccess) { \ + fprintf( \ + stderr, \ + "CUDA error at %s:%d: %s\n", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString(err) \ + ); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) +namespace armor_cuda_infer { +__global__ void nchw_float_to_hwc_uchar4( + const float* __restrict__ src, + uchar4* __restrict__ dst, + int W, + int H, + float norm +) { + const int x = blockIdx.x * blockDim.x + threadIdx.x; + const int y = blockIdx.y * blockDim.y + threadIdx.y; + if (x >= W || y >= H) + return; + + const int idx = y * W + x; + const int plane = W * H; + + float r = __ldg(src + idx + plane * 0); + float g = __ldg(src + idx + plane * 1); + float b = __ldg(src + idx + plane * 2); + + r = fminf(fmaxf(r / norm, 0.f), 255.f); + g = fminf(fmaxf(g / norm, 0.f), 255.f); + b = fminf(fmaxf(b / norm, 0.f), 255.f); + + dst[idx] = make_uchar4((unsigned char)b, (unsigned char)g, (unsigned char)r, 255); +} + +cv::Mat CudaInfer::tensorToMat(float* d_nchw, int W, int H, float norm, cudaStream_t stream) const { + static uchar4* d_hwc = nullptr; + static size_t cap = 0; + + const size_t need = W * H * sizeof(uchar4); + if (cap < need) { + if (d_hwc) + cudaFree(d_hwc); + cudaMalloc(&d_hwc, need); + cap = need; + } + + const dim3 block(TILE_W, TILE_H); + const dim3 grid((W + block.x - 1) / block.x, (H + block.y - 1) / block.y); + + nchw_float_to_hwc_uchar4<<>>(d_nchw, d_hwc, W, H, norm); + + cv::Mat img(H, W, CV_8UC4); + + cudaMemcpyAsync(img.data, d_hwc, need, cudaMemcpyDeviceToHost, stream); + + // cudaStreamSynchronize(stream); + return img; +} + +CudaInfer::CudaInfer() = default; +CudaInfer::~CudaInfer() { + release(); +} + +void CudaInfer::init(int max_src_w, int max_src_h, int input_w, int input_h) { + input_w_ = input_w; + input_h_ = input_h; + max_src_h_ = max_src_h; + max_src_w_ = max_src_w; + rellocMem(); +} +void CudaInfer::rellocMem() { + CUDA_CHECK(cudaMalloc(&d_input_bgr_, max_src_w_ * max_src_h_ * 3 * sizeof(unsigned char))); + CUDA_CHECK(cudaMallocPitch( + &d_input_bgr_pitched_, + &input_pitch_bytes_, + max_src_w_ * 3 * sizeof(unsigned char), + max_src_h_ + )); + CUDA_CHECK(cudaMalloc(&d_nchw_, input_w_ * input_h_ * 3 * sizeof(float))); + printf("Relloc memory for CudaInfer\n"); +} +void CudaInfer::getOutEnoughMem(int img_w, int img_h) { + if (img_w > max_src_w_ || img_h > max_src_h_) { + if (img_w > max_src_w_) { + max_src_w_ = img_w; + } + if (img_h > max_src_h_) { + max_src_h_ = img_h; + } + rellocMem(); + } +} + +void CudaInfer::release() { + if (d_input_bgr_) + cudaFree(d_input_bgr_), d_input_bgr_ = nullptr; + if (d_input_bgr_pitched_) + cudaFree(d_input_bgr_pitched_), d_input_bgr_pitched_ = nullptr; + if (d_nchw_) + cudaFree(d_nchw_), d_nchw_ = nullptr; +} + +float* CudaInfer::preprocess( + const unsigned char* input_bgr_host, + int img_w, + int img_h, + float norm, + bool swap_rb, + Eigen::Matrix3f& tf_matrix, + cudaStream_t stream +) { + if (!isInitialized()) { + throw std::runtime_error("CudaInfer not initialized properly."); + } + + if (!input_bgr_host || !d_input_bgr_ || !d_nchw_) { + fprintf(stderr, "[Error] Null pointer in preprocess input\n"); + return nullptr; + } + getOutEnoughMem(img_w, img_h); + float scale = fminf(input_w_ / (float)img_w, input_h_ / (float)img_h); + int rw = round(img_w * scale), rh = round(img_h * scale); + int pad_l = (input_w_ - rw) / 2, pad_t = (input_h_ - rh) / 2; + + tf_matrix << 1.f / scale, 0, -pad_l / scale, 0, 1.f / scale, -pad_t / scale, 0, 0, 1; + + size_t img_size = img_w * img_h * 3; + CUDA_CHECK( + cudaMemcpyAsync(d_input_bgr_, input_bgr_host, img_size, cudaMemcpyHostToDevice, stream) + ); + + dim3 threads(TILE_W, TILE_H); + dim3 blocks((input_w_ + TILE_W - 1) / TILE_W, (input_h_ + TILE_H - 1) / TILE_H); + + letterbox_kernel_shared<<>>( + d_input_bgr_, + img_w, + img_h, + d_nchw_, + input_w_, + input_h_, + scale, + pad_t, + pad_l, + norm, + swap_rb + ); + + CUDA_CHECK(cudaGetLastError()); + return d_nchw_; +} +float* CudaInfer::preprocess_gpu( + const unsigned char* input_bgr_device, + int img_w, + int img_h, + float norm, + bool swap_rb, + Eigen::Matrix3f& tf_matrix, + cudaStream_t stream +) { + if (!isInitialized()) { + throw std::runtime_error("CudaInfer not initialized properly."); + } + + if (!input_bgr_device || !d_nchw_) { + fprintf(stderr, "[Error] Null pointer in preprocess input\n"); + return nullptr; + } + getOutEnoughMem(img_w, img_h); + float scale = fminf(input_w_ / (float)img_w, input_h_ / (float)img_h); + int rw = round(img_w * scale), rh = round(img_h * scale); + int pad_l = (input_w_ - rw) / 2, pad_t = (input_h_ - rh) / 2; + + tf_matrix << 1.f / scale, 0, -pad_l / scale, 0, 1.f / scale, -pad_t / scale, 0, 0, 1; + + size_t img_size = img_w * img_h * 3; + + dim3 threads(TILE_W, TILE_H); + dim3 blocks((input_w_ + TILE_W - 1) / TILE_W, (input_h_ + TILE_H - 1) / TILE_H); + + letterbox_kernel_shared<<>>( + input_bgr_device, + img_w, + img_h, + d_nchw_, + input_w_, + input_h_, + scale, + pad_t, + pad_l, + norm, + swap_rb + ); + + CUDA_CHECK(cudaGetLastError()); + return d_nchw_; +} +float* CudaInfer::preprocess_pitched( + const unsigned char* input_bgr_host, + int img_w, + int img_h, + int host_step, + float norm, + bool swap_rb, + Eigen::Matrix3f& tf_matrix, + cudaStream_t stream +) { + if (!isInitialized()) { + throw std::runtime_error("CudaInfer not initialized properly."); + } + + if (!input_bgr_host || !d_nchw_) { + fprintf(stderr, "[Error] Null pointer in preprocess input\n"); + return nullptr; + } + getOutEnoughMem(img_w, img_h); + float scale = fminf((float)input_w_ / img_w, (float)input_h_ / img_h); + int rw = round(img_w * scale); + int rh = round(img_h * scale); + int pad_l = (input_w_ - rw) / 2; + int pad_t = (input_h_ - rh) / 2; + tf_matrix << 1.f / scale, 0, -pad_l / scale, 0, 1.f / scale, -pad_t / scale, 0, 0, 1; + CUDA_CHECK(cudaMemcpy2DAsync( + d_input_bgr_pitched_, + input_pitch_bytes_, + input_bgr_host, + host_step, + img_w * 3, + img_h, + cudaMemcpyHostToDevice, + stream + )); + dim3 threads(TILE_W, TILE_H); + dim3 blocks((input_w_ + TILE_W - 1) / TILE_W, (input_h_ + TILE_H - 1) / TILE_H); + + letterbox_kernel_pitched<<>>( + d_input_bgr_pitched_, + input_pitch_bytes_, + img_w, + img_h, + d_nchw_, + input_w_, + input_h_, + scale, + pad_t, + pad_l, + norm, + swap_rb + ); + + CUDA_CHECK(cudaGetLastError()); + return d_nchw_; +} +float* CudaInfer::preprocess_pitched_gpu( + const unsigned char* input_bgr_device, + int img_w, + int img_h, + int input_step, + float norm, + bool swap_rb, + Eigen::Matrix3f& tf_matrix, + cudaStream_t stream +) { + if (!isInitialized()) { + throw std::runtime_error("CudaInfer not initialized properly."); + } + + if (!input_bgr_device || !d_nchw_) { + fprintf(stderr, "[Error] Null pointer in preprocess_pitched_gpu\n"); + return nullptr; + } + getOutEnoughMem(img_w, img_h); + float scale = fminf(static_cast(input_w_) / img_w, static_cast(input_h_) / img_h); + + int rw = static_cast(roundf(img_w * scale)); + int rh = static_cast(roundf(img_h * scale)); + + int pad_l = (input_w_ - rw) / 2; + int pad_t = (input_h_ - rh) / 2; + + tf_matrix << 1.f / scale, 0.f, -pad_l / scale, 0.f, 1.f / scale, -pad_t / scale, 0.f, 0.f, 1.f; + + dim3 threads(TILE_W, TILE_H); + dim3 blocks((input_w_ + TILE_W - 1) / TILE_W, (input_h_ + TILE_H - 1) / TILE_H); + + letterbox_kernel_pitched<<>>( + input_bgr_device, + input_step, + img_w, + img_h, + d_nchw_, + input_w_, + input_h_, + scale, + pad_t, + pad_l, + norm, + swap_rb + ); + + CUDA_CHECK(cudaGetLastError()); + + return d_nchw_; +} + +} // namespace armor_cuda_infer \ No newline at end of file diff --git a/wust_vision-main/cuda_infer/armor_infer.hpp b/wust_vision-main/cuda_infer/armor_infer.hpp new file mode 100644 index 0000000..a273c95 --- /dev/null +++ b/wust_vision-main/cuda_infer/armor_infer.hpp @@ -0,0 +1,78 @@ +// armor_cuda_infer.hpp +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace armor_cuda_infer { + +class CudaInfer { +public: + CudaInfer(); + ~CudaInfer() noexcept; + + void init(int max_src_w, int max_src_h, int input_w, int input_h); + void release(); + bool isInitialized() const { + return d_input_bgr_ && d_nchw_ && d_input_bgr_pitched_; + } + void getOutEnoughMem(int img_w, int img_h); + void rellocMem(); + + float* preprocess( + const unsigned char* input_bgr_host, + int img_w, + int img_h, + float norm, + bool swap_rb, + Eigen::Matrix3f& tf_matrix, + cudaStream_t stream + ); + float* preprocess_pitched( + const unsigned char* input_bgr_host, + int img_w, + int img_h, + int host_step, + float norm, + bool swap_rb, + Eigen::Matrix3f& tf_matrix, + cudaStream_t stream + ); + float* preprocess_gpu( + const unsigned char* input_bgr_device, + int img_w, + int img_h, + float norm, + bool swap_rb, + Eigen::Matrix3f& tf_matrix, + cudaStream_t stream + ); + float* preprocess_pitched_gpu( + const unsigned char* input_bgr_device, + int img_w, + int img_h, + int host_step, + float norm, + bool swap_rb, + Eigen::Matrix3f& tf_matrix, + cudaStream_t stream + ); + cv::Mat tensorToMat(float* d_nchw, int W, int H, float norm, cudaStream_t stream) const; + +private: + CudaInfer(const CudaInfer&) = delete; + CudaInfer& operator=(const CudaInfer&) = delete; + unsigned char* d_input_bgr_ = nullptr; + float* d_nchw_ = nullptr; + unsigned char* d_input_bgr_pitched_ = nullptr; + size_t input_pitch_bytes_ = 0; + int input_w_; + int input_h_; + int max_src_w_, max_src_h_; +}; +} // namespace armor_cuda_infer \ No newline at end of file diff --git a/wust_vision-main/cuda_infer/letter_box.cu b/wust_vision-main/cuda_infer/letter_box.cu new file mode 100644 index 0000000..1937719 --- /dev/null +++ b/wust_vision-main/cuda_infer/letter_box.cu @@ -0,0 +1,147 @@ +#include "letter_box.hpp" +__global__ void letterbox_kernel_shared( + const uchar* __restrict__ input_bgr, + int in_w, + int in_h, + float* __restrict__ output_nchw, + int out_w, + int out_h, + float scale, + int pad_t, + int pad_l, + float norm, + bool swap_rb +) { + // global x/y + int x = blockIdx.x * TILE_W + threadIdx.x; + int y = blockIdx.y * TILE_H + threadIdx.y; + if (x >= out_w || y >= out_h) + return; + + // 共享内存 + halo + __shared__ uchar4 smem[TILE_H + 1][TILE_W + 1]; + + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int total_smem = (TILE_W + 1) * (TILE_H + 1); + int threads_per_block = blockDim.x * blockDim.y; + int iter = (total_smem + threads_per_block - 1) / threads_per_block; + + float inv_scale = 1.0f / scale; + float block_start_x = blockIdx.x * TILE_W - pad_l; + float block_start_y = blockIdx.y * TILE_H - pad_t; + + // load shared memory + for (int i = 0; i < iter; i++) { + int idx = tid + i * threads_per_block; + if (idx < total_smem) { + int sx = idx % (TILE_W + 1); + int sy = idx / (TILE_W + 1); + + float in_x = (block_start_x + sx) * inv_scale; + float in_y = (block_start_y + sy) * inv_scale; + + int ix = floorf(in_x); + int iy = floorf(in_y); + + uchar4 p = make_uchar4(114, 114, 114, 0); // padding BGR + if (ix >= 0 && iy >= 0 && ix < in_w && iy < in_h) { + int offset = (iy * in_w + ix) * 3; + p.x = input_bgr[offset]; // b + p.y = input_bgr[offset + 1]; // g + p.z = input_bgr[offset + 2]; // r + } + + smem[sy][sx] = p; + } + } + __syncthreads(); + + // 双线性插值 + float in_x = (x - pad_l) * inv_scale; + float in_y = (y - pad_t) * inv_scale; + int tx = threadIdx.x; + int ty = threadIdx.y; + float dx = in_x - floorf(in_x); + float dy = in_y - floorf(in_y); + float dx1 = 1.0f - dx, dy1 = 1.0f - dy; + + uchar4 p00 = smem[ty][tx]; + uchar4 p01 = smem[ty][tx + 1]; + uchar4 p10 = smem[ty + 1][tx]; + uchar4 p11 = smem[ty + 1][tx + 1]; + + float out_r = dx1 * dy1 * p00.z + dx * dy1 * p01.z + dx1 * dy * p10.z + dx * dy * p11.z; + float out_g = dx1 * dy1 * p00.y + dx * dy1 * p01.y + dx1 * dy * p10.y + dx * dy * p11.y; + float out_b = dx1 * dy1 * p00.x + dx * dy1 * p01.x + dx1 * dy * p10.x + dx * dy * p11.x; + + int out_idx = y * out_w + x; + if (swap_rb) { + output_nchw[out_idx + 0 * out_w * out_h] = out_r * norm; + output_nchw[out_idx + 1 * out_w * out_h] = out_g * norm; + output_nchw[out_idx + 2 * out_w * out_h] = out_b * norm; + } else { + output_nchw[out_idx + 0 * out_w * out_h] = out_b * norm; + output_nchw[out_idx + 1 * out_w * out_h] = out_g * norm; + output_nchw[out_idx + 2 * out_w * out_h] = out_r * norm; + } +} + +__global__ void letterbox_kernel_pitched( + const unsigned char* __restrict__ d_input_bgr, + size_t pitch, + int src_w, + int src_h, + float* __restrict__ d_nchw, + int OUT_W, + int OUT_H, + float scale, + int pad_t, + int pad_l, + float norm, + bool swap_rb +) { + int ox = blockIdx.x * blockDim.x + threadIdx.x; + int oy = blockIdx.y * blockDim.y + threadIdx.y; + if (ox >= OUT_W || oy >= OUT_H) + return; + + float fx = (ox - pad_l) / scale; + float fy = (oy - pad_t) / scale; + + int out_idx = oy * OUT_W + ox; + int plane = OUT_W * OUT_H; + + // clamp coordinates + fx = fmaxf(0.f, fminf(fx, src_w - 2.f)); + fy = fmaxf(0.f, fminf(fy, src_h - 2.f)); + + int x0 = floorf(fx), y0 = floorf(fy); + int x1 = x0 + 1, y1 = y0 + 1; + + float dx = fx - x0, dy = fy - y0; + float dx1 = 1.f - dx, dy1 = 1.f - dy; + + // row pointers + const uchar3* row0 = (const uchar3*)((const char*)d_input_bgr + y0 * pitch); + const uchar3* row1 = (const uchar3*)((const char*)d_input_bgr + y1 * pitch); + + uchar3 p00 = row0[x0]; + uchar3 p01 = row0[x1]; + uchar3 p10 = row1[x0]; + uchar3 p11 = row1[x1]; + + // bilinear interpolation + float r = dx1 * dy1 * p00.z + dx * dy1 * p01.z + dx1 * dy * p10.z + dx * dy * p11.z; + float g = dx1 * dy1 * p00.y + dx * dy1 * p01.y + dx1 * dy * p10.y + dx * dy * p11.y; + float b = dx1 * dy1 * p00.x + dx * dy1 * p01.x + dx1 * dy * p10.x + dx * dy * p11.x; + + if (swap_rb) { + d_nchw[out_idx + 0 * plane] = r * norm; + d_nchw[out_idx + 1 * plane] = g * norm; + d_nchw[out_idx + 2 * plane] = b * norm; + } else { + d_nchw[out_idx + 0 * plane] = b * norm; + d_nchw[out_idx + 1 * plane] = g * norm; + d_nchw[out_idx + 2 * plane] = r * norm; + } +} diff --git a/wust_vision-main/cuda_infer/letter_box.hpp b/wust_vision-main/cuda_infer/letter_box.hpp new file mode 100644 index 0000000..c1ee12f --- /dev/null +++ b/wust_vision-main/cuda_infer/letter_box.hpp @@ -0,0 +1,42 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +static constexpr int TILE_W = 32; +static constexpr int TILE_H = 32; +__global__ void letterbox_kernel_shared( + const uchar* __restrict__ input_bgr, + int in_w, + int in_h, + float* __restrict__ output_nchw, + int out_w, + int out_h, + float scale, + int pad_t, + int pad_l, + float norm, + bool swap_rb +); +__global__ void letterbox_kernel_pitched( + const unsigned char* __restrict__ d_input_bgr, + size_t pitch, + int src_w, + int src_h, + float* __restrict__ d_nchw, + int OUT_W, + int OUT_H, + float scale, + int pad_t, + int pad_l, + float norm, + bool swap_rb +); diff --git a/wust_vision-main/env.bash b/wust_vision-main/env.bash new file mode 100644 index 0000000..4a8d3d0 --- /dev/null +++ b/wust_vision-main/env.bash @@ -0,0 +1,19 @@ +# #!/bin/bash +export MVCAM_SDK_PATH=/opt/MVS +export MVCAM_COMMON_RUNENV=/opt/MVS/lib +export MVCAM_GENICAM_CLPROTOCOL=/opt/MVS/lib/CLProtocol +export ALLUSERSPROFILE=/opt/MVS/MVFG +export LD_LIBRARY_PATH=/opt/MVS/lib/64:/opt/MVS/lib/32:$LD_LIBRARY_PATH +export MVCAM_SDK_PATH=/opt/MVS + +export MVCAM_SDK_VERSION= + +export MVCAM_COMMON_RUNENV=/opt/MVS/lib + +export MVCAM_GENICAM_CLPROTOCOL=/opt/MVS/lib/CLProtocol + +export ALLUSERSPROFILE=/opt/MVS/MVFG +export LD_LIBRARY_PATH=/opt/MVS/lib/aarch64:$LD_LIBRARY_PATH + +WORK_DIR="$(dirname "$(realpath "${BASH_SOURCE[0]}")")" +export VISION_ROOT="$WORK_DIR" \ No newline at end of file diff --git a/wust_vision-main/format.sh b/wust_vision-main/format.sh new file mode 100755 index 0000000..e630cb4 --- /dev/null +++ b/wust_vision-main/format.sh @@ -0,0 +1,3 @@ +find . -path ./build -prune -o \ + -type f \( -name '*.h' -o -name '*.hpp' -o -name '*.c' -o -name '*.cu' -o -name '*.cpp' \) \ + -exec clang-format -i {} + diff --git a/wust_vision-main/model/0526.engine b/wust_vision-main/model/0526.engine new file mode 100644 index 0000000..c048411 Binary files /dev/null and b/wust_vision-main/model/0526.engine differ diff --git a/wust_vision-main/model/0526.onnx b/wust_vision-main/model/0526.onnx new file mode 100644 index 0000000..c6fc0dc Binary files /dev/null and b/wust_vision-main/model/0526.onnx differ diff --git a/wust_vision-main/model/0708.engine b/wust_vision-main/model/0708.engine new file mode 100644 index 0000000..a359ad2 Binary files /dev/null and b/wust_vision-main/model/0708.engine differ diff --git a/wust_vision-main/model/0708.onnx b/wust_vision-main/model/0708.onnx new file mode 100644 index 0000000..f1b71ad Binary files /dev/null and b/wust_vision-main/model/0708.onnx differ diff --git a/wust_vision-main/model/label.txt b/wust_vision-main/model/label.txt new file mode 100644 index 0000000..4a2c3fb --- /dev/null +++ b/wust_vision-main/model/label.txt @@ -0,0 +1,9 @@ +1 +2 +3 +4 +5 +outpost +guard +base +negative diff --git a/wust_vision-main/model/lenet.onnx b/wust_vision-main/model/lenet.onnx new file mode 100644 index 0000000..8a0b42f Binary files /dev/null and b/wust_vision-main/model/lenet.onnx differ diff --git a/wust_vision-main/model/mlp.onnx b/wust_vision-main/model/mlp.onnx new file mode 100644 index 0000000..2089380 Binary files /dev/null and b/wust_vision-main/model/mlp.onnx differ diff --git a/wust_vision-main/model/opt-0527-001.onnx b/wust_vision-main/model/opt-0527-001.onnx new file mode 100644 index 0000000..3f48f9b Binary files /dev/null and b/wust_vision-main/model/opt-0527-001.onnx differ diff --git a/wust_vision-main/model/opt-1208-001.bin b/wust_vision-main/model/opt-1208-001.bin new file mode 100644 index 0000000..03b2697 Binary files /dev/null and b/wust_vision-main/model/opt-1208-001.bin differ diff --git a/wust_vision-main/model/opt-1208-001.engine b/wust_vision-main/model/opt-1208-001.engine new file mode 100644 index 0000000..52d764c Binary files /dev/null and b/wust_vision-main/model/opt-1208-001.engine differ diff --git a/wust_vision-main/model/opt-1208-001.onnx b/wust_vision-main/model/opt-1208-001.onnx new file mode 100644 index 0000000..6a1916c Binary files /dev/null and b/wust_vision-main/model/opt-1208-001.onnx differ diff --git a/wust_vision-main/model/opt-1208-001.param b/wust_vision-main/model/opt-1208-001.param new file mode 100644 index 0000000..e95eb8f --- /dev/null +++ b/wust_vision-main/model/opt-1208-001.param @@ -0,0 +1,218 @@ +7767517 +216 240 +Input images 0 1 images +Convolution Conv_1 1 1 images input.4 0=16 1=6 11=6 2=1 12=1 3=2 13=2 4=2 14=2 15=2 16=2 5=1 6=1728 +HardSwish HardSwish_2 1 1 input.4 onnx::Conv_483 0=2.000000e-01 1=5.000000e-01 +ConvolutionDepthWise Conv_3 1 1 onnx::Conv_483 input.12 0=16 1=3 11=3 2=1 12=1 3=2 13=2 4=1 14=1 15=1 16=1 5=1 6=144 7=16 +Convolution Conv_4 1 1 input.12 input.20 0=32 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=512 +HardSwish HardSwish_5 1 1 input.20 onnx::Conv_488 0=2.000000e-01 1=5.000000e-01 +Split splitncnn_0 1 2 onnx::Conv_488 onnx::Conv_488_splitncnn_0 onnx::Conv_488_splitncnn_1 +ConvolutionDepthWise Conv_6 1 1 onnx::Conv_488_splitncnn_1 input.28 0=32 1=3 11=3 2=1 12=1 3=2 13=2 4=1 14=1 15=1 16=1 5=1 6=288 7=32 +Convolution Conv_7 1 1 input.28 input.36 0=16 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=512 +HardSwish HardSwish_8 1 1 input.36 onnx::Concat_493 0=2.000000e-01 1=5.000000e-01 +Convolution Conv_9 1 1 onnx::Conv_488_splitncnn_0 input.44 0=16 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=512 +HardSwish HardSwish_10 1 1 input.44 onnx::Conv_496 0=2.000000e-01 1=5.000000e-01 +ConvolutionDepthWise Conv_11 1 1 onnx::Conv_496 input.52 0=16 1=3 11=3 2=1 12=1 3=2 13=2 4=1 14=1 15=1 16=1 5=1 6=144 7=16 +Convolution Conv_12 1 1 input.52 input.60 0=48 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=768 +HardSwish HardSwish_13 1 1 input.60 onnx::Concat_501 0=2.000000e-01 1=5.000000e-01 +Concat Concat_14 2 1 onnx::Concat_493 onnx::Concat_501 x 0=0 +ShuffleChannel Reshape_19 1 1 x onnx::Slice_507 0=2 1=0 +Split splitncnn_1 1 2 onnx::Slice_507 onnx::Slice_507_splitncnn_0 onnx::Slice_507_splitncnn_1 +Crop Slice_24 1 1 onnx::Slice_507_splitncnn_1 onnx::Concat_512 -23309=1,0 -23310=1,32 -23311=1,0 +Crop Slice_29 1 1 onnx::Slice_507_splitncnn_0 onnx::Conv_517 -23309=1,32 -23310=1,2147483647 -23311=1,0 +Convolution Conv_30 1 1 onnx::Conv_517 input.68 0=32 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=1024 +HardSwish HardSwish_31 1 1 input.68 onnx::Conv_520 0=2.000000e-01 1=5.000000e-01 +ConvolutionDepthWise Conv_32 1 1 onnx::Conv_520 input.76 0=32 1=3 11=3 2=1 12=1 3=1 13=1 4=1 14=1 15=1 16=1 5=1 6=288 7=32 +Convolution Conv_33 1 1 input.76 input.84 0=32 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=1024 +HardSwish HardSwish_34 1 1 input.84 onnx::Concat_525 0=2.000000e-01 1=5.000000e-01 +Concat Concat_35 2 1 onnx::Concat_512 onnx::Concat_525 x.3 0=0 +ShuffleChannel Reshape_40 1 1 x.3 onnx::Slice_531 0=2 1=0 +Split splitncnn_2 1 2 onnx::Slice_531 onnx::Slice_531_splitncnn_0 onnx::Slice_531_splitncnn_1 +Crop Slice_45 1 1 onnx::Slice_531_splitncnn_1 onnx::Concat_536 -23309=1,0 -23310=1,32 -23311=1,0 +Crop Slice_50 1 1 onnx::Slice_531_splitncnn_0 onnx::Conv_541 -23309=1,32 -23310=1,2147483647 -23311=1,0 +Convolution Conv_51 1 1 onnx::Conv_541 input.92 0=32 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=1024 +HardSwish HardSwish_52 1 1 input.92 onnx::Conv_544 0=2.000000e-01 1=5.000000e-01 +ConvolutionDepthWise Conv_53 1 1 onnx::Conv_544 input.100 0=32 1=3 11=3 2=1 12=1 3=1 13=1 4=1 14=1 15=1 16=1 5=1 6=288 7=32 +Convolution Conv_54 1 1 input.100 input.108 0=32 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=1024 +HardSwish HardSwish_55 1 1 input.108 onnx::Concat_549 0=2.000000e-01 1=5.000000e-01 +Concat Concat_56 2 1 onnx::Concat_536 onnx::Concat_549 x.7 0=0 +ShuffleChannel Reshape_61 1 1 x.7 input.112 0=2 1=0 +Split splitncnn_3 1 4 input.112 input.112_splitncnn_0 input.112_splitncnn_1 input.112_splitncnn_2 input.112_splitncnn_3 +ConvolutionDepthWise Conv_62 1 1 input.112_splitncnn_3 input.120 0=64 1=3 11=3 2=1 12=1 3=2 13=2 4=1 14=1 15=1 16=1 5=1 6=576 7=64 +Convolution Conv_63 1 1 input.120 input.128 0=32 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=2048 +HardSwish HardSwish_64 1 1 input.128 onnx::Concat_560 0=2.000000e-01 1=5.000000e-01 +Convolution Conv_65 1 1 input.112_splitncnn_2 input.136 0=32 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=2048 +HardSwish HardSwish_66 1 1 input.136 onnx::Conv_563 0=2.000000e-01 1=5.000000e-01 +ConvolutionDepthWise Conv_67 1 1 onnx::Conv_563 input.144 0=32 1=3 11=3 2=1 12=1 3=2 13=2 4=1 14=1 15=1 16=1 5=1 6=288 7=32 +Convolution Conv_68 1 1 input.144 input.152 0=96 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=3072 +HardSwish HardSwish_69 1 1 input.152 onnx::Concat_568 0=2.000000e-01 1=5.000000e-01 +Concat Concat_70 2 1 onnx::Concat_560 onnx::Concat_568 x.11 0=0 +ShuffleChannel Reshape_75 1 1 x.11 onnx::Slice_574 0=2 1=0 +Split splitncnn_4 1 2 onnx::Slice_574 onnx::Slice_574_splitncnn_0 onnx::Slice_574_splitncnn_1 +Crop Slice_80 1 1 onnx::Slice_574_splitncnn_1 onnx::Concat_579 -23309=1,0 -23310=1,64 -23311=1,0 +Crop Slice_85 1 1 onnx::Slice_574_splitncnn_0 onnx::Conv_584 -23309=1,64 -23310=1,2147483647 -23311=1,0 +Convolution Conv_86 1 1 onnx::Conv_584 input.160 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096 +HardSwish HardSwish_87 1 1 input.160 onnx::Conv_587 0=2.000000e-01 1=5.000000e-01 +ConvolutionDepthWise Conv_88 1 1 onnx::Conv_587 input.168 0=64 1=3 11=3 2=1 12=1 3=1 13=1 4=1 14=1 15=1 16=1 5=1 6=576 7=64 +Convolution Conv_89 1 1 input.168 input.176 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096 +HardSwish HardSwish_90 1 1 input.176 onnx::Concat_592 0=2.000000e-01 1=5.000000e-01 +Concat Concat_91 2 1 onnx::Concat_579 onnx::Concat_592 x.15 0=0 +ShuffleChannel Reshape_96 1 1 x.15 onnx::Slice_598 0=2 1=0 +Split splitncnn_5 1 2 onnx::Slice_598 onnx::Slice_598_splitncnn_0 onnx::Slice_598_splitncnn_1 +Crop Slice_101 1 1 onnx::Slice_598_splitncnn_1 onnx::Concat_603 -23309=1,0 -23310=1,64 -23311=1,0 +Crop Slice_106 1 1 onnx::Slice_598_splitncnn_0 onnx::Conv_608 -23309=1,64 -23310=1,2147483647 -23311=1,0 +Convolution Conv_107 1 1 onnx::Conv_608 input.184 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096 +HardSwish HardSwish_108 1 1 input.184 onnx::Conv_611 0=2.000000e-01 1=5.000000e-01 +ConvolutionDepthWise Conv_109 1 1 onnx::Conv_611 input.192 0=64 1=3 11=3 2=1 12=1 3=1 13=1 4=1 14=1 15=1 16=1 5=1 6=576 7=64 +Convolution Conv_110 1 1 input.192 input.200 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096 +HardSwish HardSwish_111 1 1 input.200 onnx::Concat_616 0=2.000000e-01 1=5.000000e-01 +Concat Concat_112 2 1 onnx::Concat_603 onnx::Concat_616 x.19 0=0 +ShuffleChannel Reshape_117 1 1 x.19 input.204 0=2 1=0 +Split splitncnn_6 1 3 input.204 input.204_splitncnn_0 input.204_splitncnn_1 input.204_splitncnn_2 +ConvolutionDepthWise Conv_118 1 1 input.204_splitncnn_2 input.212 0=128 1=3 11=3 2=1 12=1 3=2 13=2 4=1 14=1 15=1 16=1 5=1 6=1152 7=128 +Convolution Conv_119 1 1 input.212 input.220 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=8192 +HardSwish HardSwish_120 1 1 input.220 onnx::Concat_627 0=2.000000e-01 1=5.000000e-01 +Convolution Conv_121 1 1 input.204_splitncnn_1 input.228 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=8192 +HardSwish HardSwish_122 1 1 input.228 onnx::Conv_630 0=2.000000e-01 1=5.000000e-01 +ConvolutionDepthWise Conv_123 1 1 onnx::Conv_630 input.236 0=64 1=3 11=3 2=1 12=1 3=2 13=2 4=1 14=1 15=1 16=1 5=1 6=576 7=64 +Convolution Conv_124 1 1 input.236 input.244 0=192 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=12288 +HardSwish HardSwish_125 1 1 input.244 onnx::Concat_635 0=2.000000e-01 1=5.000000e-01 +Concat Concat_126 2 1 onnx::Concat_627 onnx::Concat_635 x.23 0=0 +ShuffleChannel Reshape_131 1 1 x.23 onnx::Slice_641 0=2 1=0 +Split splitncnn_7 1 2 onnx::Slice_641 onnx::Slice_641_splitncnn_0 onnx::Slice_641_splitncnn_1 +Crop Slice_136 1 1 onnx::Slice_641_splitncnn_1 onnx::Concat_646 -23309=1,0 -23310=1,128 -23311=1,0 +Crop Slice_141 1 1 onnx::Slice_641_splitncnn_0 onnx::Conv_651 -23309=1,128 -23310=1,2147483647 -23311=1,0 +Convolution Conv_142 1 1 onnx::Conv_651 input.252 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=16384 +HardSwish HardSwish_143 1 1 input.252 onnx::Conv_654 0=2.000000e-01 1=5.000000e-01 +ConvolutionDepthWise Conv_144 1 1 onnx::Conv_654 input.260 0=128 1=3 11=3 2=1 12=1 3=1 13=1 4=1 14=1 15=1 16=1 5=1 6=1152 7=128 +Convolution Conv_145 1 1 input.260 input.268 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=16384 +HardSwish HardSwish_146 1 1 input.268 onnx::Concat_659 0=2.000000e-01 1=5.000000e-01 +Concat Concat_147 2 1 onnx::Concat_646 onnx::Concat_659 x.27 0=0 +ShuffleChannel Reshape_152 1 1 x.27 onnx::Slice_665 0=2 1=0 +Split splitncnn_8 1 2 onnx::Slice_665 onnx::Slice_665_splitncnn_0 onnx::Slice_665_splitncnn_1 +Crop Slice_157 1 1 onnx::Slice_665_splitncnn_1 onnx::Concat_670 -23309=1,0 -23310=1,128 -23311=1,0 +Crop Slice_162 1 1 onnx::Slice_665_splitncnn_0 onnx::Conv_675 -23309=1,128 -23310=1,2147483647 -23311=1,0 +Convolution Conv_163 1 1 onnx::Conv_675 input.276 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=16384 +HardSwish HardSwish_164 1 1 input.276 onnx::Conv_678 0=2.000000e-01 1=5.000000e-01 +ConvolutionDepthWise Conv_165 1 1 onnx::Conv_678 input.284 0=128 1=3 11=3 2=1 12=1 3=1 13=1 4=1 14=1 15=1 16=1 5=1 6=1152 7=128 +Convolution Conv_166 1 1 input.284 input.292 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=16384 +HardSwish HardSwish_167 1 1 input.292 onnx::Concat_683 0=2.000000e-01 1=5.000000e-01 +Concat Concat_168 2 1 onnx::Concat_670 onnx::Concat_683 x.31 0=0 +ShuffleChannel Reshape_173 1 1 x.31 input.296 0=2 1=0 +Convolution Conv_174 1 1 input.296 input.304 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=32768 +HardSwish HardSwish_175 1 1 input.304 onnx::Conv_692 0=2.000000e-01 1=5.000000e-01 +Split splitncnn_9 1 2 onnx::Conv_692 onnx::Conv_692_splitncnn_0 onnx::Conv_692_splitncnn_1 +Convolution Conv_176 1 1 onnx::Conv_692_splitncnn_1 input.312 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=8192 +HardSwish HardSwish_177 1 1 input.312 onnx::Resize_695 0=2.000000e-01 1=5.000000e-01 +Interp Resize_178 1 1 onnx::Resize_695 onnx::Concat_700 0=1 1=2.000000e+00 2=2.000000e+00 3=0 4=0 6=0 +ConvolutionDepthWise Conv_179 1 1 input.112_splitncnn_1 input.320 0=64 1=3 11=3 2=1 12=1 3=2 13=2 4=1 14=1 15=1 16=1 5=1 6=576 7=64 +Convolution Conv_180 1 1 input.320 input.328 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096 +HardSwish HardSwish_181 1 1 input.328 onnx::Concat_705 0=2.000000e-01 1=5.000000e-01 +Concat Concat_182 3 1 onnx::Concat_705 input.204_splitncnn_0 onnx::Concat_700 input.332 0=0 +Split splitncnn_10 1 2 input.332 input.332_splitncnn_0 input.332_splitncnn_1 +Convolution Conv_183 1 1 input.332_splitncnn_1 input.340 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=16384 +HardSwish HardSwish_184 1 1 input.340 onnx::Conv_709 0=2.000000e-01 1=5.000000e-01 +Convolution Conv_185 1 1 input.332_splitncnn_0 input.348 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=16384 +HardSwish HardSwish_186 1 1 input.348 onnx::Concat_712 0=2.000000e-01 1=5.000000e-01 +Convolution Conv_187 1 1 onnx::Conv_709 input.356 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096 +HardSwish HardSwish_188 1 1 input.356 onnx::Conv_715 0=2.000000e-01 1=5.000000e-01 +ConvolutionDepthWise Conv_189 1 1 onnx::Conv_715 input.364 0=64 1=3 11=3 2=1 12=1 3=1 13=1 4=1 14=1 15=1 16=1 5=1 6=576 7=64 +Convolution Conv_190 1 1 input.364 input.372 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096 +HardSwish HardSwish_191 1 1 input.372 onnx::Concat_720 0=2.000000e-01 1=5.000000e-01 +Concat Concat_192 2 1 onnx::Concat_720 onnx::Concat_712 input.376 0=0 +Convolution Conv_193 1 1 input.376 input.384 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=16384 +HardSwish HardSwish_194 1 1 input.384 onnx::Conv_724 0=2.000000e-01 1=5.000000e-01 +Split splitncnn_11 1 3 onnx::Conv_724 onnx::Conv_724_splitncnn_0 onnx::Conv_724_splitncnn_1 onnx::Conv_724_splitncnn_2 +Convolution Conv_195 1 1 onnx::Conv_724_splitncnn_2 input.392 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=8192 +HardSwish HardSwish_196 1 1 input.392 onnx::Resize_727 0=2.000000e-01 1=5.000000e-01 +Interp Resize_197 1 1 onnx::Resize_727 onnx::Concat_732 0=1 1=2.000000e+00 2=2.000000e+00 3=0 4=0 6=0 +Concat Concat_198 2 1 input.112_splitncnn_0 onnx::Concat_732 input.396 0=0 +Split splitncnn_12 1 2 input.396 input.396_splitncnn_0 input.396_splitncnn_1 +Convolution Conv_199 1 1 input.396_splitncnn_1 input.404 0=32 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096 +HardSwish HardSwish_200 1 1 input.404 onnx::Conv_736 0=2.000000e-01 1=5.000000e-01 +Convolution Conv_201 1 1 input.396_splitncnn_0 input.412 0=32 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096 +HardSwish HardSwish_202 1 1 input.412 onnx::Concat_739 0=2.000000e-01 1=5.000000e-01 +Convolution Conv_203 1 1 onnx::Conv_736 input.420 0=32 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=1024 +HardSwish HardSwish_204 1 1 input.420 onnx::Conv_742 0=2.000000e-01 1=5.000000e-01 +ConvolutionDepthWise Conv_205 1 1 onnx::Conv_742 input.428 0=32 1=3 11=3 2=1 12=1 3=1 13=1 4=1 14=1 15=1 16=1 5=1 6=288 7=32 +Convolution Conv_206 1 1 input.428 input.436 0=32 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=1024 +HardSwish HardSwish_207 1 1 input.436 onnx::Concat_747 0=2.000000e-01 1=5.000000e-01 +Concat Concat_208 2 1 onnx::Concat_747 onnx::Concat_739 input.440 0=0 +Convolution Conv_209 1 1 input.440 input.448 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096 +HardSwish HardSwish_210 1 1 input.448 onnx::Conv_751 0=2.000000e-01 1=5.000000e-01 +ConvolutionDepthWise Conv_211 1 1 onnx::Conv_724_splitncnn_1 input.456 0=128 1=3 11=3 2=1 12=1 3=2 13=2 4=1 14=1 15=1 16=1 5=1 6=1152 7=128 +Convolution Conv_212 1 1 input.456 input.464 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=16384 +HardSwish HardSwish_213 1 1 input.464 onnx::Concat_756 0=2.000000e-01 1=5.000000e-01 +Concat Concat_214 2 1 onnx::Concat_756 onnx::Conv_692_splitncnn_0 input.468 0=0 +Split splitncnn_13 1 2 input.468 input.468_splitncnn_0 input.468_splitncnn_1 +Convolution Conv_215 1 1 input.468_splitncnn_1 input.476 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=32768 +HardSwish HardSwish_216 1 1 input.476 onnx::Conv_760 0=2.000000e-01 1=5.000000e-01 +Convolution Conv_217 1 1 input.468_splitncnn_0 input.484 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=32768 +HardSwish HardSwish_218 1 1 input.484 onnx::Concat_763 0=2.000000e-01 1=5.000000e-01 +Convolution Conv_219 1 1 onnx::Conv_760 input.492 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=16384 +HardSwish HardSwish_220 1 1 input.492 onnx::Conv_766 0=2.000000e-01 1=5.000000e-01 +ConvolutionDepthWise Conv_221 1 1 onnx::Conv_766 input.500 0=128 1=3 11=3 2=1 12=1 3=1 13=1 4=1 14=1 15=1 16=1 5=1 6=1152 7=128 +Convolution Conv_222 1 1 input.500 input.508 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=16384 +HardSwish HardSwish_223 1 1 input.508 onnx::Concat_771 0=2.000000e-01 1=5.000000e-01 +Concat Concat_224 2 1 onnx::Concat_771 onnx::Concat_763 input.512 0=0 +Convolution Conv_225 1 1 input.512 input.520 0=256 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=65536 +HardSwish HardSwish_226 1 1 input.520 onnx::Conv_775 0=2.000000e-01 1=5.000000e-01 +Convolution Conv_227 1 1 onnx::Conv_751 input.528 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096 +HardSwish HardSwish_228 1 1 input.528 onnx::Conv_778 0=2.000000e-01 1=5.000000e-01 +ConvolutionDepthWise Conv_229 1 1 onnx::Conv_778 input.536 0=64 1=5 11=5 2=1 12=1 3=1 13=1 4=2 14=2 15=2 16=2 5=1 6=1600 7=64 +ConvolutionDepthWise Conv_230 1 1 input.536 input.544 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096 7=2 +HardSwish HardSwish_231 1 1 input.544 onnx::Conv_783 0=2.000000e-01 1=5.000000e-01 +ConvolutionDepthWise Conv_232 1 1 onnx::Conv_783 input.552 0=128 1=5 11=5 2=1 12=1 3=1 13=1 4=2 14=2 15=2 16=2 5=1 6=3200 7=128 +ConvolutionDepthWise Conv_233 1 1 input.552 input.560 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=8192 7=2 +HardSwish HardSwish_234 1 1 input.560 onnx::Slice_788 0=2.000000e-01 1=5.000000e-01 +Split splitncnn_14 1 2 onnx::Slice_788 onnx::Slice_788_splitncnn_0 onnx::Slice_788_splitncnn_1 +Crop Slice_239 1 1 onnx::Slice_788_splitncnn_1 onnx::Conv_793 -23309=1,0 -23310=1,64 -23311=1,0 +Split splitncnn_15 1 2 onnx::Conv_793 onnx::Conv_793_splitncnn_0 onnx::Conv_793_splitncnn_1 +Crop Slice_244 1 1 onnx::Slice_788_splitncnn_0 onnx::Conv_798 -23309=1,64 -23310=1,2147483647 -23311=1,0 +Convolution Conv_245 1 1 onnx::Conv_798 onnx::Sigmoid_799 0=12 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=768 +Convolution Conv_246 1 1 onnx::Conv_793_splitncnn_1 onnx::Concat_800 0=8 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=512 +Convolution Conv_247 1 1 onnx::Conv_793_splitncnn_0 onnx::Sigmoid_801 0=1 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=64 +Sigmoid Sigmoid_248 1 1 onnx::Sigmoid_801 onnx::Concat_802 +Sigmoid Sigmoid_249 1 1 onnx::Sigmoid_799 onnx::Concat_803 +Concat Concat_250 3 1 onnx::Concat_800 onnx::Concat_802 onnx::Concat_803 onnx::Shape_804 0=0 +Convolution Conv_251 1 1 onnx::Conv_724_splitncnn_0 input.568 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=8192 +HardSwish HardSwish_252 1 1 input.568 onnx::Conv_807 0=2.000000e-01 1=5.000000e-01 +ConvolutionDepthWise Conv_253 1 1 onnx::Conv_807 input.576 0=64 1=5 11=5 2=1 12=1 3=1 13=1 4=2 14=2 15=2 16=2 5=1 6=1600 7=64 +ConvolutionDepthWise Conv_254 1 1 input.576 input.584 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096 7=2 +HardSwish HardSwish_255 1 1 input.584 onnx::Conv_812 0=2.000000e-01 1=5.000000e-01 +ConvolutionDepthWise Conv_256 1 1 onnx::Conv_812 input.592 0=128 1=5 11=5 2=1 12=1 3=1 13=1 4=2 14=2 15=2 16=2 5=1 6=3200 7=128 +ConvolutionDepthWise Conv_257 1 1 input.592 input.600 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=8192 7=2 +HardSwish HardSwish_258 1 1 input.600 onnx::Slice_817 0=2.000000e-01 1=5.000000e-01 +Split splitncnn_16 1 2 onnx::Slice_817 onnx::Slice_817_splitncnn_0 onnx::Slice_817_splitncnn_1 +Crop Slice_263 1 1 onnx::Slice_817_splitncnn_1 onnx::Conv_822 -23309=1,0 -23310=1,64 -23311=1,0 +Split splitncnn_17 1 2 onnx::Conv_822 onnx::Conv_822_splitncnn_0 onnx::Conv_822_splitncnn_1 +Crop Slice_268 1 1 onnx::Slice_817_splitncnn_0 onnx::Conv_827 -23309=1,64 -23310=1,2147483647 -23311=1,0 +Convolution Conv_269 1 1 onnx::Conv_827 onnx::Sigmoid_828 0=12 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=768 +Convolution Conv_270 1 1 onnx::Conv_822_splitncnn_1 onnx::Concat_829 0=8 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=512 +Convolution Conv_271 1 1 onnx::Conv_822_splitncnn_0 onnx::Sigmoid_830 0=1 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=64 +Sigmoid Sigmoid_272 1 1 onnx::Sigmoid_830 onnx::Concat_831 +Sigmoid Sigmoid_273 1 1 onnx::Sigmoid_828 onnx::Concat_832 +Concat Concat_274 3 1 onnx::Concat_829 onnx::Concat_831 onnx::Concat_832 onnx::Shape_833 0=0 +Convolution Conv_275 1 1 onnx::Conv_775 input.608 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=16384 +HardSwish HardSwish_276 1 1 input.608 onnx::Conv_836 0=2.000000e-01 1=5.000000e-01 +ConvolutionDepthWise Conv_277 1 1 onnx::Conv_836 input.616 0=64 1=5 11=5 2=1 12=1 3=1 13=1 4=2 14=2 15=2 16=2 5=1 6=1600 7=64 +ConvolutionDepthWise Conv_278 1 1 input.616 input.624 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096 7=2 +HardSwish HardSwish_279 1 1 input.624 onnx::Conv_841 0=2.000000e-01 1=5.000000e-01 +ConvolutionDepthWise Conv_280 1 1 onnx::Conv_841 input.632 0=128 1=5 11=5 2=1 12=1 3=1 13=1 4=2 14=2 15=2 16=2 5=1 6=3200 7=128 +ConvolutionDepthWise Conv_281 1 1 input.632 input.640 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=8192 7=2 +HardSwish HardSwish_282 1 1 input.640 onnx::Slice_846 0=2.000000e-01 1=5.000000e-01 +Split splitncnn_18 1 2 onnx::Slice_846 onnx::Slice_846_splitncnn_0 onnx::Slice_846_splitncnn_1 +Crop Slice_287 1 1 onnx::Slice_846_splitncnn_1 onnx::Conv_851 -23309=1,0 -23310=1,64 -23311=1,0 +Split splitncnn_19 1 2 onnx::Conv_851 onnx::Conv_851_splitncnn_0 onnx::Conv_851_splitncnn_1 +Crop Slice_292 1 1 onnx::Slice_846_splitncnn_0 onnx::Conv_856 -23309=1,64 -23310=1,2147483647 -23311=1,0 +Convolution Conv_293 1 1 onnx::Conv_856 onnx::Sigmoid_857 0=12 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=768 +Convolution Conv_294 1 1 onnx::Conv_851_splitncnn_1 onnx::Concat_858 0=8 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=512 +Convolution Conv_295 1 1 onnx::Conv_851_splitncnn_0 onnx::Sigmoid_859 0=1 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=64 +Sigmoid Sigmoid_296 1 1 onnx::Sigmoid_859 onnx::Concat_860 +Sigmoid Sigmoid_297 1 1 onnx::Sigmoid_857 onnx::Concat_861 +Concat Concat_298 3 1 onnx::Concat_858 onnx::Concat_860 onnx::Concat_861 output.1 0=0 +Reshape Reshape_306 1 1 onnx::Shape_804 onnx::Concat_870 0=-1 1=21 +Reshape Reshape_314 1 1 onnx::Shape_833 onnx::Concat_878 0=-1 1=21 +Reshape Reshape_322 1 1 output.1 onnx::Concat_886 0=-1 1=21 +Concat Concat_323 3 1 onnx::Concat_870 onnx::Concat_878 onnx::Concat_886 onnx::Transpose_887 0=1 +Permute Transpose_324 1 1 onnx::Transpose_887 output 0=1 diff --git a/wust_vision-main/model/opt_1208_001.ncnn.bin b/wust_vision-main/model/opt_1208_001.ncnn.bin new file mode 100644 index 0000000..7a05589 Binary files /dev/null and b/wust_vision-main/model/opt_1208_001.ncnn.bin differ diff --git a/wust_vision-main/model/opt_1208_001.ncnn.param b/wust_vision-main/model/opt_1208_001.ncnn.param new file mode 100644 index 0000000..755c47c --- /dev/null +++ b/wust_vision-main/model/opt_1208_001.ncnn.param @@ -0,0 +1,193 @@ +7767517 +191 215 +Input in0 0 1 in0 +Convolution conv_69 1 1 in0 1 0=16 1=6 11=6 12=1 13=2 14=2 2=1 3=2 4=2 5=1 6=1728 +HardSwish hswish_0 1 1 1 2 0=1.666667e-01 1=5.000000e-01 +ConvolutionDepthWise convdw_128 1 1 2 3 0=16 1=3 11=3 12=1 13=2 14=1 2=1 3=2 4=1 5=1 6=144 7=16 +Convolution conv_70 1 1 3 4 0=32 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=512 +HardSwish hswish_1 1 1 4 5 0=1.666667e-01 1=5.000000e-01 +Split splitncnn_0 1 2 5 6 7 +ConvolutionDepthWise convdw_129 1 1 7 8 0=32 1=3 11=3 12=1 13=2 14=1 2=1 3=2 4=1 5=1 6=288 7=32 +Convolution conv_71 1 1 8 9 0=16 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=512 +HardSwish hswish_2 1 1 9 10 0=1.666667e-01 1=5.000000e-01 +Convolution conv_72 1 1 6 11 0=16 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=512 +HardSwish hswish_3 1 1 11 12 0=1.666667e-01 1=5.000000e-01 +ConvolutionDepthWise convdw_130 1 1 12 13 0=16 1=3 11=3 12=1 13=2 14=1 2=1 3=2 4=1 5=1 6=144 7=16 +Convolution conv_73 1 1 13 14 0=48 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=768 +HardSwish hswish_4 1 1 14 15 0=1.666667e-01 1=5.000000e-01 +Concat cat_0 2 1 10 15 16 0=0 +ShuffleChannel channelshuffle_60 1 1 16 17 0=2 1=0 +Slice tensor_split_0 1 2 17 18 19 -23300=2,32,-233 1=0 +Convolution conv_74 1 1 19 20 0=32 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=1024 +HardSwish hswish_5 1 1 20 21 0=1.666667e-01 1=5.000000e-01 +ConvolutionDepthWise convdw_131 1 1 21 22 0=32 1=3 11=3 12=1 13=1 14=1 2=1 3=1 4=1 5=1 6=288 7=32 +Convolution conv_75 1 1 22 23 0=32 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=1024 +HardSwish hswish_6 1 1 23 24 0=1.666667e-01 1=5.000000e-01 +Concat cat_1 2 1 18 24 25 0=0 +ShuffleChannel channelshuffle_61 1 1 25 26 0=2 1=0 +Slice tensor_split_1 1 2 26 27 28 -23300=2,32,-233 1=0 +Convolution conv_76 1 1 28 29 0=32 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=1024 +HardSwish hswish_7 1 1 29 30 0=1.666667e-01 1=5.000000e-01 +ConvolutionDepthWise convdw_132 1 1 30 31 0=32 1=3 11=3 12=1 13=1 14=1 2=1 3=1 4=1 5=1 6=288 7=32 +Convolution conv_77 1 1 31 32 0=32 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=1024 +HardSwish hswish_8 1 1 32 33 0=1.666667e-01 1=5.000000e-01 +Concat cat_2 2 1 27 33 34 0=0 +ShuffleChannel channelshuffle_62 1 1 34 35 0=2 1=0 +Split splitncnn_1 1 4 35 36 37 38 39 +ConvolutionDepthWise convdw_133 1 1 39 40 0=64 1=3 11=3 12=1 13=2 14=1 2=1 3=2 4=1 5=1 6=576 7=64 +Convolution conv_78 1 1 40 41 0=32 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=2048 +HardSwish hswish_9 1 1 41 42 0=1.666667e-01 1=5.000000e-01 +Convolution conv_79 1 1 37 43 0=32 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=2048 +HardSwish hswish_10 1 1 43 44 0=1.666667e-01 1=5.000000e-01 +ConvolutionDepthWise convdw_134 1 1 44 45 0=32 1=3 11=3 12=1 13=2 14=1 2=1 3=2 4=1 5=1 6=288 7=32 +Convolution conv_80 1 1 45 46 0=96 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=3072 +HardSwish hswish_11 1 1 46 47 0=1.666667e-01 1=5.000000e-01 +Concat cat_3 2 1 42 47 48 0=0 +ShuffleChannel channelshuffle_63 1 1 48 49 0=2 1=0 +Slice tensor_split_2 1 2 49 50 51 -23300=2,64,-233 1=0 +Convolution conv_81 1 1 51 52 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096 +HardSwish hswish_12 1 1 52 53 0=1.666667e-01 1=5.000000e-01 +ConvolutionDepthWise convdw_135 1 1 53 54 0=64 1=3 11=3 12=1 13=1 14=1 2=1 3=1 4=1 5=1 6=576 7=64 +Convolution conv_82 1 1 54 55 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096 +HardSwish hswish_13 1 1 55 56 0=1.666667e-01 1=5.000000e-01 +Concat cat_4 2 1 50 56 57 0=0 +ShuffleChannel channelshuffle_64 1 1 57 58 0=2 1=0 +Slice tensor_split_3 1 2 58 59 60 -23300=2,64,-233 1=0 +Convolution conv_83 1 1 60 61 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096 +HardSwish hswish_14 1 1 61 62 0=1.666667e-01 1=5.000000e-01 +ConvolutionDepthWise convdw_136 1 1 62 63 0=64 1=3 11=3 12=1 13=1 14=1 2=1 3=1 4=1 5=1 6=576 7=64 +Convolution conv_84 1 1 63 64 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096 +HardSwish hswish_15 1 1 64 65 0=1.666667e-01 1=5.000000e-01 +Concat cat_5 2 1 59 65 66 0=0 +ShuffleChannel channelshuffle_65 1 1 66 67 0=2 1=0 +Split splitncnn_2 1 3 67 68 69 70 +ConvolutionDepthWise convdw_137 1 1 70 71 0=128 1=3 11=3 12=1 13=2 14=1 2=1 3=2 4=1 5=1 6=1152 7=128 +Convolution conv_85 1 1 71 72 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=8192 +HardSwish hswish_16 1 1 72 73 0=1.666667e-01 1=5.000000e-01 +Convolution conv_86 1 1 69 74 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=8192 +HardSwish hswish_17 1 1 74 75 0=1.666667e-01 1=5.000000e-01 +ConvolutionDepthWise convdw_138 1 1 75 76 0=64 1=3 11=3 12=1 13=2 14=1 2=1 3=2 4=1 5=1 6=576 7=64 +Convolution conv_87 1 1 76 77 0=192 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=12288 +HardSwish hswish_18 1 1 77 78 0=1.666667e-01 1=5.000000e-01 +Concat cat_6 2 1 73 78 79 0=0 +ShuffleChannel channelshuffle_66 1 1 79 80 0=2 1=0 +Slice tensor_split_4 1 2 80 81 82 -23300=2,128,-233 1=0 +Convolution conv_88 1 1 82 83 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=16384 +HardSwish hswish_19 1 1 83 84 0=1.666667e-01 1=5.000000e-01 +ConvolutionDepthWise convdw_139 1 1 84 85 0=128 1=3 11=3 12=1 13=1 14=1 2=1 3=1 4=1 5=1 6=1152 7=128 +Convolution conv_89 1 1 85 86 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=16384 +HardSwish hswish_20 1 1 86 87 0=1.666667e-01 1=5.000000e-01 +Concat cat_7 2 1 81 87 88 0=0 +ShuffleChannel channelshuffle_67 1 1 88 89 0=2 1=0 +Slice tensor_split_5 1 2 89 90 91 -23300=2,128,-233 1=0 +Convolution conv_90 1 1 91 92 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=16384 +HardSwish hswish_21 1 1 92 93 0=1.666667e-01 1=5.000000e-01 +ConvolutionDepthWise convdw_140 1 1 93 94 0=128 1=3 11=3 12=1 13=1 14=1 2=1 3=1 4=1 5=1 6=1152 7=128 +Convolution conv_91 1 1 94 95 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=16384 +HardSwish hswish_22 1 1 95 96 0=1.666667e-01 1=5.000000e-01 +Concat cat_8 2 1 90 96 97 0=0 +ShuffleChannel channelshuffle_68 1 1 97 98 0=2 1=0 +Convolution conv_92 1 1 98 99 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=32768 +HardSwish hswish_23 1 1 99 100 0=1.666667e-01 1=5.000000e-01 +Split splitncnn_3 1 2 100 101 102 +Convolution conv_93 1 1 102 103 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=8192 +HardSwish hswish_24 1 1 103 104 0=1.666667e-01 1=5.000000e-01 +Interp interpolate_52 1 1 104 105 0=1 1=2.000000e+00 2=2.000000e+00 6=0 +ConvolutionDepthWise convdw_141 1 1 38 106 0=64 1=3 11=3 12=1 13=2 14=1 2=1 3=2 4=1 5=1 6=576 7=64 +Convolution conv_94 1 1 106 107 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096 +HardSwish hswish_25 1 1 107 108 0=1.666667e-01 1=5.000000e-01 +Concat cat_9 3 1 108 68 105 109 0=0 +Split splitncnn_4 1 2 109 110 111 +Convolution conv_95 1 1 111 112 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=16384 +HardSwish hswish_26 1 1 112 113 0=1.666667e-01 1=5.000000e-01 +Convolution conv_96 1 1 110 114 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=16384 +HardSwish hswish_27 1 1 114 115 0=1.666667e-01 1=5.000000e-01 +Convolution conv_97 1 1 113 116 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096 +HardSwish hswish_28 1 1 116 117 0=1.666667e-01 1=5.000000e-01 +ConvolutionDepthWise convdw_142 1 1 117 118 0=64 1=3 11=3 12=1 13=1 14=1 2=1 3=1 4=1 5=1 6=576 7=64 +Convolution conv_98 1 1 118 119 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096 +HardSwish hswish_29 1 1 119 120 0=1.666667e-01 1=5.000000e-01 +Concat cat_10 2 1 120 115 121 0=0 +Convolution conv_99 1 1 121 122 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=16384 +HardSwish hswish_30 1 1 122 123 0=1.666667e-01 1=5.000000e-01 +Split splitncnn_5 1 3 123 124 125 126 +Convolution conv_100 1 1 125 127 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=8192 +HardSwish hswish_31 1 1 127 128 0=1.666667e-01 1=5.000000e-01 +Interp interpolate_53 1 1 128 129 0=1 1=2.000000e+00 2=2.000000e+00 6=0 +Concat cat_11 2 1 36 129 130 0=0 +Split splitncnn_6 1 2 130 131 132 +Convolution conv_101 1 1 132 133 0=32 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096 +HardSwish hswish_32 1 1 133 134 0=1.666667e-01 1=5.000000e-01 +Convolution conv_102 1 1 131 135 0=32 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096 +HardSwish hswish_33 1 1 135 136 0=1.666667e-01 1=5.000000e-01 +Convolution conv_103 1 1 134 137 0=32 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=1024 +HardSwish hswish_34 1 1 137 138 0=1.666667e-01 1=5.000000e-01 +ConvolutionDepthWise convdw_143 1 1 138 139 0=32 1=3 11=3 12=1 13=1 14=1 2=1 3=1 4=1 5=1 6=288 7=32 +Convolution conv_104 1 1 139 140 0=32 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=1024 +HardSwish hswish_35 1 1 140 141 0=1.666667e-01 1=5.000000e-01 +Concat cat_12 2 1 141 136 142 0=0 +Convolution conv_105 1 1 142 143 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096 +HardSwish hswish_36 1 1 143 144 0=1.666667e-01 1=5.000000e-01 +ConvolutionDepthWise convdw_144 1 1 126 145 0=128 1=3 11=3 12=1 13=2 14=1 2=1 3=2 4=1 5=1 6=1152 7=128 +Convolution conv_106 1 1 145 146 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=16384 +HardSwish hswish_37 1 1 146 147 0=1.666667e-01 1=5.000000e-01 +Concat cat_13 2 1 147 101 148 0=0 +Split splitncnn_7 1 2 148 149 150 +Convolution conv_107 1 1 150 151 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=32768 +HardSwish hswish_38 1 1 151 152 0=1.666667e-01 1=5.000000e-01 +Convolution conv_108 1 1 149 153 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=32768 +HardSwish hswish_39 1 1 153 154 0=1.666667e-01 1=5.000000e-01 +Convolution conv_109 1 1 152 155 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=16384 +HardSwish hswish_40 1 1 155 156 0=1.666667e-01 1=5.000000e-01 +ConvolutionDepthWise convdw_145 1 1 156 157 0=128 1=3 11=3 12=1 13=1 14=1 2=1 3=1 4=1 5=1 6=1152 7=128 +Convolution conv_110 1 1 157 158 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=16384 +HardSwish hswish_41 1 1 158 159 0=1.666667e-01 1=5.000000e-01 +Concat cat_14 2 1 159 154 160 0=0 +Convolution conv_111 1 1 160 161 0=256 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=65536 +HardSwish hswish_42 1 1 161 162 0=1.666667e-01 1=5.000000e-01 +Convolution conv_112 1 1 144 163 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096 +HardSwish hswish_43 1 1 163 164 0=1.666667e-01 1=5.000000e-01 +ConvolutionDepthWise convdw_146 1 1 164 165 0=64 1=5 11=5 12=1 13=1 14=2 2=1 3=1 4=2 5=1 6=1600 7=64 +ConvolutionDepthWise convdw_147 1 1 165 166 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096 7=2 +HardSwish hswish_44 1 1 166 167 0=1.666667e-01 1=5.000000e-01 +ConvolutionDepthWise convdw_148 1 1 167 168 0=128 1=5 11=5 12=1 13=1 14=2 2=1 3=1 4=2 5=1 6=3200 7=128 +ConvolutionDepthWise convdw_149 1 1 168 169 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=8192 7=2 +HardSwish hswish_45 1 1 169 170 0=1.666667e-01 1=5.000000e-01 +Slice tensor_split_6 1 2 170 171 172 -23300=2,64,-233 1=0 +Split splitncnn_8 1 2 171 173 174 +Convolution conv_114 1 1 174 175 0=8 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=512 +Convolution convsigmoid_0 1 1 173 176 0=1 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=64 9=4 +Convolution convsigmoid_1 1 1 172 177 0=12 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=768 9=4 +Concat cat_15 3 1 175 176 177 178 0=0 +Convolution conv_116 1 1 124 179 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=8192 +HardSwish hswish_46 1 1 179 180 0=1.666667e-01 1=5.000000e-01 +ConvolutionDepthWise convdw_150 1 1 180 181 0=64 1=5 11=5 12=1 13=1 14=2 2=1 3=1 4=2 5=1 6=1600 7=64 +ConvolutionDepthWise convdw_151 1 1 181 182 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096 7=2 +HardSwish hswish_47 1 1 182 183 0=1.666667e-01 1=5.000000e-01 +ConvolutionDepthWise convdw_152 1 1 183 184 0=128 1=5 11=5 12=1 13=1 14=2 2=1 3=1 4=2 5=1 6=3200 7=128 +ConvolutionDepthWise convdw_153 1 1 184 185 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=8192 7=2 +HardSwish hswish_48 1 1 185 186 0=1.666667e-01 1=5.000000e-01 +Slice tensor_split_7 1 2 186 187 188 -23300=2,64,-233 1=0 +Split splitncnn_9 1 2 187 189 190 +Convolution conv_118 1 1 190 191 0=8 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=512 +Convolution convsigmoid_2 1 1 189 192 0=1 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=64 9=4 +Convolution convsigmoid_3 1 1 188 193 0=12 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=768 9=4 +Concat cat_16 3 1 191 192 193 194 0=0 +Convolution conv_120 1 1 162 195 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=16384 +HardSwish hswish_49 1 1 195 196 0=1.666667e-01 1=5.000000e-01 +ConvolutionDepthWise convdw_154 1 1 196 197 0=64 1=5 11=5 12=1 13=1 14=2 2=1 3=1 4=2 5=1 6=1600 7=64 +ConvolutionDepthWise convdw_155 1 1 197 198 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096 7=2 +HardSwish hswish_50 1 1 198 199 0=1.666667e-01 1=5.000000e-01 +ConvolutionDepthWise convdw_156 1 1 199 200 0=128 1=5 11=5 12=1 13=1 14=2 2=1 3=1 4=2 5=1 6=3200 7=128 +ConvolutionDepthWise convdw_157 1 1 200 201 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=8192 7=2 +HardSwish hswish_51 1 1 201 202 0=1.666667e-01 1=5.000000e-01 +Slice tensor_split_8 1 2 202 203 204 -23300=2,64,-233 1=0 +Split splitncnn_10 1 2 203 205 206 +Convolution conv_122 1 1 206 207 0=8 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=512 +Convolution convsigmoid_4 1 1 205 208 0=1 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=64 9=4 +Convolution convsigmoid_5 1 1 204 209 0=12 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=768 9=4 +Concat cat_17 3 1 207 208 209 210 0=0 +Reshape reshape_125 1 1 178 211 0=2704 1=21 +Reshape reshape_126 1 1 194 212 0=676 1=21 +Reshape reshape_127 1 1 210 213 0=169 1=21 +Concat cat_18 3 1 211 212 213 out0 0=1 diff --git a/wust_vision-main/model/reborn_number_classifier.engine b/wust_vision-main/model/reborn_number_classifier.engine new file mode 100644 index 0000000..d9957da Binary files /dev/null and b/wust_vision-main/model/reborn_number_classifier.engine differ diff --git a/wust_vision-main/model/reborn_number_classifier.onnx b/wust_vision-main/model/reborn_number_classifier.onnx new file mode 100644 index 0000000..de8fbb8 Binary files /dev/null and b/wust_vision-main/model/reborn_number_classifier.onnx differ diff --git a/wust_vision-main/model/tiny_resnet.onnx b/wust_vision-main/model/tiny_resnet.onnx new file mode 100644 index 0000000..1c2e6ec Binary files /dev/null and b/wust_vision-main/model/tiny_resnet.onnx differ diff --git a/wust_vision-main/model/yolox_armor3.onnx b/wust_vision-main/model/yolox_armor3.onnx new file mode 100644 index 0000000..cc30693 Binary files /dev/null and b/wust_vision-main/model/yolox_armor3.onnx differ diff --git a/wust_vision-main/read_shm_image_mmap_only.py b/wust_vision-main/read_shm_image_mmap_only.py new file mode 100644 index 0000000..6945c56 --- /dev/null +++ b/wust_vision-main/read_shm_image_mmap_only.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python3 +""" +read_shm_image_mmap_only.py + +仅使用 mmap 读取共享内存的 reader: +- 共享内存 /dev/shm/debug_frame 前 4 bytes = uint32 jpg_size (little endian) + 紧随其后的是 jpg_bytes (jpg_size bytes) +- 不再回退到文件读取;若 mmap 不可用则按 MMAP_RETRY_SECONDS 重试打开 +- 原子写入 out_path.tmp -> out_path +- 支持环境变量覆盖: SHM_PATH, SHM_SIZE, OUT_PATH, FPS, DEBUG, MMAP_RETRY_SECONDS +""" +import os +import sys +import time +import struct +import hashlib +import signal + +try: + import mmap as _mmap +except Exception: + _mmap = None + +from io import BytesIO + +# ---------- 配置(环境变量覆盖) ---------- +SHM_PATH = os.environ.get("SHM_PATH", "/dev/shm/debug_frame") +SHM_SIZE = int(os.environ.get("SHM_SIZE", str(2 * 1024 * 1024))) # bytes +OUT_PATH = os.environ.get("OUT_PATH", "/dev/shm/frame_preview.jpg") +FPS = float(os.environ.get("FPS", "10.0")) +DEBUG = os.environ.get("DEBUG", "0") != "0" +MMAP_RETRY_SECONDS = float(os.environ.get("MMAP_RETRY_SECONDS", "5.0")) +# ------------------------------------------------------------------- + +sleep_interval = 1.0 / max(1.0, FPS) + +def log(*a, **k): + if DEBUG: + print(*a, **k) + +def is_valid_image_bytes(b: bytes) -> bool: + """尝试用 Pillow 验证图像是否完整(若 Pillow 不存在,返回 True 以不阻塞)""" + try: + from PIL import Image + im = Image.open(BytesIO(b)) + im.verify() + return True + except ImportError: + log("Pillow not installed; skipping image verify") + return True + except Exception as e: + log("image verify failed:", e) + return False + +def open_mmap(): + """尝试打开并返回 (fd, mmap_obj) 或 (None, None)""" + if _mmap is None: + return None, None + try: + fd = os.open(SHM_PATH, os.O_RDONLY) + mm = _mmap.mmap(fd, SHM_SIZE, access=_mmap.ACCESS_READ) + log("mmap opened:", SHM_PATH, "size", SHM_SIZE) + return fd, mm + except Exception as e: + log("open_mmap failed:", e) + try: + if 'fd' in locals() and fd: + os.close(fd) + except: + pass + return None, None + +def close_mmap(fd, mm): + try: + if mm: + mm.close() + except: + pass + try: + if fd: + os.close(fd) + except: + pass + +def read_from_mmap(mm): + """ + 从 mmap 读取一帧: + 结构: [uint32 little-endian jpg_size][jpg_bytes...] + 返回 bytes 或 None + """ + try: + mm.seek(0) + hdr = mm.read(4) + if len(hdr) < 4: + return None + jpg_size = struct.unpack(" (SHM_SIZE - 4): + log("invalid jpg_size:", jpg_size) + return None + jpg_bytes = mm.read(jpg_size) + if len(jpg_bytes) != jpg_size: + log(f"mmap truncated: expected {jpg_size} got {len(jpg_bytes)}") + return None + # quick header check + if not (jpg_bytes.startswith(b"\xff\xd8\xff")): + log("jpg header mismatch") + return None + return jpg_bytes + except ValueError as e: + log("mmap read ValueError:", e) + return None + except Exception as e: + log("mmap read exception:", e) + raise + +def atomic_write(out_path, data): + tmp = out_path + ".tmp" + with open(tmp, "wb") as f: + f.write(data) + os.replace(tmp, out_path) + +def loop(): + if _mmap is None: + print("ERROR: mmap module not available in this Python environment. Exiting.", file=sys.stderr) + sys.exit(2) + + fd = None + mm = None + last_hash = None + last_mmap_try = 0.0 + + # 初始尝试打开 mmap(若失败则进入重试循环) + fd, mm = open_mmap() + + print(f"watching (mmap_only) '{SHM_PATH}' -> '{OUT_PATH}' @ {FPS} fps") + while True: + try: + jpg_bytes = None + + # 尝试 mmap 读取(若 mm 为 None 则尝试重建) + if mm is not None: + try: + jpg_bytes = read_from_mmap(mm) + if jpg_bytes is None: + # 写端可能尚未写入完整帧,短暂等待 + pass + else: + log("read frame from mmap, len=", len(jpg_bytes)) + except Exception as e: + log("mmap error, closing and will retry:", e) + close_mmap(fd, mm) + fd, mm = None, None + last_mmap_try = time.time() + + # 如果没有读到数据并且 mmap 不可用或不可读,则重试打开 mmap(按间隔) + if jpg_bytes is None: + now = time.time() + if mm is None and (now - last_mmap_try) >= MMAP_RETRY_SECONDS: + fd, mm = open_mmap() + last_mmap_try = now + time.sleep(sleep_interval) + continue + + # 校验完整性(PIL verify) + if not is_valid_image_bytes(jpg_bytes): + log("image bytes failed verify, skipping") + time.sleep(sleep_interval) + continue + + # 内容哈希,避免重复写磁盘 + h = hashlib.sha256(jpg_bytes).hexdigest() + if h != last_hash: + atomic_write(OUT_PATH, jpg_bytes) + last_hash = h + log("wrote", OUT_PATH, "len=", len(jpg_bytes)) + # else: skip (no change) + + time.sleep(sleep_interval) + + except KeyboardInterrupt: + print("terminated by user") + break + except Exception as e: + log("unexpected error:", e) + try: + if mm is not None: + close_mmap(fd, mm) + except: + pass + fd, mm = None, None + time.sleep(1.0) + + # 退出前清理 + close_mmap(fd, mm) + +def _handle_sigterm(signum, frame): + print("terminated", file=sys.stderr) + sys.exit(0) + +if __name__ == "__main__": + signal.signal(signal.SIGINT, _handle_sigterm) + signal.signal(signal.SIGTERM, _handle_sigterm) + loop() diff --git a/wust_vision-main/ros2/ros2.hpp b/wust_vision-main/ros2/ros2.hpp new file mode 100644 index 0000000..48802ff --- /dev/null +++ b/wust_vision-main/ros2/ros2.hpp @@ -0,0 +1,136 @@ +#pragma once + +#include "geometry_msgs/msg/transform_stamped.hpp" +#include "rclcpp/rclcpp.hpp" +#include "tf2_ros/buffer.h" +#include "tf2_ros/transform_broadcaster.h" +#include "tf2_ros/transform_listener.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class Ros2Node: public rclcpp::Node { +public: + explicit Ros2Node(const std::string& node_name = "vision"): + Node(node_name), + tf_buffer_(this->get_clock()), + tf_listener_(tf_buffer_), + tf_broadcaster_(std::make_shared(this)) {} + using Ptr = std::shared_ptr; + + static Ptr instance() { + static Ptr inst = std::make_shared(); + return inst; + } + ~Ros2Node() { + stop(); + } + + template + void add_subscription( + const std::string& topic_name, + std::function callback, + const rclcpp::QoS& qos = rclcpp::QoS(rclcpp::KeepLast(10)) + ) { + auto sub = this->create_subscription(topic_name, qos, callback); + std::lock_guard lock(map_mutex_); + subscriptions_[topic_name] = sub; + } + + template + void add_publisher( + const std::string& topic_name, + const rclcpp::QoS& qos = rclcpp::QoS(rclcpp::KeepLast(10)) + ) { + auto pub = this->create_publisher(topic_name, qos); + std::lock_guard lock(map_mutex_); + publishers_[topic_name] = pub; + } + + template + void publish(const std::string& topic_name, const MsgT& msg) { + std::lock_guard lock(map_mutex_); + auto it = publishers_.find(topic_name); + if (it != publishers_.end()) { + auto typed_pub = std::dynamic_pointer_cast>(it->second); + if (typed_pub) + typed_pub->publish(msg); + } + } + + template + void + add_timer(const std::chrono::duration& interval, std::function callback) { + auto timer = this->create_wall_timer(interval, callback); + std::lock_guard lock(map_mutex_); + timers_.push_back(timer); + } + + bool lookup_transform( + const std::string& target_frame, + const std::string& source_frame, + geometry_msgs::msg::TransformStamped& out_tf, + std::chrono::milliseconds timeout = std::chrono::milliseconds(200) + ) const { + try { + out_tf = + tf_buffer_.lookupTransform(target_frame, source_frame, tf2::TimePointZero, timeout); + return true; + } catch (const tf2::TransformException& ex) { + RCLCPP_WARN( + this->get_logger(), + "lookup_transform failed %s -> %s: %s", + source_frame.c_str(), + target_frame.c_str(), + ex.what() + ); + return false; + } + } + + void broadcast_tf(const geometry_msgs::msg::TransformStamped& tf_msg) { + tf_broadcaster_->sendTransform(tf_msg); + } + + template + void broadcast_tf_periodic( + const geometry_msgs::msg::TransformStamped& tf_msg, + const std::chrono::duration& interval + ) { + add_timer(interval, [this, tf_msg]() { broadcast_tf(tf_msg); }); + } + + void start() { + if (!spin_thread_.joinable()) { + auto node = shared_from_this(); + spin_thread_ = std::thread([node]() { rclcpp::spin(node); }); + } + } + + void stop() { + rclcpp::shutdown(); + if (spin_thread_.joinable()) { + spin_thread_.join(); + } + } + +private: + mutable std::mutex map_mutex_; + + std::unordered_map subscriptions_; + std::unordered_map publishers_; + std::vector timers_; + + std::thread spin_thread_; + + tf2_ros::Buffer tf_buffer_; + tf2_ros::TransformListener tf_listener_; + std::shared_ptr tf_broadcaster_; +}; \ No newline at end of file diff --git a/wust_vision-main/ros2/tf.hpp b/wust_vision-main/ros2/tf.hpp new file mode 100644 index 0000000..e98f8bf --- /dev/null +++ b/wust_vision-main/ros2/tf.hpp @@ -0,0 +1,88 @@ +#pragma once +#include "rclcpp/rclcpp.hpp" +#include +#include +#include +#include +#include +template +inline bool publisherSubscribed(const PublisherT& publisher) noexcept { + return publisher && publisher->get_subscription_count() > 0; +} +inline Eigen::Isometry3f tf2ToEigen(const geometry_msgs::msg::TransformStamped& tf) noexcept { + Eigen::Isometry3f T = Eigen::Isometry3f::Identity(); + + const auto& t = tf.transform.translation; + T.translation() = + Eigen::Vector3f(static_cast(t.x), static_cast(t.y), static_cast(t.z)); + + const auto& q = tf.transform.rotation; + + Eigen::Quaternionf Q( + static_cast(q.w), + static_cast(q.x), + static_cast(q.y), + static_cast(q.z) + ); + + Q.normalize(); + + T.linear() = Q.toRotationMatrix(); + + return T; +} +class TF { +public: + using Ptr = std::shared_ptr; + TF(rclcpp::Node& n) { + tf_buffer_ = std::make_unique(n.get_clock()); + tf_listener_ = std::make_unique(*tf_buffer_); + tf_broadcaster_ = std::make_unique(n); + node_ = &n; + } + static Ptr create(rclcpp::Node& n) { + return std::make_shared(n); + } + std::optional + getTransform(const std::string& target, const std::string& source, rclcpp::Time t) + const noexcept { + Eigen::Isometry3f T_out = Eigen::Isometry3f::Identity(); + try { + geometry_msgs::msg::TransformStamped tf_msg = + tf_buffer_->lookupTransform(target, source, t, rclcpp::Duration::from_seconds(0.1)); + + T_out = tf2ToEigen(tf_msg); + return T_out; + } catch (tf2::TransformException& ex) { + RCLCPP_WARN(rclcpp::get_logger("tf"), "TF lookup failed: %s", ex.what()); + return std::nullopt; + } + return T_out; + } + + void publishTransform( + const Eigen::Isometry3d& transform, + const std::string& parent_frame, + const std::string& child_frame, + const rclcpp::Time& stamp + ) const noexcept { + geometry_msgs::msg::TransformStamped tmsg; + tmsg.header.stamp = stamp; + tmsg.header.frame_id = parent_frame; + tmsg.child_frame_id = child_frame; + const Eigen::Vector3d tr = transform.translation(); + const Eigen::Quaterniond q(transform.rotation()); + tmsg.transform.translation.x = tr.x(); + tmsg.transform.translation.y = tr.y(); + tmsg.transform.translation.z = tr.z(); + tmsg.transform.rotation.x = q.x(); + tmsg.transform.rotation.y = q.y(); + tmsg.transform.rotation.z = q.z(); + tmsg.transform.rotation.w = q.w(); + tf_broadcaster_->sendTransform(tmsg); + } + rclcpp::Node* node_; + std::unique_ptr tf_buffer_; + std::unique_ptr tf_listener_; + std::unique_ptr tf_broadcaster_; +}; diff --git a/wust_vision-main/run.sh b/wust_vision-main/run.sh new file mode 100755 index 0000000..fde6bab --- /dev/null +++ b/wust_vision-main/run.sh @@ -0,0 +1,161 @@ +#!/bin/bash + +WORK_DIR="$(dirname "$(realpath "${BASH_SOURCE[0]}")")" +BUILD_DIR="$WORK_DIR/build" +CONFIG_DIR="$WORK_DIR/config" +BIN_DIR="$WORK_DIR/bin" +source "$WORK_DIR/env.bash" +export VISION_ROOT="$WORK_DIR" +export MVCAM_SDK_PATH=/opt/MVS +export MVCAM_COMMON_RUNENV=/opt/MVS/lib +export MVCAM_GENICAM_CLPROTOCOL=/opt/MVS/lib/CLProtocol +export ALLUSERSPROFILE=/opt/MVS/MVFG +export LD_LIBRARY_PATH=/opt/MVS/lib/64:/opt/MVS/lib/32:$LD_LIBRARY_PATH +blue="\033[1;34m" +yellow="\033[1;33m" +reset="\033[0m" +red="\033[1;31m" + +if [ "$EUID" -eq 0 ]; then + USER_HOME=$(getent passwd $SUDO_USER | cut -d: -f6) + COPY_BASHRC="$WORK_DIR/user_bashrc_copy.bash" + + if [ -f "$USER_HOME/.bashrc" ]; then + # 复制 bashrc 到 WORK_DIR,并删除前10行 + tail -n +11 "$USER_HOME/.bashrc" > "$COPY_BASHRC" + # 设置权限,普通用户可读 + chmod 644 "$COPY_BASHRC" + chown $SUDO_USER:$SUDO_USER "$COPY_BASHRC" + + echo -e "${yellow}Copied ~/.bashrc to $COPY_BASHRC with first 10 lines removed${reset}" + + # 加载复制的 bashrc + source "$COPY_BASHRC" + echo -e "${yellow}Loaded bashrc from copy${reset}" + else + echo -e "${red}Original ~/.bashrc not found: $USER_HOME/.bashrc${reset}" + source "$COPY_BASHRC" + fi +else + # 普通用户直接加载原 bashrc + if [ -f "$HOME/.bashrc" ]; then + source "$HOME/.bashrc" + echo -e "${yellow}Loaded bashrc from $HOME/.bashrc${reset}" + fi +fi + + +chmod 777 /dev/shm/debug_frame + + +rm -rf "$BIN_DIR/config" +ln -sf "$CONFIG_DIR" "$BIN_DIR/config" +ln -sf "$WORK_DIR/env.bash" "$BUILD_DIR/env.bash" + +if [ "$1" == "rebuild" ]; then + echo -e "${yellow}<--- Rebuilding: This will REMOVE the entire build directory --->${reset}" + read -p "Are you sure you want to rebuild? [y/N]: " confirm + confirm=${confirm,,} + if [[ "$confirm" != "y" && "$confirm" != "yes" ]]; then + echo -e "${red}Rebuild cancelled.${reset}" + exit 0 + fi + echo -e "${yellow}<--- Removing build directory --->${reset}" + rm -rf "$BUILD_DIR" + mkdir -p "$BUILD_DIR" +else + mkdir -p "$BUILD_DIR" +fi + +if [[ "$1" == "build" || "$1" == "rebuild" || "$1" == "run" ]]; then + + echo -e "${yellow}<--- Start CMake (Ninja) --->${reset}" + cmake -S "$WORK_DIR" -B "$BUILD_DIR" \ + -G Ninja \ + + if [ $? -ne 0 ]; then + echo -e "${red}\n--- CMake Failed ---${reset}" + exit 1 + fi + SECONDS=0 + echo -e "${yellow}\n<--- Start Ninja Build --->${reset}" + ninja -C "$BUILD_DIR" + if [ $? -ne 0 ]; then + echo -e "${red}\n--- Ninja Build Failed ---${reset}" + exit 1 + fi + + build_time=$SECONDS + printf "${blue}\n<--- Build Time --->\n %02d:%02d (mm:ss)\n${reset}" \ + $((build_time / 60)) $((build_time % 60)) + echo -e "${yellow}\n<--- Total Lines --->${reset}" + total=$(find "$WORK_DIR" \ + -type d \( \ + -path "$BUILD_DIR" -o \ + -path "$WORK_DIR/model" -o \ + -path "$WORK_DIR/3rdparty" -o \ + -path "$WORK_DIR/.cache" \ + \) -prune -o \ + -type f \( \ + -name "*.cpp" -o -name "*.hpp" -o -name "*.c" -o -name "*.h" \ + -o -name "*.py" -o -name "*.html" -o -name "*.sh" -o -name "*.md" \ + -o -name "*.yaml" -o -name "*.json" -o -name "*.css" -o -name "*.js" \ + -o -name "*.cu" -o -name "*.txt" \ + \) -exec wc -l {} + | awk 'END{print $1}') + echo -e "${blue} $total${reset}" + + # Only build + if [ "$1" == "build" ] || [ "$1" == "rebuild" ]; then + echo -e "${yellow}\n<--- Only building --->${reset}" + echo -e "${yellow}<----- OVER ----->${reset}" + exit 0 + fi + + # Run mode + if [ "$1" == "run" ]; then + echo -e "${yellow}\n<--- Running WUST_VISION --->${reset}" + RUN_PROGRAM="$BIN_DIR/$2" + sudo pkill -9 $2 + ORIGINAL_ARGS=("$@") + shift 2 + + "$RUN_PROGRAM" "$@" + RET=$? + set -- "${ORIGINAL_ARGS[@]}" + + if [ $RET -ne 0 ]; then + echo -e "${red}\n--- Program crashed, running guard.sh ---${reset}" + + pkill "$2" + timeout=10 + while pgrep "$2" > /dev/null; do + sleep 0.5 + timeout=$((timeout - 1)) + if [ $timeout -le 0 ]; then + echo "$2 did not exit after 10 seconds, forcing kill" + pkill -9 "$2" + break + fi + done + + GUARD_SCRIPT="$CONFIG_DIR/guard.sh" + TARGET_PATH="$RUN_PROGRAM" + + if [ ! -f "$GUARD_SCRIPT" ]; then + echo -e "${red}guard.sh not found: $GUARD_SCRIPT${reset}" + exit 1 + fi + + echo -e "${yellow}Starting guard.sh ...${reset}" + exec "$GUARD_SCRIPT" "$TARGET_PATH" "$@" + fi + fi + + echo -e "${yellow}<----- OVER ----->${reset}" + +else + echo -e "${yellow}Warning:${reset} Invalid argument '$1'." + echo -e "${yellow}Usage:${reset} $0 {build|rebuild|run [args...]}" + echo -e "${yellow}No action performed.${reset}" + exit 0 +fi diff --git a/wust_vision-main/script/install_depences.sh b/wust_vision-main/script/install_depences.sh new file mode 100755 index 0000000..822a88e --- /dev/null +++ b/wust_vision-main/script/install_depences.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +sudo apt update && sudo apt install -y \ + rsync ninja-build \ + libfmt-dev libceres-dev libeigen3-dev \ + nlohmann-json3-dev libyaml-cpp-dev diff --git a/wust_vision-main/script/rsync.sh b/wust_vision-main/script/rsync.sh new file mode 100755 index 0000000..d2b9893 --- /dev/null +++ b/wust_vision-main/script/rsync.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +WORK_DIR="$(dirname "$SCRIPT_DIR")" +echo "脚本目录: $SCRIPT_DIR" +echo "wust_vision目录: $WORK_DIR" +# 用法说明 +if [ $# -ne 2 ]; then + echo "Usage: $0 " + echo "Example:" + echo " $0 nvidiaa 192.168.10.100" + exit 1 +fi + + +REMOTE_USER="$1" +REMOTE_IP="$2" +TARGET_PATH="/home/${REMOTE_USER}/wust_vision" +rsync -avz \ + --exclude='.cache/' \ + --exclude='.vscode/' \ + --exclude='.git/' \ + --exclude='bin/' \ + --exclude='build/' \ + --exclude='model/' \ + --exclude='config/' \ + --exclude='CMakeLists.txt' \ + "${WORK_DIR}/" \ + "${REMOTE_USER}@${REMOTE_IP}:${TARGET_PATH}/" diff --git a/wust_vision-main/script/setup_devenv.sh b/wust_vision-main/script/setup_devenv.sh new file mode 100755 index 0000000..d47a827 --- /dev/null +++ b/wust_vision-main/script/setup_devenv.sh @@ -0,0 +1,65 @@ +#!/usr/bin/env bash +set -euo pipefail + +BASHRC="$HOME/.bashrc" +MARKER_START="# >>> wust_vision dev >>>" +MARKER_END="# <<< wust_vision dev <<<" +BACKUP="$BASHRC.bak.$(date +%Y%m%d_%H%M%S)" + +# 备份 +cp "$BASHRC" "$BACKUP" +echo "💾 已备份 $BASHRC 到 $BACKUP" + +# 临时文件 +TMP=$(mktemp) + +# 写入 bashrc 除去旧 block +awk -v start="$MARKER_START" -v end="$MARKER_END" ' + $0 == start {inside=1; next} + inside && $0 == end {inside=0; next} + !inside {print} +' "$BASHRC" > "$TMP" + +# 使用 cat <> "$TMP" +# >>> wust_vision dev >>> +py() { + source ~/anaconda3/etc/profile.d/conda.sh + echo "已激活 conda, conda activate激活环境,conda deactivate退出环境, conda info --envs查看环境" +} + +buildme() { + colcon build --symlink-install --cmake-args -DCMAKE_BUILD_TYPE=Release +} + +builddebug() { + colcon build --packages-select "$1" --cmake-args -DCMAKE_BUILD_TYPE=Debug +} + +killros() { + pkill -f ros && ros2 daemon stop && ros2 daemon start +} + +format() { + find . -path ./build -prune -o \ + -type f \( -name '*.h' -o -name '*.hpp' -o -name '*.c' -o -name '*.cu' -o -name '*.cpp' \) \ + -exec clang-format -i {} + +} + +s() { + source install/setup.bash +} + +hik() { + export MVCAM_SDK_PATH=/opt/MVS + export MVCAM_COMMON_RUNENV=/opt/MVS/lib + export MVCAM_GENICAM_CLPROTOCOL=/opt/MVS/lib/CLProtocol + export ALLUSERSPROFILE=/opt/MVS/MVFG + export LD_LIBRARY_PATH=/opt/MVS/lib/64:/opt/MVS/lib/32:$LD_LIBRARY_PATH +} +# <<< wust_vision dev <<< +EOF + +# 原子替换 +mv "$TMP" "$BASHRC" +echo "✅ 已安全替换或追加 block 到 $BASHRC" diff --git a/wust_vision-main/script/setup_service.sh b/wust_vision-main/script/setup_service.sh new file mode 100755 index 0000000..6499027 --- /dev/null +++ b/wust_vision-main/script/setup_service.sh @@ -0,0 +1,109 @@ +#!/usr/bin/env bash +# 自动管理 systemd service 文件 +# 工作目录 = 脚本所在路径的上一级目录 + +SERVICE_NAME="wust_vision" +SERVICE_FILE="/etc/systemd/system/${SERVICE_NAME}.service" +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +WORK_DIR="$(dirname "$SCRIPT_DIR")" + +ACTION=$1 +echo "正在操作: standard" +echo "📂 脚本路径: $SCRIPT_DIR" +echo "🏗 工作区路径: $WORK_DIR" +echo "🧾 目标 Service 文件: $SERVICE_FILE" +echo + +uninstall_service() { + if [[ ! -f "$SERVICE_FILE" ]]; then + echo "⚠️ Service 文件不存在,无需卸载。" + exit 0 + fi + + echo "🛑 停止并卸载 Service..." + sudo systemctl stop "${SERVICE_NAME}.service" 2>/dev/null || true + sudo systemctl disable "${SERVICE_NAME}.service" 2>/dev/null || true + sudo rm -f "$SERVICE_FILE" + sudo systemctl daemon-reload + + echo "✅ Service 已成功卸载。" + exit 0 +} + +install_service() { + if [[ -f "$SERVICE_FILE" ]]; then + echo "⚠️ 检测到已存在的 Service 文件:$SERVICE_FILE" + read -p "是否覆盖?(y/N): " confirm + if [[ "$confirm" != "y" && "$confirm" != "Y" ]]; then + echo "🚫 已取消安装。" + exit 0 + else + echo "🧹 删除旧版本 Service..." + sudo systemctl stop ${SERVICE_NAME}.service 2>/dev/null || true + sudo systemctl disable ${SERVICE_NAME}.service 2>/dev/null || true + sudo rm -f "$SERVICE_FILE" + fi + fi + + echo "✏️ 正在生成新的 Service 文件..." + sudo tee "$SERVICE_FILE" > /dev/null </dev/null || true + sudo systemctl disable "${SERVICE_NAME}.service" 2>/dev/null || true + sudo rm -f "$SERVICE_FILE" + sudo systemctl daemon-reload + + echo "✅ Service 已成功卸载。" + exit 0 +} + +install_service() { + if [[ -f "$SERVICE_FILE" ]]; then + echo "⚠️ 检测到已存在的 Service 文件:$SERVICE_FILE" + read -p "是否覆盖?(y/N): " confirm + if [[ "$confirm" != "y" && "$confirm" != "Y" ]]; then + echo "🚫 已取消安装。" + exit 0 + else + echo "🧹 删除旧版本 Service..." + sudo systemctl stop ${SERVICE_NAME}.service 2>/dev/null || true + sudo systemctl disable ${SERVICE_NAME}.service 2>/dev/null || true + sudo rm -f "$SERVICE_FILE" + fi + fi + + echo "✏️ 正在生成新的 Service 文件..." + sudo tee "$SERVICE_FILE" > /dev/null < /dev/null <("INFO"); + std::string log_path_ = config_["logger"]["log_path"].as("wust_log"); + bool use_logcli = config_["logger"]["use_logcli"].as(); + bool use_logfile = config_["logger"]["use_logfile"].as(); + bool use_simplelog = config_["logger"]["use_simplelog"].as(); + wust_vl::initLogger(log_level_, log_path_, use_logcli, use_logfile, use_simplelog); + + YAML::Node camera_config = config_["camera"]; + camera_ = std::make_unique(); + camera_->init(camera_config); + camera_->setFrameCallback(std::bind(&vision::frameCallback, this, std::placeholders::_1)); + std::string camera_info_path = + utils::expandEnv(camera_config["camera_info_path"].as()); + YAML::Node config_camera_info = YAML::LoadFile(camera_info_path); + std::vector camera_k = + config_camera_info["camera_matrix"]["data"].as>(); + std::vector camera_d = + config_camera_info["distortion_coefficients"]["data"].as>(); + + assert(camera_k.size() == 9); + assert(camera_d.size() == 5); + + cv::Mat K(3, 3, CV_64F); + std::memcpy(K.data, camera_k.data(), 9 * sizeof(double)); + + cv::Mat D(1, 5, CV_64F); + std::memcpy(D.data, camera_d.data(), 5 * sizeof(double)); + + auto camera_info = std::make_pair(K.clone(), D.clone()); + camera_info_ = camera_info; + auto_guidance_ = auto_guidance::AutoGuidance::create(); + auto_guidance_->setDebug(debug_mode); + auto_guidance_->init(config_, camera_info); + + max_infer_running_ = config_["max_infer_running"].as(); + thread_pool_ = std::make_unique( + std::thread::hardware_concurrency() * 2 + ); + std::string device_name = config_["control"]["device_name"].as(); + + serial_ = std::make_shared(); + bool use_serial = config_["control"]["use_serial"].as(); + if (use_serial) { + wust_vl::common::drivers::SerialDriver::SerialPortConfig cfg { + /*baud*/ 115200, + /*csize*/ 8, + boost::asio::serial_port_base::parity::none, + boost::asio::serial_port_base::stop_bits::one, + boost::asio::serial_port_base::flow_control::none + }; + serial_->init_port(device_name, cfg); + serial_->set_receive_callback(std::bind( + &vision::serialCallback, + this, + std::placeholders::_1, + std::placeholders::_2 + + )); + std::cout << "Serial port opened" << std::endl; + serial_->set_error_callback([&](const boost::system::error_code& ec) { + WUST_ERROR("serial") << "serial error: " << ec.message(); + }); + } + + timer_ = std::make_unique(); + WUST_MAIN("vision") << "starting vision task"; + return true; + } + ~vision() { + run_flag_ = false; + camera_->stop(); + thread_pool_->waitUntilEmpty(); + if (debug_thread_.joinable()) { + debug_thread_.join(); + } + } + void start() { + run_flag_ = true; + camera_->start(); + auto_guidance_->start(); + if (serial_) { + serial_->start(); + } + if (timer_) { + auto timercallback = std::bind(&vision::timerCallback, this, std::placeholders::_1); + double rate_hz = static_cast(config_["control"]["control_rate"].as()); + timer_->start(rate_hz, timercallback); + } + if (debug_mode_) { + debug_thread_ = std::thread([this]() { this->debugThread(); }); + } + } + void serialCallback(const uint8_t* data, std::size_t len) {} + void frameCallback(wust_vl::video::ImageFrame& frame) { + if (!run_flag_ || infer_running_count_ >= max_infer_running_) { + return; + } + CommonFrame common_frame; + if (frame.src_img.empty()) { + return; + } + common_frame.img_frame = std::move(frame); + common_frame.expanded = cv::Rect( + 0, + 0, + common_frame.img_frame.src_img.cols, + common_frame.img_frame.src_img.rows + ); + common_frame.offset = cv::Point2f(0, 0); + thread_pool_->enqueue([this, frame = std::move(common_frame)]() mutable { + infer_running_count_++; + if (frame.img_frame.src_img.data == nullptr) { + return; + } + if (frame.img_frame.src_img.empty()) { + return; + } + if (auto_guidance_) { + auto_guidance_->pushInput(frame); + } + infer_running_count_--; + }); + } + void timerCallback(double dt_ms) { + if (!run_flag_) { + return; + } + auto target = auto_guidance_->getTarget(); + cx_norm_ = target.center().x / target.image_size_.width * 2.0 - 1.0; + SendDartCmdData send_data; + // send_data.cmd_ID = ID_ROBOT_CMD; + // send_data.time_stamp = std::chrono::duration_cast( + // std::chrono::steady_clock::now().time_since_epoch() + // ) + // .count(); + // send_data.appear = target.is_tracking_; + send_data.diff_center_norm = (target.is_tracking_) ? cx_norm_ : 0; + if (serial_) { + serial_->write(std::move(wust_vl::common::drivers::toVector(send_data))); + } + } + + void debugThread() { + using namespace std::chrono; + + double us_interval = 1e6 / static_cast(30.0); + auto kInterval = std::chrono::microseconds(static_cast(us_interval)); + while (run_flag_) { + auto start_time = steady_clock::now(); + try { + auto dbg = auto_guidance_->getDebug(); + + drawDebugOverlayShm(dbg, false); + debuglog(dbg.target); + } catch (std::exception& e) { + std::cout << "debug thread error: " << e.what() << std::endl; + } + + auto elapsed = steady_clock::now() - start_time; + if (elapsed < kInterval) { + std::this_thread::sleep_for(kInterval - elapsed); + } + } + } + void checkStateMatchMode() {} + YAML::Node config_; + std::unique_ptr thread_pool_; + std::unique_ptr auto_guidance_; + std::unique_ptr camera_; + std::unique_ptr timer_; + std::shared_ptr serial_; + std::atomic infer_running_count_ { 0 }; + bool run_flag_ = false; + int max_infer_running_; + bool debug_mode_ = false; + std::pair camera_info_; + std::thread debug_thread_; + double cx_norm_; +}; +} // namespace wust_vision +VISION_MAIN(wust_vision::vision) \ No newline at end of file diff --git a/wust_vision-main/src/hero.cpp b/wust_vision-main/src/hero.cpp new file mode 100644 index 0000000..156d913 --- /dev/null +++ b/wust_vision-main/src/hero.cpp @@ -0,0 +1,35 @@ +#include "ros2/ros2.hpp" +#include "tasks/auto_sniper/auto_sniper.hpp" +#include "tasks/type_common.hpp" +#include "tasks/utils/config.hpp" +#include "tasks/utils/main_base.hpp" +#include "tasks/vision_base.hpp" +ENABLE_BACKWARD() + +namespace wust_vision { +class vision: public VisionBase { +public: + vision(): VisionBase(COMMON_CONFIG, CAMERA_CONFIG, AUTO_AIM_CONFIG, AUTO_BUFF_CONFIG) {} + + bool init(bool debug_mode) { + VisionBase::init(debug_mode); + auto auto_aim = + auto_aim::AutoAim::create(auto_aim_config_, tf_config_, camera_info_, debug_mode); + modules_.emplace(HeroMode::AttackMode::ARMOR, auto_aim); + modules_.emplace(HeroMode::AttackMode::UNKNOWN, auto_aim); + rclcpp::init(0, nullptr); + ros2_ = std::make_shared("vison_node"); + auto auto_sniper = auto_sniper::AutoSniper::create(*ros2_, motion_buffer_); + modules_.emplace(HeroMode::AttackMode::SNIPER, auto_sniper); + return true; + } + + void start() { + VisionBase::start(); + ros2_->start(); + } + + std::shared_ptr ros2_; +}; +} // namespace wust_vision +VISION_MAIN(wust_vision::vision) \ No newline at end of file diff --git a/wust_vision-main/src/sentry.cpp b/wust_vision-main/src/sentry.cpp new file mode 100644 index 0000000..a423dc6 --- /dev/null +++ b/wust_vision-main/src/sentry.cpp @@ -0,0 +1,302 @@ +#include "geometry_msgs/msg/twist.hpp" +#include "ros2/ros2.hpp" +#include "sentry_interfaces/msg/detail/mode__struct.hpp" +#include "sentry_interfaces/msg/detail/robo_state__struct.hpp" +#include "sentry_interfaces/msg/detail/target__struct.hpp" +#include "sentry_interfaces/msg/mode.hpp" +#include "sentry_interfaces/msg/robo_state.hpp" +#include "sentry_interfaces/msg/target.hpp" +#include "tasks/auto_aim/armor_control/very_aimer.hpp" +#include "tasks/auto_aim/armor_omni/armor_omni.hpp" +#include "tasks/auto_aim/auto_aim.hpp" +#include "tasks/packet_typedef.hpp" +#include "tasks/type_common.hpp" +#include "tasks/utils/config.hpp" +#include "tasks/utils/main_base.hpp" +#include "tasks/vision_base.hpp" +#include "visualization_msgs/msg/marker.hpp" +#include +#include +#include +#include +ENABLE_BACKWARD() +namespace wust_vision { +class vision: public VisionBase { +public: + static constexpr const char* TARGET_MARKER = "target_marker"; + + vision(): VisionBase(COMMON_CONFIG, CAMERA_CONFIG, AUTO_AIM_CONFIG, AUTO_BUFF_CONFIG) {} + ~vision() { + armor_omni_.reset(); + } + bool init(bool debug_mode) { + VisionBase::init(debug_mode); + auto auto_aim = + auto_aim::AutoAim::create(auto_aim_config_, tf_config_, camera_info_, debug_mode); + modules_.emplace(InfantryMode::AttackMode::ARMOR, auto_aim); + modules_.emplace(InfantryMode::AttackMode::UNKNOWN, auto_aim); + auto auto_buff = + auto_buff::AutoBuff::create(auto_buff_config_, tf_config_, camera_info_, debug_mode); + modules_.emplace(InfantryMode::AttackMode::BIG_RUNE, auto_buff); + modules_.emplace(InfantryMode::AttackMode::SMALL_RUNE, auto_buff); + + serial_->set_receive_callback( + std::bind(&vision::serialCallback, this, std::placeholders::_1, std::placeholders::_2) + ); + timer_B_ = std::make_unique("10hz"); + big_yaw_motion_buffer_ = + std::make_shared>(); + // auto very_aimer_copy = std::make_shared(*auto_aim->getVeryAimer()); + auto_aim::ArmorOmni::Ctx omni_ctx = { + .car_motion_buffer = motion_buffer_, + .big_yaw_motion_buffer = big_yaw_motion_buffer_, + .very_aimer = auto_aim->getVeryAimer(), + }; + armor_omni_ = auto_aim::ArmorOmni::create(detect_color_, omni_ctx); + rclcpp::init(0, nullptr); + ros2_ = std::make_shared("vison_node"); + ros2_->add_subscription( + "cmd_vel", + std::bind(&vision::twistCb, this, std::placeholders::_1), + rclcpp::QoS(10) + ); + ros2_->add_subscription( + MODE_TOPIC, + std::bind(&vision::modeCb, this, std::placeholders::_1) + ); + ros2_->add_publisher(TARGET_MARKER); + ros2_->add_publisher(TARGET_TOPIC); + ros2_->add_publisher(ROBO_STATE_TOPIC); + return true; + } + + void start() { + VisionBase::start(); + if (timer_) { + const auto timercallback = + std::bind(&vision::timerCallback, this, std::placeholders::_1); + const double rate_hz = control_config_->control_rate_param.get(); + timer_->start(rate_hz, timercallback); + } + if (timer_B_) { + const auto timercallback = + std::bind(&vision::timerBCallback, this, std::placeholders::_1); + const double rate_hz = 10.0; + timer_B_->start(rate_hz, timercallback); + } + armor_omni_->start(); + ros2_->start(); + } + void timerCallback(double dt_ms) { + if (!run_flag_) { + return; + } + + GimbalCmd cmd; + try { + InfantryMode::AttackMode mode = InfantryMode::toAttackMode(attack_mode_); + auto module = modules_.at(mode); + if (!module) { + return; + } + cmd = module->solve(bullet_speed_); + } catch (const std::exception& e) { + std::cout << "solve error: " << e.what() << std::endl; + } + if (!cmd.isValid()) { + return; + } + if (!cmd.appear && armor_omni_) { + cmd = armor_omni_->solve(bullet_speed_); + } + last_cmd_ = cmd; + + double cmd_pitch = cmd.pitch; + double cmd_yaw = cmd.yaw; + SendRobotCmdData send_data; + send_data.cmd_ID = ID_ROBOT_CMD; + send_data.time_stamp = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch() + ) + .count(); + if (cmd.distance > 0.5) { + send_data.appear = cmd.appear; + } else { + send_data.appear = false; + } + + send_data.detect_color = detect_color_; + send_data.pitch = cmd_pitch + cmd.v_pitch * control_config_->pitch_ramp_param.get(); + send_data.yaw = cmd_yaw + cmd.v_yaw * control_config_->yaw_ramp_param.get(); + send_data.v_pitch = cmd.v_pitch; + send_data.v_yaw = cmd.v_yaw; + send_data.a_pitch = cmd.a_pitch; + send_data.a_yaw = cmd.a_yaw; + send_data.target_yaw = cmd.target_yaw; + send_data.target_pitch = cmd.target_pitch; + send_data.enable_pitch_diff = cmd.enable_pitch_diff; + send_data.enable_yaw_diff = cmd.enable_yaw_diff; + send_data.shoot_rate = shoot_config_->rate_param.get(); + if (serial_) { + serial_->write(std::move(wust_vl::common::drivers::toVector(send_data))); + } + } + void timerBCallback(double dt_ms) { + InfantryMode::AttackMode mode = InfantryMode::toAttackMode(attack_mode_); + do { + if (mode == InfantryMode::AttackMode::ARMOR) { + auto auto_aim = auto_aim::toAutoAim(modules_.at(mode)); + if (auto_aim) { + auto target = auto_aim->getTarget(); + if (armor_omni_) { + armor_omni_->setDetectColor(detect_color_); + armor_omni_->updateMainTracking(target.checkTargetAppear()); + } + + if (!target.checkTargetAppear()) { + break; + } + auto target_pos = target.target_state_.pos(); // world frame + + double big_yaw = 0.0; + + Eigen::Rotation2Dd rot(-big_yaw); + + Eigen::Vector2d pos_world(target_pos.x(), target_pos.y()); + Eigen::Vector2d pos_bigyaw = rot * pos_world; + publishMarker(pos_bigyaw); + sentry_interfaces::msg::Target target_msg; + target_msg.pos.x = pos_bigyaw.x(); + target_msg.pos.y = pos_bigyaw.y(); + target_msg.pos.z = 0.0; + target_msg.color = detect_color_; + target_msg.id = 0; + target_msg.header.frame_id = "gimbal_yaw"; + target_msg.header.stamp = ros2_->now(); + ros2_->publish(TARGET_TOPIC, target_msg); + } + } + + } while (0); + } + + void serialCallback(const uint8_t* data, std::size_t len) { + if (len < 1) + return; + + uint8_t cmd = data[0]; + + try { + if (cmd == ID_AIM_INFO) { + if (len != sizeof(ReceiveAimINFO)) + return; + + const std::vector buf(data, data + len); + + auto aim_data = wust_vl::common::drivers::fromVector(buf); + + processAimData(aim_data); + } + + else if (cmd == ID_REFEREE_INFO) + { + if (len != sizeof(ReceiveReferee)) + return; + + const std::vector buf(data, data + len); + + auto referee_data = wust_vl::common::drivers::fromVector(buf); + + processRefereeData(referee_data); + } + + else + { + std::cerr << "Unknown cmd_ID: " << int(cmd) << std::endl; + } + + } catch (const std::exception& e) { + std::cerr << "serialCallback exception: " << e.what() << std::endl; + } + } + void processRefereeData(const ReceiveReferee& ref) { + if (debug_mode_) { + updateSerialLog(ref); + flushSerialLog(); + } + const auto now = std::chrono::steady_clock::now(); + double big_yaw_rad = ref.big_yaw_in_world / 180.0 * M_PI; + if (big_yaw_motion_buffer_) { + BigYaw big_yaw { .big_yaw = big_yaw_rad }; + big_yaw_motion_buffer_->push(big_yaw, now); + } + sentry_interfaces::msg::RoboState robo_state; + robo_state.game_time = ref.game_time; + robo_state.cur_hp = ref.cur_health; + robo_state.max_hp = ref.max_health; + robo_state.cur_bullet = ref.cur_bullet; + robo_state.center_state = ref.center_state; + ros2_->publish(ROBO_STATE_TOPIC, robo_state); + } + void publishMarker(const Eigen::Vector2d& pos) { + visualization_msgs::msg::Marker marker; + + marker.header.frame_id = "gimbal_yaw"; + marker.header.stamp = ros2_->now(); + + marker.ns = "target"; + marker.id = 0; + marker.type = visualization_msgs::msg::Marker::SPHERE; + marker.action = visualization_msgs::msg::Marker::ADD; + + marker.pose.position.x = pos.x(); + marker.pose.position.y = pos.y(); + marker.pose.position.z = 0.0; + + marker.pose.orientation.w = 1.0; + + marker.scale.x = 0.15; + marker.scale.y = 0.15; + marker.scale.z = 0.15; + + marker.color.a = 1.0; + marker.color.r = 1.0; + marker.color.g = 0.0; + marker.color.b = 0.0; + + marker.lifetime = rclcpp::Duration(0, 0); + + ros2_->publish(TARGET_MARKER, marker); + } + void modeCb(const sentry_interfaces::msg::Mode::SharedPtr msg) { + last_mode_ = msg->mode; + } + int last_mode_ = 0; + void twistCb(const geometry_msgs::msg::Twist::SharedPtr msg) { + NavRobotCmdData send_data; + send_data.cmd_ID = ID_NAV_CMD; + send_data.packet_type = ID_NAV_CONTROL; + send_data.time_stamp = + static_cast(std::chrono::duration_cast( + wust_vl::common::utils::time_utils::now().time_since_epoch() + ) + .count()); + send_data.vx = -msg->linear.y; + send_data.vy = msg->linear.x; + send_data.wz = msg->angular.z; + + if (serial_) { + serial_->write(std::move(wust_vl::common::drivers::toVector(send_data))); + } + last_cmd_time_ = wust_vl::common::utils::time_utils::now(); + } + std::shared_ptr ros2_; + std::unique_ptr timer_B_; + auto_aim::ArmorOmni::Ptr armor_omni_; + std::chrono::steady_clock::time_point last_cmd_time_; + std::shared_ptr> + big_yaw_motion_buffer_; + bool use_sim_ = false; +}; +} // namespace wust_vision +VISION_MAIN(wust_vision::vision) diff --git a/wust_vision-main/src/sim.cpp b/wust_vision-main/src/sim.cpp new file mode 100644 index 0000000..b53775a --- /dev/null +++ b/wust_vision-main/src/sim.cpp @@ -0,0 +1,488 @@ + +// #include "ros2/ros2.hpp" +// #include "sensor_msgs/msg/camera_info.hpp" +// #include "sensor_msgs/msg/image.hpp" +// #include "tasks/utils/config.hpp" +// #include "tasks/utils/main_base.hpp" +// #include "tasks/vision_base.hpp" +// ENABLE_BACKWARD() +// namespace wust_vision { +// class vision { +// public: +// vision() { +// common_config_ = COMMON_CONFIG; +// camera_config_ = CAMERA_CONFIG; +// auto_aim_config_ = AUTO_AIM_CONFIG; +// auto_buff_config_ = AUTO_BUFF_CONFIG; +// }; +// ~vision() { +// run_flag_ = false; +// has_camera_info_ = true; +// if (debug_thread_.joinable()) { +// debug_thread_.join(); +// } + +// WUST_MAIN("main") << "vision stop already!"; +// } +// bool init(bool debug_mode) { +// try { +// rclcpp::init(0, nullptr); +// ros2_ = std::make_shared("vison_node"); +// ros2_->add_subscription( +// "image_raw", +// std::bind(&vision::imageCallback, this, std::placeholders::_1), +// rclcpp::SensorDataQoS() +// ); + +// ros2_->add_subscription( +// "camera_info", +// [this](sensor_msgs::msg::CameraInfo::ConstSharedPtr _camera_info) { +// if (has_camera_info_) +// return; +// std::cout << "camera info received" << std::endl; +// auto& msg = *_camera_info; + +// cv::Mat K(3, 3, CV_64F); +// std::memcpy(K.data, msg.k.data(), 9 * sizeof(double)); + +// cv::Mat D(1, msg.d.size(), CV_64F); +// std::memcpy(D.data, msg.d.data(), msg.d.size() * sizeof(double)); + +// camera_info_ = std::make_pair(K.clone(), D.clone()); +// has_camera_info_ = true; +// } +// ); +// ros2_->start(); +// while (rclcpp::ok() && !has_camera_info_) { +// std::this_thread::sleep_for(std::chrono::milliseconds(1000)); +// WUST_INFO("sim") << "Waiting for camera info..."; +// } + +// const char* v = std::getenv("VISION_ROOT"); +// if (v) +// std::cout << "[env] VISION_ROOT = " << v << "\n"; +// else +// std::cout << "[env] VISION_ROOT not set in this process\n"; +// debug_mode_ = debug_mode; + +// control_config_ = ControlConfig::create(this); +// shoot_config_ = ShootConfig::create(); +// logger_config_ = LoggerConfig::create(); +// tf_config_ = TFConfig::create(); +// max_infer_running_config_ = MaxInferRunningConfig::create(); +// common_config_parameter_.registerGroup(*control_config_); +// common_config_parameter_.registerGroup(*shoot_config_); +// common_config_parameter_.registerGroup(*logger_config_); +// common_config_parameter_.registerGroup(*tf_config_); +// common_config_parameter_.registerGroup(*max_infer_running_config_); +// common_config_parameter_.loadFromFile(common_config_); +// auto config = common_config_parameter_.getConfig(); +// debug_fps_ = config["debug_fps"].as(); +// attack_mode_ = config["attack_mode"].as(); +// detect_color_ = config["detect_color"].as(); + +// wust_vl::common::utils::ParameterManager::instance().registerParameter( +// common_config_parameter_ +// ); +// auto_aim_ = std::make_unique( +// auto_aim_config_, +// tf_config_, +// camera_info_, +// debug_mode +// ); +// auto_buff_ = std::make_unique( +// auto_buff_config_, +// tf_config_, +// camera_info_, +// debug_mode +// ); +// thread_pool_ = std::make_unique( +// std::thread::hardware_concurrency() * 2 +// ); +// motion_buffer_ = +// std::make_shared>(); + +// timer_ = std::make_unique(); +// } catch (std::exception& e) { +// std::cerr << "init exception: " << e.what() << std::endl; +// } +// return true; +// } +// void imageCallback(const sensor_msgs::msg::Image::ConstSharedPtr img_msg) { +// if (!run_flag_) { +// return; +// } +// cv::Mat image( +// img_msg->height, +// img_msg->width, +// CV_8UC3, +// const_cast(img_msg->data.data()), // raw pointer +// img_msg->step +// ); +// CommonFrame common_frame; +// common_frame.img_frame.src_img = std::move(image); +// common_frame.detect_color = detect_color_; +// common_frame.img_frame.timestamp = std::chrono::steady_clock::now(); +// common_frame.img_frame.pixel_format = wust_vl::video::PixelFormat::BGR; +// common_frame.expanded = cv::Rect( +// 0, +// 0, +// common_frame.img_frame.src_img.cols, +// common_frame.img_frame.src_img.rows +// ); +// common_frame.any_ctx = VisionCtx { .motion_buffer = motion_buffer_, +// .communication_delay_μs = +// control_config_->communication_delay_us_param.get(), +// .attack_mode = toAttackMode(attack_mode_) }; +// common_frame.offset = cv::Point2f(0, 0); +// thread_pool_->enqueue([this, frame = std::move(common_frame)]() mutable { +// infer_running_count_++; +// if (frame.img_frame.src_img.empty()) { +// return; +// } +// AttackMode mode = toAttackMode(attack_mode_); +// switch (mode) { +// case AttackMode::ARMOR: { +// auto_aim_->pushInput(frame); +// } break; +// case AttackMode::SMALL_RUNE: { +// auto_buff_->pushInput(frame); +// } break; +// case AttackMode::BIG_RUNE: { +// auto_buff_->pushInput(frame); +// } break; +// case AttackMode::UNKNOWN: { +// auto_aim_->pushInput(frame); +// } break; +// } +// infer_running_count_--; +// }); +// } + +// void start() { +// run_flag_ = true; +// auto_aim_->start(); +// auto_buff_->start(); +// if (timer_) { +// auto timercallback = std::bind(&vision::timerCallback, this, std::placeholders::_1); +// double rate_hz = control_config_->control_rate_param.get(); +// timer_->start(rate_hz, timercallback); +// } + +// if (debug_mode_) { +// debug_thread_ = std::thread([this]() { this->debugThread(); }); +// } +// } +// void timerCallback(double dt_ms) { +// if (!run_flag_) { +// return; +// } +// geometry_msgs::msg::TransformStamped tf; +// if (!ros2_->lookup_transform("odom", "gimbal_link", tf)) { +// RCLCPP_WARN(ros2_->get_logger(), "TF lookup failed"); +// } else { +// Eigen::Quaterniond q( +// tf.transform.rotation.w, +// tf.transform.rotation.x, +// tf.transform.rotation.y, +// tf.transform.rotation.z +// ); + +// // Convert to rotation matrix +// Eigen::Matrix3d R = q.toRotationMatrix(); + +// // Euler angles (ZYX order -> yaw pitch roll) +// Eigen::Vector3d euler = R.eulerAngles(2, 1, 0); // yaw, pitch, roll + +// double yaw = euler[0]; +// double pitch = euler[1]; +// double roll = euler[2]; +// if (motion_buffer_) { +// Motion motion { yaw, -pitch, roll, 0.0, 0, 0, 0, 0, 0 }; +// motion_buffer_->push(motion, std::chrono::steady_clock::now()); +// } +// } +// GimbalCmd cmd; +// try { +// AttackMode mode = toAttackMode(attack_mode_); +// switch (mode) { +// case AttackMode::ARMOR: { +// cmd = auto_aim_->solve(shoot_config_->bullet_speed_param.get()); +// } break; +// case AttackMode::SMALL_RUNE: +// case AttackMode::BIG_RUNE: { +// cmd = auto_buff_->solve(shoot_config_->bullet_speed_param.get()); +// } break; +// case AttackMode::UNKNOWN: { +// cmd = auto_aim_->solve(shoot_config_->bullet_speed_param.get()); +// } break; +// } +// } catch (const std::exception& e) { +// std::cout << "auto_aim_solve error: " << e.what() << std::endl; +// } + +// last_cmd_ = cmd; +// } + +// void debugThread() { +// const double us_interval = 1e6 / static_cast(debug_fps_); +// const auto kInterval = std::chrono::microseconds(static_cast(us_interval)); +// while (run_flag_) { +// const auto start_time = std::chrono::steady_clock::now(); +// do { +// try { +// if (!auto_aim_ || !auto_buff_) { +// break; +// } +// auto dbg_armor = auto_aim_->getDebugFrame(); +// auto dbg_rune = auto_buff_->getDebugFrame(); +// AttackMode mode = toAttackMode(attack_mode_); +// switch (mode) { +// case AttackMode::UNKNOWN: +// case AttackMode::ARMOR: { +// drawDebugOverlayShm(dbg_armor, camera_info_, false); +// } break; +// case AttackMode::SMALL_RUNE: +// case AttackMode::BIG_RUNE: { +// drawDebugOverlayShm(dbg_rune, camera_info_, false); +// } break; +// } +// std::pair gimbal_py; +// if (motion_buffer_) { +// auto last_att = motion_buffer_->get_last(); +// if (last_att) { +// gimbal_py.first = last_att->data.pitch; +// gimbal_py.second = last_att->data.yaw; +// } +// } + +// debuglog(dbg_armor, dbg_rune, last_cmd_, gimbal_py); +// utils::XSecOnce( +// [this]() { +// wust_vl::common::utils::ParameterManager::instance() +// .allReloadFromOldPath(); +// }, +// 1.0 +// ); + +// } catch (std::exception& e) { +// std::cout << "debug thread error: " << e.what() << std::endl; +// } +// } while (0); + +// const auto elapsed = std::chrono::steady_clock::now() - start_time; +// if (elapsed < kInterval) { +// std::this_thread::sleep_for(kInterval - elapsed); +// } +// } +// } +// void checkStateMatchMode() { +// AttackMode mode = toAttackMode(attack_mode_); +// switch (mode) { +// case AttackMode::ARMOR: { +// if (!auto_aim_->isActive()) { +// auto_aim_->processingUp(); +// } +// if (auto_buff_->isActive()) { +// auto_buff_->processingWait(); +// } + +// } break; +// case AttackMode::SMALL_RUNE: +// case AttackMode::BIG_RUNE: { +// if (auto_aim_->isActive()) { +// auto_aim_->processingWait(); +// } +// if (!auto_buff_->isActive()) { +// auto_buff_->processingUp(); +// } +// } break; +// case AttackMode::UNKNOWN: { +// if (!auto_aim_->isActive()) { +// auto_aim_->processingUp(); +// } +// if (auto_buff_->isActive()) { +// auto_buff_->processingWait(); +// } +// } break; +// } +// } +// struct ControlConfig: wust_vl::common::utils::ParamGroup { +// public: +// static constexpr const char* kKey = "control"; +// static constexpr const char* Logger = "Config: common::control"; +// const char* key() const override { +// return kKey; +// } +// using Ptr = std::shared_ptr; +// ControlConfig(vision* b) { +// base = b; +// communication_delay_us_param.onChange([this](int o, int n) { +// if (isBaseACtive()) { +// WUST_DEBUG(Logger) +// << "communication_delay_μs from: " << o << " to: " << n << " us"; +// } +// }); +// } +// static Ptr create(vision* b) { +// return std::make_shared(b); +// } +// GEN_PARAM(double, communication_delay_us); +// GEN_PARAM(double, yaw_ramp); +// GEN_PARAM(double, pitch_ramp); +// GEN_PARAM(double, control_rate); +// vision* base; +// bool first_load = false; +// bool isBaseACtive() { +// return base != nullptr; +// } +// void loadSelf(const YAML::Node& node) override { +// if (!isBaseACtive()) +// return; +// if (!first_load) { +// communication_delay_us_param.set(node["communication_delay_us"].as()); +// yaw_ramp_param.set(node["yaw_ramp"].as()); +// pitch_ramp_param.set(node["pitch_ramp"].as()); +// control_rate_param.set(node["control_rate"].as()); +// first_load = true; +// } else { +// communication_delay_us_param.load(node); +// yaw_ramp_param.load(node); +// pitch_ramp_param.load(node); +// control_rate_param.load(node); +// } +// } +// }; +// ControlConfig::Ptr control_config_; +// struct ShootConfig: wust_vl::common::utils::ParamGroup { +// public: +// static constexpr const char* kKey = "shoot"; +// static constexpr const char* Logger = "Config: common::shoot"; +// const char* key() const override { +// return kKey; +// } +// using Ptr = std::shared_ptr; +// ShootConfig() { +// rate_param.onChange([this](int o, int n) { +// WUST_DEBUG(Logger) << "shoot_rate from: " << o << " to: " << n << " HZ"; +// }); +// } + +// static Ptr create() { +// return std::make_shared(); +// } +// GEN_PARAM(int, rate); +// GEN_PARAM(double, bullet_speed); +// bool first_load = false; + +// void loadSelf(const YAML::Node& node) override { +// if (!first_load) { +// rate_param.set(node["rate"].as()); +// bullet_speed_param.set(node["bullet_speed"].as()); +// first_load = true; +// } else { +// rate_param.load(node); +// } +// } +// }; +// ShootConfig::Ptr shoot_config_; + +// LoggerConfig::Ptr logger_config_; + +// TFConfig::Ptr tf_config_; + +// MaxInferRunningConfig::Ptr max_infer_running_config_; + +// int attack_mode_; +// int debug_fps_; +// bool detect_color_; + +// wust_vl::common::utils::Parameter common_config_parameter_; +// std::unique_ptr thread_pool_; +// std::unique_ptr auto_aim_; +// std::unique_ptr auto_buff_; +// std::unique_ptr timer_; +// std::shared_ptr> motion_buffer_; +// std::thread debug_thread_; +// GimbalCmd last_cmd_; +// std::pair camera_info_; +// bool run_flag_ = false; +// bool debug_mode_ = false; +// std::atomic infer_running_count_ { 0 }; +// std::string common_config_; +// std::string camera_config_; +// std::string auto_aim_config_; +// std::string auto_buff_config_; +// bool has_camera_info_ = false; +// std::shared_ptr ros2_; +// }; +// } // namespace wust_vision +// // VISION_MAIN(wust_vision::vision) +// int main(int argc, char** argv) { +// wust_vision::printBanner(); +// bool debug = false; +// if (argc > 1) { +// std::string firstArg = argv[1]; +// debug = (firstArg == "true" || firstArg == "1"); +// std::cout << "debug: " << firstArg << std::endl; +// } +// std::set_terminate([]() { +// std::cerr << "Uncaught exception, terminating program.\n"; +// if (auto e = std::current_exception()) { +// try { +// std::rethrow_exception(e); +// } catch (const std::exception& ex) { +// std::cerr << "Exception: " << ex.what() << std::endl; +// } catch (...) { +// std::cerr << "Unknown exception" << std::endl; +// } +// } +// std::abort(); +// }); + +// try { +// int exit_code = 0; + +// { +// wust_vision::vision v; +// v.init(debug); +// v.start(); + +// wust_vl::common::utils::SignalHandler sig; +// sig.start([&] { rclcpp::shutdown(); }); + +// bool exit_flag = false; + +// while (!sig.shouldExit() && !exit_flag) { +// wust_vl::common::concurrency::ThreadManager::instance().printStatus(); +// auto all_status = +// wust_vl::common::concurrency::ThreadManager::instance().getAllThreadStatuses(); +// v.checkStateMatchMode(); + +// for (auto& status: all_status) { +// if (status.second +// == wust_vl::common::concurrency::MonitoredThread::Status::Hung) { +// std::cerr << status.first << " is Hunging! Exiting program..." << std::endl; +// exit_flag = true; +// exit_code = -1; +// sig.requestExit(); +// break; +// } +// } +// std::this_thread::sleep_for(std::chrono::milliseconds(1000)); +// } +// } + +// std::cout << "Exiting program..." << std::endl; + +// return exit_code; + +// } catch (const std::exception& e) { +// std::cerr << "Caught exception in main: " << e.what() << "\n"; +// throw; +// return -1; +// } catch (...) { +// std::cerr << "Unknown exception caught in main!\n"; +// return -1; +// } +// } \ No newline at end of file diff --git a/wust_vision-main/src/standard.cpp b/wust_vision-main/src/standard.cpp new file mode 100644 index 0000000..235c9f9 --- /dev/null +++ b/wust_vision-main/src/standard.cpp @@ -0,0 +1,27 @@ +#include "tasks/imodule.hpp" +#include "tasks/utils/config.hpp" +#include "tasks/utils/main_base.hpp" +#include "tasks/vision_base.hpp" +ENABLE_BACKWARD() +namespace wust_vision { +class vision: public VisionBase { +public: + vision(): VisionBase(COMMON_CONFIG, CAMERA_CONFIG, AUTO_AIM_CONFIG, AUTO_BUFF_CONFIG) {} + bool init(bool debug_mode) { + if (!VisionBase::init(debug_mode)) { + return false; + } + auto auto_aim = + auto_aim::AutoAim::create(auto_aim_config_, tf_config_, camera_info_, debug_mode); + modules_.emplace(InfantryMode::AttackMode::ARMOR, auto_aim); + modules_.emplace(InfantryMode::AttackMode::UNKNOWN, auto_aim); + auto auto_buff = + auto_buff::AutoBuff::create(auto_buff_config_, tf_config_, camera_info_, debug_mode); + modules_.emplace(InfantryMode::AttackMode::BIG_RUNE, auto_buff); + modules_.emplace(InfantryMode::AttackMode::SMALL_RUNE, auto_buff); + return true; + } +}; +} // namespace wust_vision + +VISION_MAIN(wust_vision::vision) \ No newline at end of file diff --git a/wust_vision-main/static/css/style.css b/wust_vision-main/static/css/style.css new file mode 100644 index 0000000..fb75152 --- /dev/null +++ b/wust_vision-main/static/css/style.css @@ -0,0 +1,424 @@ +:root { + --primary-color: #00bfff; + --primary-dark: #004f8c; + --accent-color: #00d4ff; + --bg-dark: #121720; + --text-light: #a0d6f5; + --shadow-glow: #00d4ffcc; +} + +/* 基础样式 */ +html, body { + height: 100%; + margin: 0; + padding: 0; + font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; + background: linear-gradient(135deg, var(--bg-dark) 0%, #0f1526 100%); + color: var(--text-light); + -webkit-user-select: text; + user-select: text; + -webkit-font-smoothing: antialiased; + -moz-osx-font-smoothing: grayscale; + text-shadow: 0 0 3px rgba(0, 120, 255, 0.4); +} + +/* 头部区域 */ +#header { + display: flex; + align-items: center; + padding: 12px 24px; + background: rgba(0,31,63,0.85); + -webkit-backdrop-filter: saturate(180%) blur(6px); + backdrop-filter: saturate(180%) blur(6px); + background-image: linear-gradient(90deg, var(--primary-dark), #003366); + border-bottom: 3px solid var(--primary-color); + box-shadow: 0 0 14px var(--shadow-glow); + height: 100px; + box-sizing: border-box; +} + +#header img { + width: 52px; + height: 52px; + margin-right: 20px; + border-radius: 12px; + box-shadow: 0 0 12px var(--primary-color); + transition: transform 0.3s ease; + cursor: pointer; +} +#header img:hover { + animation: glowPulse 2s infinite ease-in-out; + transform: scale(1.1); +} + +/* 标签按钮栏 */ +#tab-buttons { + padding: 10px 24px; + background: #101425; + border-bottom: 2px solid #004466; + box-shadow: inset 0 -2px 8px #00334daa; + height: 44px; + box-sizing: border-box; + display: flex; + align-items: center; +} + +#tab-buttons button { + background: linear-gradient(145deg, #0088cc, #005f7f); + border: none; + color: #ccf8ff; + padding: 8px 18px; + margin-right: 14px; + border-radius: 10px; + font-weight: 600; + cursor: pointer; + box-shadow: + inset 0 -2px 4px #004a6d99, + 0 0 12px var(--primary-color); + transition: all 0.3s ease; + -webkit-user-select: text; + user-select: text; + text-shadow: 0 0 4px var(--primary-color); +} +#tab-buttons button:hover, #tab-buttons button:focus { + background: linear-gradient(145deg, var(--primary-color), #0077aa); + box-shadow: + inset 0 -2px 6px #006db6cc, + 0 0 22px var(--accent-color); + outline: none; +} +#tab-buttons button:active { + transform: scale(0.97); + box-shadow: + inset 0 2px 4px #002d43cc, + 0 0 16px #0088ccdd; +} + +/* 视频区域 */ +#video-tab { + height: calc(100vh - 60px - 44px); + width: 100vw; + box-sizing: border-box; + background: #0f131f; + padding: 12px 20px; + overflow: hidden; + box-shadow: inset 0 0 40px #003554aa; +} + +/* 主网格布局 */ +#main-grid { + display: grid; + grid-template-columns: 1fr 1fr; + grid-template-rows: 1fr 1fr; + height: 100%; + width: 100%; + gap: 6px; + box-sizing: border-box; + background: #121b2b; + border-radius: 14px; + box-shadow: + 0 0 12px #007acc88, + inset 0 0 20px #002b4daa; +} + +/* 面板通用样式 */ +.panel { + background: linear-gradient(145deg, #1b2237, #121a2e); + padding: 16px; + box-sizing: border-box; + height: 100%; + width: 100%; + display: flex; + flex-direction: column; + border-radius: 14px; + border: 1.8px solid transparent; + background-clip: padding-box; + position: relative; + overflow: hidden; + cursor: default; + box-shadow: + 0 0 15px var(--primary-color), + inset 0 0 25px var(--primary-dark); +} + +.panel::before { + content: ""; + position: absolute; + inset: 0; + border-radius: 14px; + padding: 1.8px; + pointer-events: none; + background: linear-gradient(60deg, var(--primary-color), var(--primary-dark), var(--accent-color), #005fbc); + -webkit-mask: + linear-gradient(#fff 0 0) content-box, + linear-gradient(#fff 0 0); + mask: + linear-gradient(#fff 0 0) content-box, + linear-gradient(#fff 0 0); + -webkit-mask-composite: destination-out; + mask-composite: exclude; + animation: pulseBorder 3s ease-in infinite; +} + +/* 边框发光动画 */ +@keyframes pulseBorder { + 0% { + filter: drop-shadow(0 0 5px var(--primary-color)); + } + 100% { + filter: drop-shadow(0 0 15px var(--accent-color)); + } +} + +/* 四个区域微调 */ +#video-panel { padding-right: 10px; } +#json-info { padding-left: 10px; overflow-y: auto; } +#main-chart-panel { padding-right: 10px; overflow-y: auto; } +#individual-chart-panel { padding-left: 10px; overflow-y: auto; } + +/* 视频容器 */ +.video-container { + flex-grow: 1; + display: flex; + justify-content: center; + align-items: center; + overflow: hidden; + border-radius: 12px; + box-shadow: inset 0 0 40px #005a9e88; + background: #12203f; + position: relative; +} + +.video-container img { + max-width: 100%; + max-height: 100%; + border-radius: 12px; + border: 2px solid #006db6; + box-shadow: + 0 0 20px var(--primary-color), + inset 0 0 30px #004975cc; + transition: transform 0.4s ease, box-shadow 0.4s ease; + cursor: pointer; +} +.video-container img:hover { + transform: scale(1.05); + box-shadow: + 0 0 30px var(--accent-color), + inset 0 0 40px #0069c9cc; +} + +/* JSON信息容器 */ +#json-info { + color: #a0cbdc; + font-family: "Source Code Pro", monospace, monospace; + overflow-y: auto; + text-shadow: 0 0 3px #005e8a; + -webkit-user-select: text; + user-select: text; +} + +/* JSON子容器 */ +.json-container { + background: linear-gradient(135deg, #12273e 0%, #0c1c33 100%); + border-radius: 12px; + padding: 14px 18px; + margin-bottom: 14px; + height: 48%; + overflow-y: auto; + box-shadow: inset 0 0 12px #004573; + border: 1.2px solid #0068ac; +} + +/* 主图表和独立图表容器撑满 */ +#main-chart-panel, +#individual-chart-panel { + height: 100%; + overflow-y: auto; + display: flex; + flex-direction: column; +} + +/* 主图表样式 */ +#mainChart { + flex-grow: 1; + width: 100% !important; + border-radius: 14px; + background: linear-gradient(135deg, #1a243a, #0f1526); + padding: 12px; + box-sizing: border-box; + margin-bottom: 14px; + box-shadow: + 0 0 12px #0099ffbb, + inset 0 0 18px #004575aa; + border: 2px solid #0088ee; + transition: box-shadow 0.3s ease; + cursor: default; +} +#mainChart:hover { + box-shadow: + 0 0 24px var(--primary-color), + inset 0 0 30px #0077ccbb; +} + +/* 独立图表容器 */ +#individualCharts { + flex-grow: 1; + overflow-y: auto; + display: flex; + flex-wrap: wrap; + gap: 18px; + padding-bottom: 6px; +} + +/* 交互控件 */ +.chart-controls { + margin-bottom: 10px; + color: #a0d4f7; + text-shadow: 0 0 3px #005e8a; + -webkit-user-select: text; + user-select: text; +} + +/* 按钮、下拉框等统一风格 */ +button, select, input[type="checkbox"] { + background: linear-gradient(145deg, #00bcd4, #0088aa); + color: #001a25; + border: none; + padding: 8px 14px; + border-radius: 10px; + margin-right: 10px; + cursor: pointer; + font-weight: 700; + box-shadow: 0 0 14px #00cfffcc; + transition: all 0.25s ease; + -webkit-user-select: text; + user-select: text; + text-shadow: 0 0 6px #00dfffcc; +} +button:hover, select:hover { + background: linear-gradient(145deg, #00e3ff, #0099cc); + box-shadow: 0 0 24px #00eaffee; +} +input[type="checkbox"] { + transform: scale(1.3); + cursor: pointer; +} + +/* 标签文字 */ +label { + color: #9bc8f7; + margin-right: 12px; + -webkit-user-select: text; + user-select: text; + text-shadow: 0 0 3px #004477; +} + +/* 配置面板 */ +#config-tab { + padding: 12px; + background: linear-gradient(135deg, #132237, #0b1528); + border-radius: 14px; + border: 1.5px solid #005f8f; + box-shadow: 0 0 12px #0077bbaa; + color: #a0cde8; +} + +/* 响应式调整:小屏幕时改为单列布局 */ +@media (max-width: 1024px) { + #video-tab { + height: auto; + } + #main-grid { + grid-template-columns: 1fr; + grid-template-rows: auto auto auto auto; + height: auto; + gap: 18px; + } + .panel { + height: auto; + } +} + +/* 头部图片发光动画 */ +@keyframes glowPulse { + 0%, 100% { + box-shadow: 0 0 10px var(--primary-color); + } + 50% { + box-shadow: 0 0 20px var(--accent-color); + } +} + +/* 全屏模式下,仅显示视频容器,隐藏其他所有元素 */ +.fullscreen-mode #header, +.fullscreen-mode #tab-buttons, +.fullscreen-mode #json-info, +.fullscreen-mode #main-chart-panel, +.fullscreen-mode #individual-chart-panel, +.fullscreen-mode .fullscreen-btn { + display: none !important; +} + +.fullscreen-mode #main-grid { + padding: 0 !important; + margin: 0 !important; + gap: 0 !important; + background: black !important; + box-shadow: none !important; + border-radius: 0 !important; +} + +.fullscreen-mode #video-panel { + grid-column: 1 / -1; + grid-row: 1 / -1; + padding: 0 !important; +} + +/* 全屏时视频容器铺满全屏 */ +.fullscreen-mode .video-container { + position: fixed; + top: 0; left: 0; + width: 100vw !important; + height: 100vh !important; + margin: 0; padding: 0; + background: black !important; + border-radius: 0 !important; + box-shadow: none !important; + overflow: hidden; + z-index: 9999; + display: flex !important; + justify-content: center; + align-items: center; +} + +/* 图片铺满容器,保持比例,可能裁剪 */ +.fullscreen-mode .video-container img { + height: 90% !important; + width: auto !important; + max-width: 100vw !important; + object-fit: contain !important; + border: none !important; + box-shadow: none !important; + border-radius: 0 !important; +} + + +/* 全屏按钮样式 */ +.fullscreen-btn { + position: absolute; + bottom: 12px; + right: 14px; + padding: 8px 12px; + font-size: 14px; + background: rgba(0, 136, 204, 0.85); + color: white; + border: none; + border-radius: 8px; + cursor: pointer; + box-shadow: 0 0 10px #00bfff88; + z-index: 10; + transition: background 0.2s ease; +} +.fullscreen-btn:hover { + background: rgba(0, 180, 255, 0.95); +} diff --git a/wust_vision-main/static/js/chart_logic.js b/wust_vision-main/static/js/chart_logic.js new file mode 100644 index 0000000..8ff08e5 --- /dev/null +++ b/wust_vision-main/static/js/chart_logic.js @@ -0,0 +1,249 @@ +const UPDATE_HZ = 100; +const UPDATE_INTERVAL_MS = 1000 / UPDATE_HZ; + +let updateTimer = null; + +function startUpdateLoop() { + if (updateTimer) return; + updateTimer = setInterval(fetchDataAndUpdateCharts, UPDATE_INTERVAL_MS); +} +startUpdateLoop(); + +const chartMap = { + raw_yaw: { label: "Raw Yaw" }, + raw_pitch: { label: "Raw Pitch" }, + yaw: { label: "Yaw" }, + pitch: { label: "Pitch" }, + armor_dis: { label: "Armor Distance" }, + armor_x: { label: "Armor X" }, + armor_y: { label: "Armor Y" }, + armor_z: { label: "Armor Z" }, + armor_yaw: { label: "Armor Yaw" }, + ypd_p: { label: "Ypd Pitch" }, + ypd_y: { label: "Ypd Yaw" }, + rune_obs: { label: "Rune Obs" }, + rune_pre: { label: "Rune Pre" }, + rune_obsv: { label: "Rune ObsV" }, + rune_fitv: { label: "Rune FitV" }, + gimbal_yaw: { label: "Gimbal Yaw" }, + gimbal_pitch: { label: "Gimbal Pitch" }, + target_v_yaw: { label: "Target V Yaw" }, + control_v_yaw: { label: "Control V Yaw" }, + control_v_pitch: { label: "Control V Pitch" }, + yaw_diff: { label: "Yaw Diff" }, + fire: { label: "Fire" }, + rune_dis: { label: "Rune Distance" }, + fly_time: { label: "Fly Time" }, + control_a_yaw: { label: "Control A Yaw" }, + control_a_pitch: { label: "Control A Pitch" }, +}; + +const mainCtx = document.getElementById("mainChart").getContext("2d"); +let individualCharts = {}; +let individualRanges = {}; + +const commonChartOptions = { + animation: false, + responsive: false, + interaction: { + mode: "nearest", + axis: "x", + intersect: false, + }, + elements: { + line: { + tension: 0, + }, + point: { + radius: 0, + hoverRadius: 0, + }, + }, + plugins: { + tooltip: { + enabled: true, + mode: "nearest", + intersect: false, + backgroundColor: "rgba(0, 188, 212, 0.85)", + padding: 8, + cornerRadius: 6, + titleFont: { + size: 12, + family: "'Segoe UI', Tahoma, Geneva, Verdana, sans-serif", + }, + bodyFont: { + size: 16, + family: "'Segoe UI', Tahoma, Geneva, Verdana, sans-serif", + weight: "bold", + }, + callbacks: { + title: () => "", + label: (context) => `值: ${context.parsed.y.toFixed(3)}`, + }, + }, + legend: { + labels: { + font: { + size: 14, + family: "'Segoe UI', Tahoma, Geneva, Verdana, sans-serif", + }, + color: "#00bcd4", + }, + }, + }, + scales: { + x: { + display: false, // ✅ 关键:彻底关闭 X 轴显示 + }, + y: { + title: { display: true, text: "Value" }, + ticks: { + color: "#00bcd4", + font: { + size: 14, + family: "'Segoe UI', Tahoma, Geneva, Verdana, sans-serif", + }, + }, + grid: { color: "#2f3241" }, + }, + }, +}; + +const mainChart = new Chart(mainCtx, { + type: "line", + data: { labels: [], datasets: [] }, + options: commonChartOptions, +}); + +function updateMainRange() { + const maxPts = parseInt(document.getElementById("mainMaxPts").value) || 100; + mainChart._maxPoints = maxPts; +} + +function updateCharts() { + const showMulti = document.getElementById("multiLineChart").checked; + const selected = Array.from( + document.querySelectorAll( + '.chart-select-controls input[type="checkbox"]:checked' + ) + ) + .map((cb) => cb.dataset.key) + .filter(Boolean); + + const container = document.getElementById("individualCharts"); + container.innerHTML = ""; + individualCharts = {}; + individualRanges = {}; + + if (showMulti) { + mainChart.data.datasets = selected.map((key) => ({ + label: chartMap[key]?.label || key, + data: [], + fill: false, + })); + } else { + mainChart.data.datasets = []; + } + mainChart.update(); + + selected.forEach((key) => { + const box = document.createElement("div"); + box.className = "chart-box"; + const title = document.createElement("h4"); + title.textContent = chartMap[key]?.label || key; + box.appendChild(title); + + const rdiv = document.createElement("div"); + rdiv.className = "range-controls"; + rdiv.innerHTML = ` + + min: + max: + + `; + box.appendChild(rdiv); + + const canvas = document.createElement("canvas"); + canvas.width = 440; + canvas.height = 280; + box.appendChild(canvas); + container.appendChild(box); + + const ctx = canvas.getContext("2d"); + const chart = new Chart(ctx, { + type: "line", + data: { + labels: [], + datasets: [ + { + label: chartMap[key]?.label || key, + data: [], + fill: false, + }, + ], + }, + options: commonChartOptions, + }); + individualCharts[key] = chart; + + const applyBtn = rdiv.querySelector(".applyRange"); + applyBtn.addEventListener("click", () => { + const enabled = rdiv.querySelector(".childEnable").checked; + const minVal = parseFloat(rdiv.querySelector(".childMin").value); + const maxVal = parseFloat(rdiv.querySelector(".childMax").value); + chart.options.scales.y.min = enabled ? minVal : undefined; + chart.options.scales.y.max = enabled ? maxVal : undefined; + chart.update("none"); + }); + }); +} + +async function fetchDataAndUpdateCharts() { + try { + const res = await fetch("/data"); + const json = await res.json(); + const time = json.time; + if (!time) return; + + const maxPts = mainChart._maxPoints || 100; + const start = time.length > maxPts ? time.length - maxPts : 0; + const slicedTime = time.slice(start); + + if (document.getElementById("multiLineChart").checked) { + mainChart.data.labels = slicedTime; + const keys = Object.keys(individualCharts); + mainChart.data.datasets.forEach((ds, i) => { + const key = keys[i]; + ds.data = json[key]?.slice(start) || []; + }); + + const allValues = mainChart.data.datasets.flatMap(ds => ds.data); + if (allValues.length > 0) { + const minVal = Math.min(...allValues); + const maxVal = Math.max(...allValues); + const padding = (maxVal - minVal) * 0.1 || 1; // 至少留1的余量 + mainChart.options.scales.y.min = minVal - padding; + mainChart.options.scales.y.max = maxVal + padding; + } + mainChart.update(); + } + + Object.entries(individualCharts).forEach(([key, ch]) => { + ch.data.labels = slicedTime; + ch.data.datasets[0].data = json[key]?.slice(start) || []; + + // 每个子图也加 margin,避免线条贴边 + const arr = ch.data.datasets[0].data; + if (arr.length > 0) { + const minVal = Math.min(...arr); + const maxVal = Math.max(...arr); + const padding = (maxVal - minVal) * 0.1 || 1; + ch.options.scales.y.min = minVal - padding; + ch.options.scales.y.max = maxVal + padding; + } + ch.update(); + }); + } catch (e) { + console.error("fetch error:", e); + } +} diff --git a/wust_vision-main/static/js/json_view.js b/wust_vision-main/static/js/json_view.js new file mode 100644 index 0000000..e77ff0f --- /dev/null +++ b/wust_vision-main/static/js/json_view.js @@ -0,0 +1,61 @@ +function jsonToHtml(data, container) { + const title = container.querySelector("h3"); + container.innerHTML = ""; + if (title) container.appendChild(title); + const contentDiv = document.createElement("div"); + container.appendChild(contentDiv); + + function buildTree(d, parent) { + if (typeof d !== "object" || d === null) { + parent.textContent = String(d); + return; + } + const ul = document.createElement("ul"); + ul.className = "json-tree"; + const entries = Array.isArray(d) ? d.map((v, i) => [i, v]) : Object.entries(d); + entries.forEach(([k, v]) => { + const li = document.createElement("li"); + if (typeof v === "object" && v !== null) { + const details = document.createElement("details"); + details.open = true; + const summary = document.createElement("summary"); + summary.textContent = k; + details.appendChild(summary); + const childUl = document.createElement("ul"); + childUl.className = "json-tree"; + buildTree(v, childUl); + details.appendChild(childUl); + li.appendChild(details); + } else { + li.textContent = `${k}: ${v}`; + } + ul.appendChild(li); + }); + parent.appendChild(ul); + } + buildTree(data, contentDiv); +} + +let lastSerialData = null, lastTargetData = null; + +async function fetchAndDisplayJsonWithTree(id, url) { + const parent = document.getElementById(id + "-container"); + const cont = document.getElementById(id); + try { + parent.classList.add("json-updating"); + const res = await fetch(url); + if (!res.ok) throw new Error(res.statusText); + const data = await res.json(); + + const prev = id === "json-serial" ? lastSerialData : lastTargetData; + if (JSON.stringify(data) !== JSON.stringify(prev)) { + jsonToHtml(data, cont); + if (id === "json-serial") lastSerialData = data; + else lastTargetData = data; + } + } catch (e) { + console.warn(`请求失败(${url}): ${e.message}`); + } finally { + parent.classList.remove("json-updating"); + } +} diff --git a/wust_vision-main/static/js/main.js b/wust_vision-main/static/js/main.js new file mode 100644 index 0000000..c29652a --- /dev/null +++ b/wust_vision-main/static/js/main.js @@ -0,0 +1,35 @@ +function showTab(tab) { + document.getElementById("video-tab").style.display = tab === "video" ? "flex" : "none"; +} + +document.addEventListener("DOMContentLoaded", () => { + updateCharts(); + updateMainRange(); + setInterval(() => { + fetchDataAndUpdateCharts(); + fetchAndDisplayJsonWithTree("json-target", "/target_log"); + fetchAndDisplayJsonWithTree("json-serial", "/serial_log"); + }, 200); +}); +function toggleFullscreen() { + const container = document.querySelector('.video-container'); + + if (!document.fullscreenElement) { + container.requestFullscreen().then(() => { + document.body.classList.add('fullscreen-mode'); + }).catch((err) => { + alert(`无法进入全屏模式: ${err.message}`); + }); + } else { + document.exitFullscreen().then(() => { + document.body.classList.remove('fullscreen-mode'); + }); + } +} + +// 监听用户按 ESC 或其他方式退出全屏 +document.addEventListener('fullscreenchange', () => { + if (!document.fullscreenElement) { + document.body.classList.remove('fullscreen-mode'); + } +}); diff --git a/wust_vision-main/static/logo.JPG b/wust_vision-main/static/logo.JPG new file mode 100644 index 0000000..d2b61ab Binary files /dev/null and b/wust_vision-main/static/logo.JPG differ diff --git a/wust_vision-main/tasks/auto_aim/armor_control/tinympc/CMakeLists.txt b/wust_vision-main/tasks/auto_aim/armor_control/tinympc/CMakeLists.txt new file mode 100644 index 0000000..b8654af --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_control/tinympc/CMakeLists.txt @@ -0,0 +1,43 @@ +add_library(tinympcstatic STATIC + admm.cpp + tiny_api.cpp + codegen.cpp + rho_benchmark.cpp +) + +set_property(TARGET tinympcstatic PROPERTY POSITION_INDEPENDENT_CODE ON) + +# target_link_libraries(tinympcstatic PUBLIC Eigen) +target_include_directories(tinympcstatic PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/..) +target_include_directories(tinympcstatic PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../include/Eigen) + + + +if(USING_CODEGEN) # Defined in top-level CMakeLists.txt + +# Files that are needed for embedded code generation +list( APPEND EMBEDDED_FILES + "${CMAKE_CURRENT_SOURCE_DIR}/admm.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/admm.hpp" + "${CMAKE_CURRENT_SOURCE_DIR}/tiny_api.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/tiny_api.hpp" + "${CMAKE_CURRENT_SOURCE_DIR}/types.hpp" + "${CMAKE_CURRENT_SOURCE_DIR}/tiny_api_constants.hpp" ) + + +foreach( f ${EMBEDDED_FILES} ) + get_filename_component( fname ${f} NAME ) + + set( dest_file "${EMBEDDED_BUILD_TINYMPC_DIR}/${fname}" ) + list( APPEND EMBEDDED_BUILD_TINYMPC_FILES "${dest_file}" ) + + add_custom_command(OUTPUT ${dest_file} + COMMAND ${CMAKE_COMMAND} -E copy "${f}" "${dest_file}" + DEPENDS ${f} + COMMENT "Copying ${fname}") +endforeach() + +add_custom_target( copy_codegen_tinympc_files DEPENDS ${EMBEDDED_BUILD_TINYMPC_FILES} ) +add_dependencies( copy_codegen_files copy_codegen_tinympc_files ) + +endif(USING_CODEGEN) diff --git a/wust_vision-main/tasks/auto_aim/armor_control/tinympc/admm.cpp b/wust_vision-main/tasks/auto_aim/armor_control/tinympc/admm.cpp new file mode 100644 index 0000000..c9909b8 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_control/tinympc/admm.cpp @@ -0,0 +1,399 @@ +#include + +#include "admm.hpp" +#include "rho_benchmark.hpp" + +#define DEBUG_MODULE "TINYALG" + +extern "C" { + +/** + * Update linear terms from Riccati backward pass +*/ +void backward_pass_grad(TinySolver* solver) { + for (int i = solver->work->N - 2; i >= 0; i--) { + (solver->work->d.col(i)).noalias() = solver->cache->Quu_inv + * (solver->work->Bdyn.transpose() * solver->work->p.col(i + 1) + solver->work->r.col(i) + + solver->cache->BPf); + (solver->work->p.col(i)).noalias() = solver->work->q.col(i) + + solver->cache->AmBKt.lazyProduct(solver->work->p.col(i + 1)) + - (solver->cache->Kinf.transpose()).lazyProduct(solver->work->r.col(i)) + + solver->cache->APf; + } +} + +/** + * Use LQR feedback policy to roll out trajectory +*/ +void forward_pass(TinySolver* solver) { + for (int i = 0; i < solver->work->N - 1; i++) { + (solver->work->u.col(i)).noalias() = + -solver->cache->Kinf.lazyProduct(solver->work->x.col(i)) - solver->work->d.col(i); + (solver->work->x.col(i + 1)).noalias() = + solver->work->Adyn.lazyProduct(solver->work->x.col(i)) + + solver->work->Bdyn.lazyProduct(solver->work->u.col(i)) + solver->work->fdyn; + } +} + +/** + * Project a vector s onto the second order cone defined by mu + * @param s, mu + * @return projection onto cone if s is outside cone. Return s if s is inside cone. +*/ +tinyVector project_soc(tinyVector s, float mu) { + tinytype u0 = s(Eigen::placeholders::last) * mu; + tinyVector u1 = s.head(s.rows() - 1); + float a = u1.norm(); + tinyVector cone_origin(s.rows()); + cone_origin.setZero(); + + if (a <= -u0) { // below cone + return cone_origin; + } else if (a <= u0) { // in cone + return s; + } else if (a >= abs(u0)) { // outside cone + Matrix u2(u1.size() + 1); + u2 << u1, a / mu; + return 0.5 * (1 + u0 / a) * u2; + } else { + return cone_origin; + } +} + +/** + * Project a vector z onto a hyperplane defined by a^T z = b + * Implements equation (21): ΠH(z) = z - (⟨z, a⟩ − b)/||a||² * a + * @param z Vector to project + * @param a Normal vector of the hyperplane + * @param b Offset of the hyperplane + * @return Projection of z onto the hyperplane + */ +tinyVector project_hyperplane(const tinyVector& z, const tinyVector& a, tinytype b) { + tinytype dist = (a.dot(z) - b) / a.squaredNorm(); + return z - dist * a; +} + +/** + * Project slack (auxiliary) variables into their feasible domain, defined by + * projection functions related to each constraint + * TODO: pass in meta information with each constraint assigning it to a + * projection function +*/ +void update_slack(TinySolver* solver) { + // Update bound constraint slack variables for state + solver->work->vnew = solver->work->x + solver->work->g; + + // Update bound constraint slack variables for input + solver->work->znew = solver->work->u + solver->work->y; + + // Box constraints on state + if (solver->settings->en_state_bound) { + solver->work->vnew = + solver->work->x_max.cwiseMin(solver->work->x_min.cwiseMax(solver->work->vnew)); + } + + // Box constraints on input + if (solver->settings->en_input_bound) { + solver->work->znew = + solver->work->u_max.cwiseMin(solver->work->u_min.cwiseMax(solver->work->znew)); + } + + // Update second order cone slack variables for state + if (solver->settings->en_state_soc && solver->work->numStateCones > 0) { + solver->work->vcnew = solver->work->x + solver->work->gc; + } + + // Update second order cone slack variables for input + if (solver->settings->en_input_soc && solver->work->numInputCones > 0) { + solver->work->zcnew = solver->work->u + solver->work->yc; + } + + // Cone constraints on state + if (solver->settings->en_state_soc) { + for (int i = 0; i < solver->work->N; i++) { + for (int k = 0; k < solver->work->numStateCones; k++) { + int start = solver->work->Acx(k); + int num_xs = solver->work->qcx(k); + tinytype mu = solver->work->cx(k); + tinyVector col = solver->work->vcnew.block(start, i, num_xs, 1); + solver->work->vcnew.block(start, i, num_xs, 1) = project_soc(col, mu); + } + } + } + + // Cone constraints on input + if (solver->settings->en_input_soc) { + for (int i = 0; i < solver->work->N - 1; i++) { + for (int k = 0; k < solver->work->numInputCones; k++) { + int start = solver->work->Acu(k); + int num_us = solver->work->qcu(k); + tinytype mu = solver->work->cu(k); + tinyVector col = solver->work->zcnew.block(start, i, num_us, 1); + solver->work->zcnew.block(start, i, num_us, 1) = project_soc(col, mu); + } + } + } + + // Update linear constraint slack variables for state + if (solver->settings->en_state_linear) { + solver->work->vlnew = solver->work->x + solver->work->gl; + } + + // Update linear constraint slack variables for input + if (solver->settings->en_input_linear) { + solver->work->zlnew = solver->work->u + solver->work->yl; + } + + // Linear constraints on state + if (solver->settings->en_state_linear) { + for (int i = 0; i < solver->work->N; i++) { + for (int k = 0; k < solver->work->numStateLinear; k++) { + tinyVector a = solver->work->Alin_x.row(k); + tinytype b = solver->work->blin_x(k); + tinytype constraint_value = a.dot(solver->work->vlnew.col(i)); + if (constraint_value > b) { // Only project if constraint is violated + solver->work->vlnew.col(i) = + project_hyperplane(solver->work->vlnew.col(i), a, b); + } + } + } + } + + // Linear constraints on input + if (solver->settings->en_input_linear) { + for (int i = 0; i < solver->work->N - 1; i++) { + for (int k = 0; k < solver->work->numInputLinear; k++) { + tinyVector a = solver->work->Alin_u.row(k); + tinytype b = solver->work->blin_u(k); + tinytype constraint_value = a.dot(solver->work->zlnew.col(i)); + if (constraint_value > b) { // Only project if constraint is violated + solver->work->zlnew.col(i) = + project_hyperplane(solver->work->zlnew.col(i), a, b); + } + } + } + } +} + +/** + * Update next iteration of dual variables by performing the augmented + * lagrangian multiplier update +*/ +void update_dual(TinySolver* solver) { + // Update bound constraint dual variables for state + solver->work->g = solver->work->g + solver->work->x - solver->work->vnew; + + // Update bound constraint dual variables for input + solver->work->y = solver->work->y + solver->work->u - solver->work->znew; + + // Update second order cone dual variables for state + if (solver->settings->en_state_soc && solver->work->numStateCones > 0) { + solver->work->gc = solver->work->gc + solver->work->x - solver->work->vcnew; + } + + // Update second order cone dual variables for input + if (solver->settings->en_input_soc && solver->work->numInputCones > 0) { + solver->work->yc = solver->work->yc + solver->work->u - solver->work->zcnew; + } + + // Update linear constraint dual variables for state + if (solver->settings->en_state_linear) { + solver->work->gl = solver->work->gl + solver->work->x - solver->work->vlnew; + } + + // Update linear constraint dual variables for input + if (solver->settings->en_input_linear) { + solver->work->yl = solver->work->yl + solver->work->u - solver->work->zlnew; + } +} + +/** + * Update linear control cost terms in the Riccati feedback using the changing + * slack and dual variables from ADMM +*/ +void update_linear_cost(TinySolver* solver) { + // Update state cost terms + solver->work->q = -(solver->work->Xref.array().colwise() * solver->work->Q.array()); + (solver->work->q).noalias() -= solver->cache->rho * (solver->work->vnew - solver->work->g); + if (solver->settings->en_state_soc && solver->work->numStateCones > 0) { + (solver->work->q).noalias() -= + solver->cache->rho * (solver->work->vcnew - solver->work->gc); + } + if (solver->settings->en_state_linear) { + (solver->work->q).noalias() -= + solver->cache->rho * (solver->work->vlnew - solver->work->gl); + } + + // Update input cost terms + solver->work->r = -(solver->work->Uref.array().colwise() * solver->work->R.array()); + (solver->work->r).noalias() -= solver->cache->rho * (solver->work->znew - solver->work->y); + if (solver->settings->en_input_soc && solver->work->numInputCones > 0) { + (solver->work->r).noalias() -= + solver->cache->rho * (solver->work->zcnew - solver->work->yc); + } + if (solver->settings->en_input_linear) { + (solver->work->r).noalias() -= + solver->cache->rho * (solver->work->zlnew - solver->work->yl); + } + + // Update terminal cost + solver->work->p.col(solver->work->N - 1) = + -(solver->work->Xref.col(solver->work->N - 1).transpose().lazyProduct(solver->cache->Pinf)); + (solver->work->p.col(solver->work->N - 1)).noalias() -= solver->cache->rho + * (solver->work->vnew.col(solver->work->N - 1) - solver->work->g.col(solver->work->N - 1)); + + if (solver->settings->en_state_soc && solver->work->numStateCones > 0) { + solver->work->p.col(solver->work->N - 1) -= solver->cache->rho + * (solver->work->vcnew.col(solver->work->N - 1) + - solver->work->gc.col(solver->work->N - 1)); + } + if (solver->settings->en_state_linear) { + solver->work->p.col(solver->work->N - 1) -= solver->cache->rho + * (solver->work->vlnew.col(solver->work->N - 1) + - solver->work->gl.col(solver->work->N - 1)); + } +} + +/** + * Check for termination condition by evaluating whether the largest absolute + * primal and dual residuals for states and inputs are below threhold. +*/ +bool termination_condition(TinySolver* solver) { + if (solver->work->iter % solver->settings->check_termination == 0) { + solver->work->primal_residual_state = + (solver->work->x - solver->work->vnew).cwiseAbs().maxCoeff(); + solver->work->dual_residual_state = + ((solver->work->v - solver->work->vnew).cwiseAbs().maxCoeff()) * solver->cache->rho; + solver->work->primal_residual_input = + (solver->work->u - solver->work->znew).cwiseAbs().maxCoeff(); + solver->work->dual_residual_input = + ((solver->work->z - solver->work->znew).cwiseAbs().maxCoeff()) * solver->cache->rho; + + if (solver->work->primal_residual_state < solver->settings->abs_pri_tol + && solver->work->primal_residual_input < solver->settings->abs_pri_tol + && solver->work->dual_residual_state < solver->settings->abs_dua_tol + && solver->work->dual_residual_input < solver->settings->abs_dua_tol) + { + return true; + } + } + return false; +} + +int solve(TinySolver* solver) { + // Initialize variables + solver->solution->solved = 0; + solver->solution->iter = 0; + solver->work->status = 11; // TINY_UNSOLVED + solver->work->iter = 0; + + // Setup for adaptive rho + RhoAdapter adapter; + adapter.rho_min = solver->settings->adaptive_rho_min; + adapter.rho_max = solver->settings->adaptive_rho_max; + adapter.clip = solver->settings->adaptive_rho_enable_clipping; + + RhoBenchmarkResult rho_result; + + // Store previous values for residuals + tinyMatrix v_prev = solver->work->vnew; + tinyMatrix z_prev = solver->work->znew; + + // Initialize SOC slack variables if needed + if (solver->settings->en_state_soc && solver->work->numStateCones > 0) { + solver->work->vcnew = solver->work->x; + } + + if (solver->settings->en_input_soc && solver->work->numInputCones > 0) { + solver->work->zcnew = solver->work->u; + } + + // Initialize linear constraint slack variables if needed + if (solver->settings->en_state_linear) { + solver->work->vlnew = solver->work->x; + } + + if (solver->settings->en_input_linear) { + solver->work->zlnew = solver->work->u; + } + + for (int i = 0; i < solver->settings->max_iter; i++) { + // Solve linear system with Riccati and roll out to get new trajectory + backward_pass_grad(solver); + + forward_pass(solver); + + // Project slack variables into feasible domain + update_slack(solver); + + // Compute next iteration of dual variables + update_dual(solver); + + // Update linear control cost terms using reference trajectory, duals, and slack variables + update_linear_cost(solver); + + solver->work->iter += 1; + + // Handle adaptive rho if enabled + if (solver->settings->adaptive_rho) { + // Calculate residuals for adaptive rho + tinytype pri_res_input = (solver->work->u - solver->work->znew).cwiseAbs().maxCoeff(); + tinytype pri_res_state = (solver->work->x - solver->work->vnew).cwiseAbs().maxCoeff(); + tinytype dua_res_input = + solver->cache->rho * (solver->work->znew - z_prev).cwiseAbs().maxCoeff(); + tinytype dua_res_state = + solver->cache->rho * (solver->work->vnew - v_prev).cwiseAbs().maxCoeff(); + + // Update rho every 5 iterations + if (i > 0 && i % 5 == 0) { + benchmark_rho_adaptation( + &adapter, + solver->work->x, + solver->work->u, + solver->work->vnew, + solver->work->znew, + solver->work->g, + solver->work->y, + solver->cache, + solver->work, + solver->work->N, + &rho_result + ); + + // Update matrices using Taylor expansion + update_matrices_with_derivatives(solver->cache, rho_result.final_rho); + } + } + + // Store previous values for next iteration + z_prev = solver->work->znew; + v_prev = solver->work->vnew; + + // Check for whether cost is minimized by calculating residuals + if (termination_condition(solver)) { + solver->work->status = 1; // TINY_SOLVED + + // Save solution + solver->solution->iter = solver->work->iter; + solver->solution->solved = 1; + solver->solution->x = solver->work->vnew; + solver->solution->u = solver->work->znew; + + // std::cout << "Solver converged in " << solver->work->iter << " iterations" << std::endl; + + return 0; + } + + // Save previous slack variables + solver->work->v = solver->work->vnew; + solver->work->z = solver->work->znew; + } + + solver->solution->iter = solver->work->iter; + solver->solution->solved = 0; + solver->solution->x = solver->work->vnew; + solver->solution->u = solver->work->znew; + return 1; +} + +} /* extern "C" */ diff --git a/wust_vision-main/tasks/auto_aim/armor_control/tinympc/admm.hpp b/wust_vision-main/tasks/auto_aim/armor_control/tinympc/admm.hpp new file mode 100644 index 0000000..ca2d1d1 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_control/tinympc/admm.hpp @@ -0,0 +1,37 @@ +#pragma once + +#include "types.hpp" + +#ifdef __cplusplus +extern "C" { +#endif + +int solve(TinySolver* solver); + +void update_primal(TinySolver* solver); +void backward_pass_grad(TinySolver* solver); +void forward_pass(TinySolver* solver); +void update_slack(TinySolver* solver); +void update_dual(TinySolver* solver); +void update_linear_cost(TinySolver* solver); +bool termination_condition(TinySolver* solver); + +/** + * Project a vector s onto the second order cone defined by mu + * @param s, mu + * @return projection onto cone if s is outside cone. Return s if s is inside cone. +*/ +tinyVector project_soc(tinyVector s, float mu); + +/** + * Project a vector z onto a hyperplane defined by a^T z = b + * Implements equation (21): ΠH(z) = z - (⟨z, a⟩ − b)/||a||² * a + * @param z Vector to project + * @param a Normal vector of the hyperplane + * @param b Offset of the hyperplane + * @return Projection of z onto the hyperplane + */ +tinyVector project_hyperplane(const tinyVector& z, const tinyVector& a, tinytype b); +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_control/tinympc/codegen.cpp b/wust_vision-main/tasks/auto_aim/armor_control/tinympc/codegen.cpp new file mode 100644 index 0000000..18c3e6a --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_control/tinympc/codegen.cpp @@ -0,0 +1,466 @@ +#include +#include +#include +#include +#include +#include +#include +//#include +#include "error.hpp" + +#include +#include + +// #include "types.hpp" +#include "codegen.hpp" + +#ifdef __MINGW32__ + #include +inline int mkdir(const char* pathname, int flags) { + return _mkdir(pathname); +} +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +/* Define the maximum allowed length of the path (directory + filename + extension) */ +#define PATH_LENGTH 2048 + +using namespace Eigen; + +static void print_matrix(FILE* f, MatrixXd mat, int num_elements) { + // Check if matrix is uninitialized or too small + if (mat.size() == 0 || mat.size() < num_elements) { + // Print zeros for all elements + for (int i = 0; i < num_elements; i++) { + fprintf(f, "(tinytype)0.0000000000000000"); + if (i < num_elements - 1) + fprintf(f, ","); + } + return; + } + + // Matrix is properly initialized and has enough elements + for (int i = 0; i < num_elements; i++) { + fprintf(f, "(tinytype)%.16f", mat.reshaped()[i]); + if (i < num_elements - 1) + fprintf(f, ","); + } +} + +static void create_directory(const char* dir, int verbose) { + // Attempt to create directory + if (mkdir(dir, S_IRWXU | S_IRWXG | S_IROTH)) { + if (errno == EEXIST) { // Skip if directory already exists + if (verbose) + std::cout << dir << " already exists, skipping." << std::endl; + } else { + ERROR_MSG(EXIT_FAILURE, "Failed to create directory %s", dir); + } + } +} + +// TODO: Make this fail if tiny_setup has not already been called +int tiny_codegen(TinySolver* solver, const char* output_dir, int verbose) { + if (!solver) { + std::cout << "Error in tiny_codegen: solver is nullptr" << std::endl; + return 1; + } + int status = 0; + status |= codegen_create_directories(output_dir, verbose); + status |= codegen_data_header(output_dir, verbose); + status |= codegen_data_source(solver, output_dir, verbose); + status |= codegen_example(output_dir, verbose); + + return status; +} + +int tiny_codegen_with_sensitivity( + TinySolver* solver, + const char* output_dir, + tinyMatrix* dK, + tinyMatrix* dP, + tinyMatrix* dC1, + tinyMatrix* dC2, + int verbose +) { + if (!solver) { + std::cout << "Error in tiny_codegen_with_sensitivity: solver is nullptr" << std::endl; + return 1; + } + + // Only store sensitivity matrices if adaptive rho is enabled + if (solver->settings->adaptive_rho) { + // Store the sensitivity matrices in the solver's cache + solver->cache->dKinf_drho = *dK; + solver->cache->dPinf_drho = *dP; + solver->cache->dC1_drho = *dC1; + solver->cache->dC2_drho = *dC2; + } + + // Call the regular codegen function which will now include the sensitivity matrices if adaptive_rho is enabled + return tiny_codegen(solver, output_dir, verbose); +} + +// Create code generation folder structure in whichever directory the executable calling tiny_codegen was called +int codegen_create_directories(const char* output_dir, int verbose) { + // Create output folder (root folder for code generation) + create_directory(output_dir, verbose); + + // Create src folder + char src_dir[PATH_LENGTH]; + sprintf(src_dir, "%s/src/", output_dir); + create_directory(src_dir, verbose); + + // Create tinympc folder + char tinympc_dir[PATH_LENGTH]; + sprintf(tinympc_dir, "%s/tinympc/", output_dir); + create_directory(tinympc_dir, verbose); + + // // Create include folder + // char inc_dir[PATH_LENGTH]; + // sprintf(inc_dir, "%s/include/", output_dir); + // create_directory(inc_dir, verbose); + + return EXIT_SUCCESS; +} + +// Create inc/tiny_data.hpp file +int codegen_data_header(const char* output_dir, int verbose) { + char data_hpp_fname[PATH_LENGTH]; + FILE* data_hpp_f; + + sprintf(data_hpp_fname, "%s/tinympc/tiny_data.hpp", output_dir); + + // Open data header file + data_hpp_f = fopen(data_hpp_fname, "w+"); + if (data_hpp_f == NULL) + ERROR_MSG(EXIT_FAILURE, "Failed to open file %s", data_hpp_fname); + + // Preamble + time_t start_time; + time(&start_time); + fprintf(data_hpp_f, "/*\n"); + fprintf(data_hpp_f, " * This file was autogenerated by TinyMPC on %s", ctime(&start_time)); + fprintf(data_hpp_f, " */\n\n"); + + fprintf(data_hpp_f, "#pragma once\n\n"); + + fprintf(data_hpp_f, "#include \"types.hpp\"\n\n"); + + fprintf(data_hpp_f, "#ifdef __cplusplus\n"); + fprintf(data_hpp_f, "extern \"C\" {\n"); + fprintf(data_hpp_f, "#endif\n\n"); + + fprintf(data_hpp_f, "extern TinySolver tiny_solver;\n\n"); + + fprintf(data_hpp_f, "#ifdef __cplusplus\n"); + fprintf(data_hpp_f, "}\n"); + fprintf(data_hpp_f, "#endif\n"); + + // Close codegen data header file + fclose(data_hpp_f); + + if (verbose) { + printf("Data header generated in %s\n", data_hpp_fname); + } + return 0; +} + +// Create src/tiny_data.cpp file +int codegen_data_source(TinySolver* solver, const char* output_dir, int verbose) { + char data_cpp_fname[PATH_LENGTH]; + FILE* data_cpp_f; + + int nx = solver->work->nx; + int nu = solver->work->nu; + int N = solver->work->N; + + sprintf(data_cpp_fname, "%s/src/tiny_data.cpp", output_dir); + + // Open data source file + data_cpp_f = fopen(data_cpp_fname, "w+"); + if (data_cpp_f == NULL) + ERROR_MSG(EXIT_FAILURE, "Failed to open file %s", data_cpp_fname); + + // Preamble + time_t start_time; + time(&start_time); + fprintf(data_cpp_f, "/*\n"); + fprintf(data_cpp_f, " * This file was autogenerated by TinyMPC on %s", ctime(&start_time)); + fprintf(data_cpp_f, " */\n\n"); + + // Open extern C + fprintf(data_cpp_f, "#include \"tinympc/tiny_data.hpp\"\n\n"); + fprintf(data_cpp_f, "#ifdef __cplusplus\n"); + fprintf(data_cpp_f, "extern \"C\" {\n"); + fprintf(data_cpp_f, "#endif\n\n"); + + // Solution + fprintf(data_cpp_f, "/* Solution */\n"); + fprintf(data_cpp_f, "TinySolution solution = {\n"); + + fprintf(data_cpp_f, "\t%d,\t\t// iter\n", solver->solution->iter); + fprintf(data_cpp_f, "\t%d,\t\t// solved\n", solver->solution->solved); + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N); + print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N); + fprintf(data_cpp_f, ").finished(),\t// x\n"); // x solution + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1); + print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1)); + fprintf(data_cpp_f, ").finished(),\t// x\n"); // u solution + + fprintf(data_cpp_f, "};\n\n"); + + // Cache + fprintf(data_cpp_f, "/* Matrices that must be recomputed with changes in time step, rho */\n"); + fprintf(data_cpp_f, "TinyCache cache = {\n"); + + fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// rho (step size/penalty)\n", solver->cache->rho); + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, nx); + print_matrix(data_cpp_f, solver->cache->Kinf, nu * nx); + fprintf(data_cpp_f, ").finished(),\t// Kinf\n"); // Kinf + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx); + print_matrix(data_cpp_f, solver->cache->Pinf, nx * nx); + fprintf(data_cpp_f, ").finished(),\t// Pinf\n"); // Pinf + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, nu); + print_matrix(data_cpp_f, solver->cache->Quu_inv, nu * nu); + fprintf(data_cpp_f, ").finished(),\t// Quu_inv\n"); // Quu_inv + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx); + print_matrix(data_cpp_f, solver->cache->AmBKt, nx * nx); + fprintf(data_cpp_f, ").finished(),\t// AmBKt\n"); // AmBKt + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx); + print_matrix(data_cpp_f, solver->cache->C1, nx * nx); + fprintf(data_cpp_f, ").finished(),\t// C1\n"); // C1 + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx); + print_matrix(data_cpp_f, solver->cache->C2, nx * nx); + fprintf(data_cpp_f, ").finished()"); // C2, no comma if no sensitivity matrices + + // Only print sensitivity matrices if adaptive rho is enabled + if (solver->settings->adaptive_rho) { + fprintf(data_cpp_f, ",\t// C2\n"); // Add comma and comment for C2 if we have more matrices + + // Add sensitivity matrices within the struct initialization + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, nx); + print_matrix(data_cpp_f, solver->cache->dKinf_drho, nu * nx); + fprintf(data_cpp_f, ").finished(),\t// dKinf_drho\n"); // dKinf_drho + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx); + print_matrix(data_cpp_f, solver->cache->dPinf_drho, nx * nx); + fprintf(data_cpp_f, ").finished(),\t// dPinf_drho\n"); // dPinf_drho + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx); + print_matrix(data_cpp_f, solver->cache->dC1_drho, nx * nx); + fprintf(data_cpp_f, ").finished(),\t// dC1_drho\n"); // dC1_drho + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx); + print_matrix(data_cpp_f, solver->cache->dC2_drho, nx * nx); + fprintf(data_cpp_f, ").finished()\t// dC2_drho\n"); // dC2_drho + } else { + fprintf(data_cpp_f, "\t// C2\n"); // Just add comment for C2 + } + + fprintf(data_cpp_f, "};\n\n"); + + // Settings + fprintf(data_cpp_f, "/* User settings */\n"); + fprintf(data_cpp_f, "TinySettings settings = {\n"); + + fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// primal tolerance\n", solver->settings->abs_pri_tol); + fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// dual tolerance\n", solver->settings->abs_dua_tol); + fprintf(data_cpp_f, "\t%d,\t\t// max iterations\n", solver->settings->max_iter); + fprintf( + data_cpp_f, + "\t%d,\t\t// iterations per termination check\n", + solver->settings->check_termination + ); + fprintf(data_cpp_f, "\t%d,\t\t// enable state constraints\n", solver->settings->en_state_bound); + fprintf(data_cpp_f, "\t%d\t\t// enable input constraints\n", solver->settings->en_input_bound); + + fprintf(data_cpp_f, "};\n\n"); + + // Workspace + fprintf(data_cpp_f, "/* Problem variables */\n"); + fprintf(data_cpp_f, "TinyWorkspace work = {\n"); + + fprintf(data_cpp_f, "\t%d,\t// Number of states\n", nx); + fprintf(data_cpp_f, "\t%d,\t// Number of control inputs\n", nu); + fprintf(data_cpp_f, "\t%d,\t// Number of knotpoints in the horizon\n", N); + + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N); + print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N); + fprintf(data_cpp_f, ").finished(),\t// x\n"); // x + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1); + print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1)); + fprintf(data_cpp_f, ").finished(),\t// u\n"); // u + + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N); + print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N); + fprintf(data_cpp_f, ").finished(),\t// q\n"); // q + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1); + print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1)); + fprintf(data_cpp_f, ").finished(),\t// r\n"); // r + + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N); + print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N); + fprintf(data_cpp_f, ").finished(),\t// p\n"); // p + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1); + print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1)); + fprintf(data_cpp_f, ").finished(),\t// d\n"); // d + + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N); + print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N); + fprintf(data_cpp_f, ").finished(),\t// v\n"); // v + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N); + print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N); + fprintf(data_cpp_f, ").finished(),\t// vnew\n"); // vnew + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1); + print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1)); + fprintf(data_cpp_f, ").finished(),\t// z\n"); // z + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1); + print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1)); + fprintf(data_cpp_f, ").finished(),\t// znew\n"); // znew + + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N); + print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N); + fprintf(data_cpp_f, ").finished(),\t// g\n"); // g + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1); + print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1)); + fprintf(data_cpp_f, ").finished(),\t// y\n"); // y + + fprintf(data_cpp_f, "\t(tinyVector(%d) << ", nx); + print_matrix(data_cpp_f, solver->work->Q, nx); + fprintf(data_cpp_f, ").finished(),\t// Q\n"); // Q + fprintf(data_cpp_f, "\t(tinyVector(%d) << ", nu); + print_matrix(data_cpp_f, solver->work->R, nu); + fprintf(data_cpp_f, ").finished(),\t// R\n"); // R + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx); + print_matrix(data_cpp_f, solver->work->Adyn, nx * nx); + fprintf(data_cpp_f, ").finished(),\t// Adyn\n"); // Adyn + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nu); + print_matrix(data_cpp_f, solver->work->Bdyn, nx * nu); + fprintf(data_cpp_f, ").finished(),\t// Bdyn\n"); // Bdyn + + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N); + print_matrix(data_cpp_f, solver->work->x_min, nx * N); + fprintf(data_cpp_f, ").finished(),\t// x_min\n"); // x_min + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N); + print_matrix(data_cpp_f, solver->work->x_max, nx * N); + fprintf(data_cpp_f, ").finished(),\t// x_max\n"); // x_max + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1); + print_matrix(data_cpp_f, solver->work->u_min, nu * (N - 1)); + fprintf(data_cpp_f, ").finished(),\t// u_min\n"); // u_min + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1); + print_matrix(data_cpp_f, solver->work->u_max, nu * (N - 1)); + fprintf(data_cpp_f, ").finished(),\t// u_max\n"); // u_max + + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N); + print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N); + fprintf(data_cpp_f, ").finished(),\t// Xref\n"); // Xref + fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1); + print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1)); + fprintf(data_cpp_f, ").finished(),\t// Uref\n"); // Uref + + fprintf(data_cpp_f, "\t(tinyVector(%d) << ", nu); + print_matrix(data_cpp_f, MatrixXd::Zero(nu, 1), nu); + fprintf(data_cpp_f, ").finished(),\t// Qu\n"); // Qu + + fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// state primal residual\n", 0.0); + fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// input primal residual\n", 0.0); + fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// state dual residual\n", 0.0); + fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// input dual residual\n", 0.0); + fprintf(data_cpp_f, "\t%d,\t// solve status\n", 0); + fprintf(data_cpp_f, "\t%d,\t// solve iteration\n", 0); + + fprintf(data_cpp_f, "};\n\n"); + + // Write solver struct definition to workspace file + fprintf(data_cpp_f, "TinySolver tiny_solver = {&solution, &settings, &cache, &work};\n\n"); + + // Close extern C + fprintf(data_cpp_f, "#ifdef __cplusplus\n"); + fprintf(data_cpp_f, "}\n"); + fprintf(data_cpp_f, "#endif\n\n"); + + // Close codegen data file + fclose(data_cpp_f); + if (verbose) { + printf("Data generated in %s\n", data_cpp_fname); + } + return 0; +} + +int codegen_example(const char* output_dir, int verbose) { + char example_cpp_fname[PATH_LENGTH]; + FILE* example_cpp_f; + + sprintf(example_cpp_fname, "%s/src/tiny_main.cpp", output_dir); + + // Open example file + example_cpp_f = fopen(example_cpp_fname, "w+"); + if (example_cpp_f == NULL) + ERROR_MSG(EXIT_FAILURE, "Failed to open file %s", example_cpp_fname); + + // Preamble + time_t start_time; + time(&start_time); + fprintf(example_cpp_f, "/*\n"); + fprintf(example_cpp_f, " * This file was autogenerated by TinyMPC on %s", ctime(&start_time)); + fprintf(example_cpp_f, " */\n\n"); + + fprintf(example_cpp_f, "#include \n\n"); + + fprintf(example_cpp_f, "#include \n"); + fprintf(example_cpp_f, "#include \n\n"); + + fprintf(example_cpp_f, "using namespace Eigen;\n"); + fprintf(example_cpp_f, "IOFormat TinyFmt(4, 0, \", \", \"\\n\", \"[\", \"]\");\n\n"); + + fprintf(example_cpp_f, "#ifdef __cplusplus\n"); + fprintf(example_cpp_f, "extern \"C\" {\n"); + fprintf(example_cpp_f, "#endif\n\n"); + + fprintf(example_cpp_f, "int main()\n"); + fprintf(example_cpp_f, "{\n"); + fprintf(example_cpp_f, "\tint exitflag = 1;\n"); + fprintf(example_cpp_f, "\t// Double check some data\n"); + fprintf(example_cpp_f, "\tstd::cout << \"rho: \" << tiny_solver.cache->rho << std::endl;\n"); + fprintf( + example_cpp_f, + "\tstd::cout << \"\\nmax iters: \" << tiny_solver.settings->max_iter << std::endl;\n" + ); + fprintf( + example_cpp_f, + "\tstd::cout << \"\\nState transition matrix:\\n\" << tiny_solver.work->Adyn.format(TinyFmt) << std::endl;\n" + ); + fprintf( + example_cpp_f, + "\tstd::cout << \"\\nInput/control matrix:\\n\" << tiny_solver.work->Bdyn.format(TinyFmt) << std::endl;\n\n" + ); + + fprintf( + example_cpp_f, + "\t// Visit https://tinympc.org/ to see how to set the initial condition and update the reference trajectory.\n\n" + ); + + fprintf(example_cpp_f, "\tstd::cout << \"\\nSolving...\\n\" << std::endl;\n\n"); + fprintf(example_cpp_f, "\texitflag = tiny_solve(&tiny_solver);\n\n"); + fprintf(example_cpp_f, "\tif (exitflag == 0) printf(\"Hooray! Solved with no error!\\n\");\n"); + fprintf(example_cpp_f, "\telse printf(\"Oops! Something went wrong!\\n\");\n"); + + fprintf(example_cpp_f, "\treturn 0;\n"); + fprintf(example_cpp_f, "}\n\n"); + + fprintf(example_cpp_f, "#ifdef __cplusplus\n"); + fprintf(example_cpp_f, "} /* extern \"C\" */\n"); + fprintf(example_cpp_f, "#endif\n"); + + // Close codegen example main file + fclose(example_cpp_f); + if (verbose) { + printf("Example tinympc main generated in %s\n", example_cpp_fname); + } + return 0; +} + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_control/tinympc/codegen.hpp b/wust_vision-main/tasks/auto_aim/armor_control/tinympc/codegen.hpp new file mode 100644 index 0000000..89318e7 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_control/tinympc/codegen.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include "types.hpp" + +#ifdef __cplusplus +extern "C" { +#endif + +int tiny_codegen(TinySolver* solver, const char* output_dir, int verbose); + +int tiny_codegen_with_sensitivity( + TinySolver* solver, + const char* output_dir, + tinyMatrix* dK, + tinyMatrix* dP, + tinyMatrix* dC1, + tinyMatrix* dC2, + int verbose +); + +int codegen_create_directories(const char* output_dir, int verbose); +int codegen_data_header(const char* output_dir, int verbose); +int codegen_data_source(TinySolver* solver, const char* output_dir, int verbose); +int codegen_example(const char* output_dir, int verbose); + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_control/tinympc/error.hpp b/wust_vision-main/tasks/auto_aim/armor_control/tinympc/error.hpp new file mode 100644 index 0000000..67b4b71 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_control/tinympc/error.hpp @@ -0,0 +1,29 @@ +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include +#include +#include + +// #if defined(__linux__) || defined(__unix__)// Check if Linux +// #include +// #define ERROR_MSG(exit_code, format, ...) error(exit_code, errno, format, ##__VA_ARGS__) + +// #elif defined(__APPLE__) || defined(__MACH__) // Check if macOS +#define ERROR_MSG(exit_code, format, ...) \ + { \ + fprintf(stderr, format ": %s\n", ##__VA_ARGS__, strerror(errno)); \ + exit(exit_code); \ + } + +// #else +// #error "Unsupported operating system" +// #endif + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_control/tinympc/rho_benchmark.cpp b/wust_vision-main/tasks/auto_aim/armor_control/tinympc/rho_benchmark.cpp new file mode 100644 index 0000000..eb8cb28 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_control/tinympc/rho_benchmark.cpp @@ -0,0 +1,252 @@ +#include "rho_benchmark.hpp" +#include +#include +#include +#ifdef ARDUINO + #include +#else +// For non-Arduino platforms +uint32_t micros() { + return 0; // Replace with appropriate timing function +} +#endif + +void initialize_format_matrices(RhoAdapter* adapter, int nx, int nu, int N) { + // Calculate dimensions + int x_decision_size = nx * N + nu * (N - 1); + int constraint_rows = (nx + nu) * (N - 1); + + // Pre-allocate matrices + adapter->A_matrix = tinyMatrix::Zero(constraint_rows, x_decision_size); + adapter->z_vector = tinyMatrix::Zero(constraint_rows, 1); + adapter->y_vector = tinyMatrix::Zero(constraint_rows, 1); + adapter->x_decision = tinyMatrix::Zero(x_decision_size, 1); + + // Pre-compute P matrix structure + adapter->P_matrix = tinyMatrix::Zero(x_decision_size, x_decision_size); + adapter->q_vector = tinyMatrix::Zero(x_decision_size, 1); + + // Pre-allocate residual computation matrices + adapter->Ax_vector = tinyMatrix::Zero(constraint_rows, 1); + adapter->r_prim_vector = tinyMatrix::Zero(constraint_rows, 1); + adapter->r_dual_vector = tinyMatrix::Zero(x_decision_size, 1); + adapter->Px_vector = tinyMatrix::Zero(x_decision_size, 1); + adapter->ATy_vector = tinyMatrix::Zero(x_decision_size, 1); + + // Store dimensions + adapter->format_nx = nx; + adapter->format_nu = nu; + adapter->format_N = N; + + adapter->matrices_initialized = true; +} + +void format_matrices( + RhoAdapter* adapter, + const tinyMatrix& x_prev, + const tinyMatrix& u_prev, + const tinyMatrix& v_prev, + const tinyMatrix& z_prev, + const tinyMatrix& g_prev, + const tinyMatrix& y_prev, + TinyCache* cache, + TinyWorkspace* work, + int N +) { + if (!adapter->matrices_initialized) { + initialize_format_matrices(adapter, x_prev.rows(), u_prev.rows(), N); + } + + int nx = adapter->format_nx; + int nu = adapter->format_nu; + + // Fill x_decision + int x_idx = 0; + for (int i = 0; i < N; i++) { + adapter->x_decision.block(x_idx, 0, nx, 1) = x_prev.col(i); + x_idx += nx; + if (i < N - 1) { + adapter->x_decision.block(x_idx, 0, nu, 1) = u_prev.col(i); + x_idx += nu; + } + } + + // Clear A matrix for reuse + adapter->A_matrix.setZero(); + + // Fill A matrix with dynamics and input constraints + for (int i = 0; i < N - 1; i++) { + // Input constraints + int row_start = i * nu; + int col_start = i * (nx + nu) + nx; + adapter->A_matrix.block(row_start, col_start, nu, nu) = tinyMatrix::Identity(nu, nu); + + // Dynamics constraints + row_start = (N - 1) * nu + i * nx; + col_start = i * (nx + nu); + adapter->A_matrix.block(row_start, col_start, nx, nx) = work->Adyn; + adapter->A_matrix.block(row_start, col_start + nx, nx, nu) = work->Bdyn; + + int next_state_idx = col_start + nx + nu; + if (next_state_idx < adapter->A_matrix.cols()) { + adapter->A_matrix.block(row_start, next_state_idx, nx, nx) = + -tinyMatrix::Identity(nx, nx); + } + } + + // Fill z and y vectors + for (int i = 0; i < N - 1; i++) { + adapter->z_vector.block(i * nu, 0, nu, 1) = z_prev.col(i); + adapter->z_vector.block((N - 1) * nu + i * nx, 0, nx, 1) = v_prev.col(i + 1); + + adapter->y_vector.block(i * nu, 0, nu, 1) = y_prev.col(i); + adapter->y_vector.block((N - 1) * nu + i * nx, 0, nx, 1) = g_prev.col(i + 1); + } + + // Build P matrix (cost matrix) + adapter->P_matrix.setZero(); + + // Fill diagonal blocks + x_idx = 0; + for (int i = 0; i < N; i++) { + // State cost + if (i == N - 1) { + adapter->P_matrix.block(x_idx, x_idx, nx, nx) = cache->Pinf; + } else { + adapter->P_matrix.block(x_idx, x_idx, nx, nx) = work->Q.asDiagonal(); + } + x_idx += nx; + + // Input cost + if (i < N - 1) { + adapter->P_matrix.block(x_idx, x_idx, nu, nu) = work->R.asDiagonal(); + x_idx += nu; + } + } + + // Create q vector (linear cost vector) + x_idx = 0; + for (int i = 0; i < N; i++) { + // For simplicity, we'll use zero reference for now + // In a real implementation, you'd use your reference trajectory + tinyMatrix x_ref = tinyMatrix::Zero(nx, 1); + tinyMatrix delta_x = x_prev.col(i) - x_ref; + adapter->q_vector.block(x_idx, 0, nx, 1) = work->Q.asDiagonal() * delta_x; + x_idx += nx; + + if (i < N - 1) { + // For simplicity, we'll use zero reference for now + tinyMatrix u_ref = tinyMatrix::Zero(nu, 1); + tinyMatrix delta_u = u_prev.col(i) - u_ref; + adapter->q_vector.block(x_idx, 0, nu, 1) = work->R.asDiagonal() * delta_u; + x_idx += nu; + } + } +} + +void compute_residuals( + RhoAdapter* adapter, + tinytype* pri_res, + tinytype* dual_res, + tinytype* pri_norm, + tinytype* dual_norm +) { + // Compute Ax + adapter->Ax_vector = adapter->A_matrix * adapter->x_decision; + + // Compute primal residual + adapter->r_prim_vector = adapter->Ax_vector - adapter->z_vector; + *pri_res = adapter->r_prim_vector.cwiseAbs().maxCoeff(); + *pri_norm = + std::max(adapter->Ax_vector.cwiseAbs().maxCoeff(), adapter->z_vector.cwiseAbs().maxCoeff()); + + // Compute dual residual components + adapter->Px_vector = adapter->P_matrix * adapter->x_decision; + adapter->ATy_vector = adapter->A_matrix.transpose() * adapter->y_vector; + + // Compute full dual residual + adapter->r_dual_vector = adapter->Px_vector + adapter->q_vector + adapter->ATy_vector; + *dual_res = adapter->r_dual_vector.cwiseAbs().maxCoeff(); + + // Compute normalization + *dual_norm = std::max( + std::max( + adapter->Px_vector.cwiseAbs().maxCoeff(), + adapter->ATy_vector.cwiseAbs().maxCoeff() + ), + adapter->q_vector.cwiseAbs().maxCoeff() + ); +} + +tinytype predict_rho( + RhoAdapter* adapter, + tinytype pri_res, + tinytype dual_res, + tinytype pri_norm, + tinytype dual_norm, + tinytype current_rho +) { + const tinytype eps = 1e-10; + + tinytype normalized_pri = pri_res / (pri_norm + eps); + tinytype normalized_dual = dual_res / (dual_norm + eps); + + tinytype ratio = normalized_pri / (normalized_dual + eps); + + tinytype new_rho = current_rho * std::sqrt(ratio); + + if (adapter->clip) { + new_rho = std::min(std::max(new_rho, adapter->rho_min), adapter->rho_max); + } + + return new_rho; +} + +void update_matrices_with_derivatives(TinyCache* cache, tinytype new_rho) { + tinytype delta_rho = new_rho - cache->rho; + + cache->Kinf = cache->Kinf + delta_rho * cache->dKinf_drho; + cache->Pinf = cache->Pinf + delta_rho * cache->dPinf_drho; + cache->C1 = cache->C1 + delta_rho * cache->dC1_drho; + cache->C2 = cache->C2 + delta_rho * cache->dC2_drho; + + cache->rho = new_rho; +} + +void benchmark_rho_adaptation( + RhoAdapter* adapter, + const tinyMatrix& x_prev, + const tinyMatrix& u_prev, + const tinyMatrix& v_prev, + const tinyMatrix& z_prev, + const tinyMatrix& g_prev, + const tinyMatrix& y_prev, + TinyCache* cache, + TinyWorkspace* work, + int N, + RhoBenchmarkResult* result +) { + uint32_t start_time = micros(); + + // Format matrices + format_matrices(adapter, x_prev, u_prev, v_prev, z_prev, g_prev, y_prev, cache, work, N); + + // Compute residuals + tinytype pri_res, dual_res, pri_norm, dual_norm; + compute_residuals(adapter, &pri_res, &dual_res, &pri_norm, &dual_norm); + + // Predict new rho + tinytype new_rho = predict_rho(adapter, pri_res, dual_res, pri_norm, dual_norm, cache->rho); + + // Update matrices + update_matrices_with_derivatives(cache, new_rho); + + // Store results + result->time_us = micros() - start_time; + result->initial_rho = cache->rho; + result->final_rho = new_rho; + result->pri_res = pri_res; + result->dual_res = dual_res; + result->pri_norm = pri_norm; + result->dual_norm = dual_norm; +} \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_control/tinympc/rho_benchmark.hpp b/wust_vision-main/tasks/auto_aim/armor_control/tinympc/rho_benchmark.hpp new file mode 100644 index 0000000..e15a0cc --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_control/tinympc/rho_benchmark.hpp @@ -0,0 +1,94 @@ +#pragma once +#include "types.hpp" +#include + +struct RhoAdapter { + tinytype rho_min; + tinytype rho_max; + bool clip; + bool matrices_initialized; + + // Pre-allocated matrices for formatting + tinyMatrix A_matrix; + tinyMatrix z_vector; + tinyMatrix y_vector; + tinyMatrix x_decision; + tinyMatrix P_matrix; + tinyMatrix q_vector; + + // Pre-allocated matrices for residual computation + tinyMatrix Ax_vector; + tinyMatrix r_prim_vector; + tinyMatrix r_dual_vector; + tinyMatrix Px_vector; + tinyMatrix ATy_vector; + + // Dimensions + int format_nx; + int format_nu; + int format_N; +}; + +struct RhoBenchmarkResult { + uint32_t time_us; + tinytype initial_rho; + tinytype final_rho; + tinytype pri_res; + tinytype dual_res; + tinytype pri_norm; + tinytype dual_norm; +}; + +// Initialize matrices for formatting +void initialize_format_matrices(RhoAdapter* adapter, int nx, int nu, int N); + +// Format matrices for residual computation +void format_matrices( + RhoAdapter* adapter, + const tinyMatrix& x_prev, + const tinyMatrix& u_prev, + const tinyMatrix& v_prev, + const tinyMatrix& z_prev, + const tinyMatrix& g_prev, + const tinyMatrix& y_prev, + TinyCache* cache, + TinyWorkspace* work, + int N +); + +// Compute residuals +void compute_residuals( + RhoAdapter* adapter, + tinytype* pri_res, + tinytype* dual_res, + tinytype* pri_norm, + tinytype* dual_norm +); + +// Predict new rho value +tinytype predict_rho( + RhoAdapter* adapter, + tinytype pri_res, + tinytype dual_res, + tinytype pri_norm, + tinytype dual_norm, + tinytype current_rho +); + +// Update matrices using derivatives +void update_matrices_with_derivatives(TinyCache* cache, tinytype new_rho); + +// Main benchmark function +void benchmark_rho_adaptation( + RhoAdapter* adapter, + const tinyMatrix& x_prev, + const tinyMatrix& u_prev, + const tinyMatrix& v_prev, + const tinyMatrix& z_prev, + const tinyMatrix& g_prev, + const tinyMatrix& y_prev, + TinyCache* cache, + TinyWorkspace* work, + int N, + RhoBenchmarkResult* result +); \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_control/tinympc/tiny_api.cpp b/wust_vision-main/tasks/auto_aim/armor_control/tinympc/tiny_api.cpp new file mode 100644 index 0000000..e9b7613 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_control/tinympc/tiny_api.cpp @@ -0,0 +1,876 @@ +#include "tiny_api.hpp" +#include "tiny_api_constants.hpp" + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +using namespace Eigen; +IOFormat TinyApiFmt(4, 0, ", ", "\n", "[", "]"); + +static int +check_dimension(std::string matrix_name, std::string rows_or_columns, int actual, int expected) { + if (actual != expected) { + std::cout << matrix_name << " has " << actual << " " << rows_or_columns << ". Expected " + << expected << "." << std::endl; + return 1; + } + return 0; +} + +int tiny_setup( + TinySolver** solverp, + tinyMatrix Adyn, + tinyMatrix Bdyn, + tinyMatrix fdyn, + tinyMatrix Q, + tinyMatrix R, + tinytype rho, + int nx, + int nu, + int N, + int verbose +) { + TinySolution* solution = new TinySolution(); + TinyCache* cache = new TinyCache(); + TinySettings* settings = new TinySettings(); + TinyWorkspace* work = new TinyWorkspace(); + TinySolver* solver = new TinySolver(); + + solver->solution = solution; + solver->cache = cache; + solver->settings = settings; + solver->work = work; + + *solverp = solver; + + // Initialize solution + solution->iter = 0; + solution->solved = 0; + solution->x = tinyMatrix::Zero(nx, N); + solution->u = tinyMatrix::Zero(nu, N - 1); + + // Initialize settings + tiny_set_default_settings(settings); + + // Initialize workspace + work->nx = nx; + work->nu = nu; + work->N = N; + + // Make sure arguments are the correct shapes + int status = 0; + status |= check_dimension("State transition matrix (A)", "rows", Adyn.rows(), nx); + status |= check_dimension("State transition matrix (A)", "columns", Adyn.cols(), nx); + status |= check_dimension("Input matrix (B)", "rows", Bdyn.rows(), nx); + status |= check_dimension("Input matrix (B)", "columns", Bdyn.cols(), nu); + status |= check_dimension("Affine vector (f)", "rows", fdyn.rows(), nx); + status |= check_dimension("Affine vector (f)", "columns", fdyn.cols(), 1); + status |= check_dimension("State stage cost (Q)", "rows", Q.rows(), nx); + status |= check_dimension("State stage cost (Q)", "columns", Q.cols(), nx); + status |= check_dimension("State input cost (R)", "rows", R.rows(), nu); + status |= check_dimension("State input cost (R)", "columns", R.cols(), nu); + if (status) { + return status; + } + + work->x = tinyMatrix::Zero(nx, N); + work->u = tinyMatrix::Zero(nu, N - 1); + + work->q = tinyMatrix::Zero(nx, N); + work->r = tinyMatrix::Zero(nu, N - 1); + + work->p = tinyMatrix::Zero(nx, N); + work->d = tinyMatrix::Zero(nu, N - 1); + + // Bound constraint slack variables + work->v = tinyMatrix::Zero(nx, N); + work->vnew = tinyMatrix::Zero(nx, N); + work->z = tinyMatrix::Zero(nu, N - 1); + work->znew = tinyMatrix::Zero(nu, N - 1); + + // Bound constraint dual variables + work->g = tinyMatrix::Zero(nx, N); + work->y = tinyMatrix::Zero(nu, N - 1); + + // Cone constraint slack variables + work->vc = tinyMatrix::Zero(nx, N); + work->vcnew = tinyMatrix::Zero(nx, N); + work->zc = tinyMatrix::Zero(nu, N - 1); + work->zcnew = tinyMatrix::Zero(nu, N - 1); + + // Cone constraint dual variables + work->gc = tinyMatrix::Zero(nx, N); + work->yc = tinyMatrix::Zero(nu, N - 1); + + // Linear constraint slack variables + work->vl = tinyMatrix::Zero(nx, N); + work->vlnew = tinyMatrix::Zero(nx, N); + work->zl = tinyMatrix::Zero(nu, N - 1); + work->zlnew = tinyMatrix::Zero(nu, N - 1); + + // Linear constraint dual variables + work->gl = tinyMatrix::Zero(nx, N); + work->yl = tinyMatrix::Zero(nu, N - 1); + + work->Q = (Q + rho * tinyMatrix::Identity(nx, nx)).diagonal(); + work->R = (R + rho * tinyMatrix::Identity(nu, nu)).diagonal(); + work->Adyn = Adyn; // State transition matrix + work->Bdyn = Bdyn; // Input matrix + work->fdyn = fdyn; // Affine offset vector + + work->Xref = tinyMatrix::Zero(nx, N); + work->Uref = tinyMatrix::Zero(nu, N - 1); + + work->Qu = tinyVector::Zero(nu); + + work->primal_residual_state = 0; + work->primal_residual_input = 0; + work->dual_residual_state = 0; + work->dual_residual_input = 0; + work->status = 0; + work->iter = 0; + + // Initialize cache + status = tiny_precompute_and_set_cache( + cache, + Adyn, + Bdyn, + fdyn, + work->Q.asDiagonal(), + work->R.asDiagonal(), + nx, + nu, + rho, + verbose + ); + if (status) { + return status; + } + + // Initialize sensitivity matrices for adaptive rho + if (solver->settings->adaptive_rho) { + tiny_initialize_sensitivity_matrices(solver); + } + + return 0; +} + +int tiny_set_bound_constraints( + TinySolver* solver, + tinyMatrix x_min, + tinyMatrix x_max, + tinyMatrix u_min, + tinyMatrix u_max +) { + if (!solver) { + std::cout << "Error in tiny_set_bound_constraints: solver is nullptr" << std::endl; + return 1; + } + + // Make sure all bound constraint matrix sizes are self-consistent + int status = 0; + status |= check_dimension("Lower state bounds (x_min)", "rows", x_min.rows(), solver->work->nx); + status |= check_dimension("Lower state bounds (x_min)", "cols", x_min.cols(), solver->work->N); + status |= check_dimension("Lower state bounds (x_max)", "rows", x_max.rows(), solver->work->nx); + status |= check_dimension("Lower state bounds (x_max)", "cols", x_max.cols(), solver->work->N); + status |= check_dimension("Lower input bounds (u_min)", "rows", u_min.rows(), solver->work->nu); + status |= + check_dimension("Lower input bounds (u_min)", "cols", u_min.cols(), solver->work->N - 1); + status |= check_dimension("Lower input bounds (u_max)", "rows", u_max.rows(), solver->work->nu); + status |= + check_dimension("Lower input bounds (u_max)", "cols", u_max.cols(), solver->work->N - 1); + + solver->work->x_min = x_min; + solver->work->x_max = x_max; + solver->work->u_min = u_min; + solver->work->u_max = u_max; + + return 0; +} + +int tiny_set_cone_constraints( + TinySolver* solver, + VectorXi Acx, + VectorXi qcx, + tinyVector cx, + VectorXi Acu, + VectorXi qcu, + tinyVector cu +) { + if (!solver) { + std::cout << "Error in tiny_set_cone_constraints: solver is nullptr" << std::endl; + return 1; + } + + // Make sure all cone constraint vector sizes are self-consistent + int num_state_cones = Acx.rows(); + int num_input_cones = Acu.rows(); + int status = 0; + status |= check_dimension("Cone state size (qcx)", "rows", qcx.rows(), num_state_cones); + status |= check_dimension("Cone mu value for state (cx)", "rows", cx.rows(), num_state_cones); + status |= check_dimension("Cone input size (qcu)", "rows", qcu.rows(), num_input_cones); + status |= check_dimension("Cone mu value for input (cu)", "rows", cu.rows(), num_input_cones); + if (status) { + return status; + } + + solver->work->numStateCones = num_state_cones; + solver->work->numInputCones = num_input_cones; + + solver->work->Acx = Acx; + solver->work->qcx = qcx; + solver->work->cx = cx; + + solver->work->Acu = Acu; + solver->work->qcu = qcu; + solver->work->cu = cu; + + return 0; +} + +int tiny_set_linear_constraints( + TinySolver* solver, + tinyMatrix Alin_x, + tinyVector blin_x, + tinyMatrix Alin_u, + tinyVector blin_u +) { + if (!solver) { + std::cout << "Error in tiny_set_linear_constraints: solver is nullptr" << std::endl; + return 1; + } + + // Make sure all linear constraint matrix sizes are self-consistent + int num_state_linear = Alin_x.rows(); + int num_input_linear = Alin_u.rows(); + int status = 0; + + // Check state constraint dimensions + if (num_state_linear > 0) { + status |= check_dimension( + "State linear constraint matrix (Alin_x)", + "rows", + Alin_x.rows(), + num_state_linear + ); + status |= check_dimension( + "State linear constraint matrix (Alin_x)", + "columns", + Alin_x.cols(), + solver->work->nx + ); + status |= check_dimension( + "State linear constraint vector (blin_x)", + "rows", + blin_x.rows(), + num_state_linear + ); + status |= + check_dimension("State linear constraint vector (blin_x)", "columns", blin_x.cols(), 1); + } + + // Check input constraint dimensions + if (num_input_linear > 0) { + status |= check_dimension( + "Input linear constraint matrix (Alin_u)", + "rows", + Alin_u.rows(), + num_input_linear + ); + status |= check_dimension( + "Input linear constraint matrix (Alin_u)", + "columns", + Alin_u.cols(), + solver->work->nu + ); + status |= check_dimension( + "Input linear constraint vector (blin_u)", + "rows", + blin_u.rows(), + num_input_linear + ); + status |= + check_dimension("Input linear constraint vector (blin_u)", "columns", blin_u.cols(), 1); + } + + if (status) { + return status; + } + + solver->work->numStateLinear = num_state_linear; + solver->work->numInputLinear = num_input_linear; + + solver->work->Alin_x = Alin_x; + solver->work->blin_x = blin_x; + solver->work->Alin_u = Alin_u; + solver->work->blin_u = blin_u; + + return 0; +} + +int tiny_precompute_and_set_cache( + TinyCache* cache, + tinyMatrix Adyn, + tinyMatrix Bdyn, + tinyMatrix fdyn, + tinyMatrix Q, + tinyMatrix R, + int nx, + int nu, + tinytype rho, + int verbose +) { + if (!cache) { + std::cout << "Error in tiny_precompute_and_set_cache: cache is nullptr" << std::endl; + return 1; + } + + // Update by adding rho * identity matrix to Q, R + tinyMatrix Q1 = Q + rho * tinyMatrix::Identity(nx, nx); + tinyMatrix R1 = R + rho * tinyMatrix::Identity(nu, nu); + + // Printing + if (verbose) { + std::cout << "A = " << Adyn.format(TinyApiFmt) << std::endl; + std::cout << "B = " << Bdyn.format(TinyApiFmt) << std::endl; + std::cout << "Q = " << Q1.format(TinyApiFmt) << std::endl; + std::cout << "R = " << R1.format(TinyApiFmt) << std::endl; + std::cout << "rho = " << rho << std::endl; + } + + // Riccati recursion to get Kinf, Pinf + tinyMatrix Ktp1 = tinyMatrix::Zero(nu, nx); + tinyMatrix Ptp1 = rho * tinyMatrix::Ones(nx, 1).array().matrix().asDiagonal(); + tinyMatrix Kinf = tinyMatrix::Zero(nu, nx); + tinyMatrix Pinf = tinyMatrix::Zero(nx, nx); + + for (int i = 0; i < 1000; i++) { + Kinf = (R1 + Bdyn.transpose() * Ptp1 * Bdyn).inverse() * Bdyn.transpose() * Ptp1 * Adyn; + Pinf = Q1 + Adyn.transpose() * Ptp1 * (Adyn - Bdyn * Kinf); + // if Kinf converges, break + if ((Kinf - Ktp1).cwiseAbs().maxCoeff() < 1e-5) { + if (verbose) { + std::cout << "Kinf converged after " << i + 1 << " iterations" << std::endl; + } + break; + } + Ktp1 = Kinf; + Ptp1 = Pinf; + } + + // Compute cached matrices + tinyMatrix Quu_inv = (R1 + Bdyn.transpose() * Pinf * Bdyn).inverse(); + tinyMatrix AmBKt = (Adyn - Bdyn * Kinf).transpose(); + + // Precomputation for affine term + tinyVector APf = AmBKt * Pinf * fdyn; + tinyVector BPf = Bdyn.transpose() * Pinf * fdyn; + + if (verbose) { + std::cout << "Kinf = " << Kinf.format(TinyApiFmt) << std::endl; + std::cout << "Pinf = " << Pinf.format(TinyApiFmt) << std::endl; + std::cout << "Quu_inv = " << Quu_inv.format(TinyApiFmt) << std::endl; + std::cout << "AmBKt = " << AmBKt.format(TinyApiFmt) << std::endl; + std::cout << "APf = " << APf.format(TinyApiFmt) << std::endl; + std::cout << "BPf = " << BPf.format(TinyApiFmt) << std::endl; + + std::cout << "\nPrecomputation finished!\n" << std::endl; + } + + cache->rho = rho; + cache->Kinf = Kinf; + cache->Pinf = Pinf; + cache->Quu_inv = Quu_inv; + cache->AmBKt = AmBKt; + cache->C1 = Quu_inv; + cache->C2 = AmBKt; + cache->APf = APf; + cache->BPf = BPf; + + return 0; // return success +} + +int tiny_solve(TinySolver* solver) { + return solve(solver); +} + +int tiny_update_settings( + TinySettings* settings, + tinytype abs_pri_tol, + tinytype abs_dua_tol, + int max_iter, + int check_termination, + int en_state_bound, + int en_input_bound, + int en_state_soc, + int en_input_soc, + int en_state_linear, + int en_input_linear +) { + if (!settings) { + std::cout << "Error in tiny_update_settings: settings is nullptr" << std::endl; + return 1; + } + settings->abs_pri_tol = abs_pri_tol; + settings->abs_dua_tol = abs_dua_tol; + settings->max_iter = max_iter; + settings->check_termination = check_termination; + settings->en_state_bound = en_state_bound; + settings->en_input_bound = en_input_bound; + settings->en_state_soc = en_state_soc; + settings->en_input_soc = en_input_soc; + settings->en_state_linear = en_state_linear; + settings->en_input_linear = en_input_linear; + return 0; +} + +int tiny_set_default_settings(TinySettings* settings) { + if (!settings) { + std::cout << "Error in tiny_set_default_settings: settings is nullptr" << std::endl; + return 1; + } + settings->abs_pri_tol = TINY_DEFAULT_ABS_PRI_TOL; + settings->abs_dua_tol = TINY_DEFAULT_ABS_DUA_TOL; + settings->max_iter = TINY_DEFAULT_MAX_ITER; + settings->check_termination = TINY_DEFAULT_CHECK_TERMINATION; + + // Turn off constraints until they are set by tiny_set_bound_constraints or tiny_set_cone_constraints + settings->en_state_bound = TINY_DEFAULT_EN_STATE_BOUND; + settings->en_input_bound = TINY_DEFAULT_EN_INPUT_BOUND; + settings->en_state_soc = TINY_DEFAULT_EN_STATE_SOC; + settings->en_input_soc = TINY_DEFAULT_EN_INPUT_SOC; + settings->en_state_linear = TINY_DEFAULT_EN_STATE_LINEAR; + settings->en_input_linear = TINY_DEFAULT_EN_INPUT_LINEAR; + + // Initialize adaptive rho settings + // NOTE : Adaptive rho currently supports only quadrotor system + settings->adaptive_rho = 0; // Disabled by default + settings->adaptive_rho_min = 1.0; + settings->adaptive_rho_max = 100.0; + settings->adaptive_rho_enable_clipping = 1; + + return 0; +} + +int tiny_set_x0(TinySolver* solver, tinyVector x0) { + if (!solver) { + std::cout << "Error in tiny_set_x0: solver is nullptr" << std::endl; + return 1; + } + if (x0.rows() != solver->work->nx) { + perror("Error in tiny_set_x0: x0 is not the correct length"); + } + solver->work->x.col(0) = x0; + return 0; +} + +int tiny_set_x_ref(TinySolver* solver, tinyMatrix x_ref) { + if (!solver) { + std::cout << "Error in tiny_set_x_ref: solver is nullptr" << std::endl; + return 1; + } + int status = 0; + status |= check_dimension( + "State reference trajectory (x_ref)", + "rows", + x_ref.rows(), + solver->work->nx + ); + status |= check_dimension( + "State reference trajectory (x_ref)", + "columns", + x_ref.cols(), + solver->work->N + ); + solver->work->Xref = x_ref; + return 0; +} + +int tiny_set_u_ref(TinySolver* solver, tinyMatrix u_ref) { + if (!solver) { + std::cout << "Error in tiny_set_u_ref: solver is nullptr" << std::endl; + return 1; + } + int status = 0; + status |= check_dimension( + "Control/input reference trajectory (u_ref)", + "rows", + u_ref.rows(), + solver->work->nu + ); + status |= check_dimension( + "Control/input reference trajectory (u_ref)", + "columns", + u_ref.cols(), + solver->work->N - 1 + ); + solver->work->Uref = u_ref; + return 0; +} + +void tiny_initialize_sensitivity_matrices(TinySolver* solver) { + int nu = solver->work->nu; + int nx = solver->work->nx; + // Initialize matrices with zeros + solver->cache->dKinf_drho = tinyMatrix::Zero(nu, nx); + solver->cache->dPinf_drho = tinyMatrix::Zero(nx, nx); + solver->cache->dC1_drho = tinyMatrix::Zero(nu, nu); + solver->cache->dC2_drho = tinyMatrix::Zero(nx, nx); + + const float dKinf_drho[4][12] = { { 0.0001, + -0.0001, + -0.0025, + 0.0003, + 0.0007, + 0.0050, + 0.0001, + -0.0001, + -0.0008, + 0.0000, + 0.0001, + 0.0008 }, + { -0.0001, + -0.0000, + -0.0025, + -0.0001, + -0.0006, + -0.0050, + -0.0001, + 0.0000, + -0.0008, + -0.0000, + -0.0001, + -0.0008 }, + { 0.0000, + 0.0000, + -0.0025, + 0.0001, + 0.0004, + 0.0050, + 0.0000, + 0.0000, + -0.0008, + 0.0000, + 0.0000, + 0.0008 }, + { -0.0000, + 0.0001, + -0.0025, + -0.0003, + -0.0004, + -0.0050, + -0.0000, + 0.0001, + -0.0008, + -0.0000, + -0.0000, + -0.0008 } }; + + const float dPinf_drho[12][12] = { { 0.0494, + -0.0045, + -0.0000, + 0.0110, + 0.1300, + -0.0283, + 0.0280, + -0.0026, + -0.0000, + 0.0004, + 0.0070, + -0.0094 }, + { -0.0045, + 0.0491, + 0.0000, + -0.1320, + -0.0111, + 0.0114, + -0.0026, + 0.0279, + 0.0000, + -0.0076, + -0.0004, + 0.0038 }, + { -0.0000, + 0.0000, + 2.4450, + 0.0000, + -0.0000, + -0.0000, + -0.0000, + 0.0000, + 1.2593, + 0.0000, + 0.0000, + 0.0000 }, + { 0.0110, + -0.1320, + 0.0000, + 0.3913, + 0.0592, + 0.3108, + 0.0080, + -0.0776, + 0.0000, + 0.0254, + 0.0068, + 0.0750 }, + { 0.1300, + -0.0111, + -0.0000, + 0.0592, + 0.4420, + 0.7771, + 0.0797, + -0.0081, + -0.0000, + 0.0068, + 0.0350, + 0.1875 }, + { -0.0283, + 0.0114, + -0.0000, + 0.3108, + 0.7771, + 10.0441, + 0.0272, + -0.0109, + 0.0000, + 0.0655, + 0.1639, + 2.6362 }, + { 0.0280, + -0.0026, + -0.0000, + 0.0080, + 0.0797, + 0.0272, + 0.0163, + -0.0016, + -0.0000, + 0.0005, + 0.0047, + 0.0032 }, + { -0.0026, + 0.0279, + 0.0000, + -0.0776, + -0.0081, + -0.0109, + -0.0016, + 0.0161, + 0.0000, + -0.0046, + -0.0005, + -0.0013 }, + { -0.0000, + 0.0000, + 1.2593, + 0.0000, + -0.0000, + 0.0000, + -0.0000, + 0.0000, + 0.9232, + 0.0000, + 0.0000, + 0.0000 }, + { 0.0004, + -0.0076, + 0.0000, + 0.0254, + 0.0068, + 0.0655, + 0.0005, + -0.0046, + 0.0000, + 0.0022, + 0.0017, + 0.0244 }, + { 0.0070, + -0.0004, + 0.0000, + 0.0068, + 0.0350, + 0.1639, + 0.0047, + -0.0005, + 0.0000, + 0.0017, + 0.0054, + 0.0610 }, + { -0.0094, + 0.0038, + 0.0000, + 0.0750, + 0.1875, + 2.6362, + 0.0032, + -0.0013, + 0.0000, + 0.0244, + 0.0610, + 0.9869 } }; + + const float dC1_drho[4][4] = { { -0.0000, 0.0000, -0.0000, 0.0000 }, + { 0.0000, -0.0000, 0.0000, -0.0000 }, + { -0.0000, 0.0000, -0.0000, 0.0000 }, + { 0.0000, -0.0000, 0.0000, -0.0000 } }; + + const float dC2_drho[12][12] = { { 0.0000, + -0.0000, + 0.0000, + 0.0000, + 0.0000, + -0.0000, + 0.0000, + -0.0000, + 0.0000, + 0.0000, + 0.0000, + -0.0000 }, + { -0.0000, + 0.0000, + 0.0000, + -0.0000, + -0.0000, + 0.0000, + -0.0000, + 0.0000, + 0.0000, + -0.0000, + -0.0000, + 0.0000 }, + { -0.0000, + 0.0000, + 0.0001, + 0.0000, + -0.0000, + -0.0000, + -0.0000, + 0.0000, + 0.0000, + 0.0000, + -0.0000, + -0.0000 }, + { 0.0000, + -0.0000, + -0.0000, + 0.0001, + 0.0000, + -0.0000, + 0.0000, + -0.0000, + -0.0000, + 0.0000, + 0.0000, + -0.0000 }, + { 0.0000, + -0.0000, + -0.0000, + 0.0000, + 0.0001, + -0.0000, + 0.0000, + -0.0000, + -0.0000, + 0.0000, + 0.0000, + -0.0000 }, + { -0.0000, + 0.0000, + -0.0000, + -0.0000, + 0.0000, + 0.0001, + -0.0000, + 0.0000, + -0.0000, + 0.0000, + 0.0000, + 0.0000 }, + { 0.0000, + -0.0000, + 0.0000, + 0.0000, + 0.0000, + -0.0000, + 0.0000, + -0.0000, + 0.0000, + 0.0000, + 0.0000, + -0.0000 }, + { -0.0000, + 0.0000, + 0.0000, + -0.0000, + -0.0000, + 0.0000, + -0.0000, + 0.0000, + 0.0000, + -0.0000, + -0.0000, + 0.0000 }, + { -0.0000, + 0.0000, + 0.0021, + 0.0000, + -0.0000, + -0.0000, + -0.0000, + 0.0000, + 0.0006, + 0.0000, + -0.0000, + -0.0000 }, + { 0.0002, + -0.0027, + -0.0000, + 0.0068, + 0.0005, + -0.0005, + 0.0001, + -0.0015, + -0.0000, + 0.0004, + 0.0000, + -0.0001 }, + { 0.0027, + -0.0002, + 0.0000, + 0.0005, + 0.0066, + -0.0011, + 0.0015, + -0.0001, + 0.0000, + 0.0000, + 0.0004, + -0.0002 }, + { -0.0001, + 0.0001, + 0.0000, + -0.0000, + 0.0000, + 0.0041, + -0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0006 } }; + + // Map arrays to Eigen matrices + solver->cache->dKinf_drho = Map>(dKinf_drho[0]).cast(); + solver->cache->dPinf_drho = Map>(dPinf_drho[0]).cast(); + solver->cache->dC1_drho = Map>(dC1_drho[0]).cast(); + solver->cache->dC2_drho = Map>(dC2_drho[0]).cast(); +} + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_control/tinympc/tiny_api.hpp b/wust_vision-main/tasks/auto_aim/armor_control/tinympc/tiny_api.hpp new file mode 100644 index 0000000..329fd90 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_control/tinympc/tiny_api.hpp @@ -0,0 +1,118 @@ +#pragma once + +#include "admm.hpp" +#include + +#ifdef __cplusplus +extern "C" { +#endif + +int tiny_setup( + TinySolver** solverp, + tinyMatrix Adyn, + tinyMatrix Bdyn, + tinyMatrix fdyn, + tinyMatrix Q, + tinyMatrix R, + tinytype rho, + int nx, + int nu, + int N, + int verbose +); +int tiny_set_bound_constraints( + TinySolver* solver, + tinyMatrix x_min, + tinyMatrix x_max, + tinyMatrix u_min, + tinyMatrix u_max +); +int tiny_set_cone_constraints( + TinySolver* solver, + VectorXi Acu, + VectorXi qcu, + tinyVector cu, + VectorXi Acx, + VectorXi qcx, + tinyVector cx +); +int tiny_set_linear_constraints( + TinySolver* solver, + tinyMatrix Alin_x, + tinyVector blin_x, + tinyMatrix Alin_u, + tinyVector blin_u +); +int tiny_precompute_and_set_cache( + TinyCache* cache, + tinyMatrix Adyn, + tinyMatrix Bdyn, + tinyMatrix fdyn, + tinyMatrix Q, + tinyMatrix R, + int nx, + int nu, + tinytype rho, + int verbose +); + +void compute_sensitivity_matrices( + TinyCache* cache, + tinyMatrix Adyn, + tinyMatrix Bdyn, + tinyMatrix Q, + tinyMatrix R, + int nx, + int nu, + tinytype rho, + int verbose +); + +int tiny_update_matrices_with_derivatives(TinyCache* cache, tinytype delta_rho); +int tiny_solve(TinySolver* solver); + +int tiny_update_settings( + TinySettings* settings, + tinytype abs_pri_tol, + tinytype abs_dua_tol, + int max_iter, + int check_termination, + int en_state_bound, + int en_input_bound, + int en_state_soc, + int en_input_soc, + int en_state_linear, + int en_input_linear +); +int tiny_set_default_settings(TinySettings* settings); + +int tiny_set_x0(TinySolver* solver, tinyVector x0); +int tiny_set_x_ref(TinySolver* solver, tinyMatrix x_ref); +int tiny_set_u_ref(TinySolver* solver, tinyMatrix u_ref); + +/** + * Initialize sensitivity matrices for adaptive rho + * + * @param solver Pointer to solver + */ +void tiny_initialize_sensitivity_matrices(TinySolver* solver); + +int tiny_setup_state_soc_constraints( + TinySolver* solver, + tinyVector Acx, + tinyVector qcx, + tinyVector cx, + int numStateCones +); + +int tiny_setup_input_soc_constraints( + TinySolver* solver, + tinyVector Acu, + tinyVector qcu, + tinyVector cu, + int numInputCones +); + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_control/tinympc/tiny_api_constants.hpp b/wust_vision-main/tasks/auto_aim/armor_control/tinympc/tiny_api_constants.hpp new file mode 100644 index 0000000..20deac1 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_control/tinympc/tiny_api_constants.hpp @@ -0,0 +1,13 @@ +#pragma once + +// Default settings +#define TINY_DEFAULT_ABS_PRI_TOL (1e-03) +#define TINY_DEFAULT_ABS_DUA_TOL (1e-03) +#define TINY_DEFAULT_MAX_ITER (1000) +#define TINY_DEFAULT_CHECK_TERMINATION (1) +#define TINY_DEFAULT_EN_STATE_BOUND (1) +#define TINY_DEFAULT_EN_INPUT_BOUND (1) +#define TINY_DEFAULT_EN_STATE_SOC (0) +#define TINY_DEFAULT_EN_INPUT_SOC (0) +#define TINY_DEFAULT_EN_STATE_LINEAR (0) +#define TINY_DEFAULT_EN_INPUT_LINEAR (0) diff --git a/wust_vision-main/tasks/auto_aim/armor_control/tinympc/types.hpp b/wust_vision-main/tasks/auto_aim/armor_control/tinympc/types.hpp new file mode 100644 index 0000000..aa5fddd --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_control/tinympc/types.hpp @@ -0,0 +1,197 @@ +#pragma once + +#include +// #include +// #include + +using namespace Eigen; + +#ifdef __cplusplus +extern "C" { +#endif + +typedef double tinytype; // should be double if you want to generate code +typedef Matrix tinyMatrix; +typedef Matrix tinyVector; + +// typedef Matrix tiny_VectorNx; +// typedef Matrix tiny_VectorNu; +// typedef Matrix tiny_MatrixNxNx; +// typedef Matrix tiny_MatrixNxNu; +// typedef Matrix tiny_MatrixNuNx; +// typedef Matrix tiny_MatrixNuNu; + +// typedef Matrix tiny_MatrixNxNh; // Nu x Nh +// typedef Matrix tiny_MatrixNuNhm1; // Nu x Nh-1 + +/** + * Solution + */ +typedef struct { + int iter; + int solved; + tinyMatrix x; // nx x N + tinyMatrix u; // nu x N-1 +} TinySolution; + +/** +* Matrices that must be recomputed with changes in time step, rho +*/ +typedef struct { + tinytype rho; + tinyMatrix Kinf; // nu x nx + tinyMatrix Pinf; // nx x nx + tinyMatrix Quu_inv; // nu x nu + tinyMatrix AmBKt; // nx x nx + tinyVector APf; // nx x 1 + tinyVector BPf; // nu x 1 + tinyMatrix C1; // From adaptive rho + tinyMatrix C2; // From adaptive rho + + // Sensitivity matrices for adaptive rho + tinyMatrix dKinf_drho; + tinyMatrix dPinf_drho; + tinyMatrix dC1_drho; + tinyMatrix dC2_drho; +} TinyCache; +/** +* User settings +*/ +typedef struct { + tinytype abs_pri_tol; + tinytype abs_dua_tol; + int max_iter; + int check_termination; + int en_state_bound; + int en_input_bound; + int en_state_soc; + int en_input_soc; + int en_state_linear; + int en_input_linear; + + // Add adaptive rho parameters + int adaptive_rho; // Enable/disable adaptive rho (1/0) + tinytype adaptive_rho_min; // Minimum value for rho + tinytype adaptive_rho_max; // Maximum value for rho + int adaptive_rho_enable_clipping; // Enable/disable clipping of rho (1/0) +} TinySettings; + +/** + * Problem variables + */ +typedef struct { + int nx; // Number of states + int nu; // Number of control inputs + int N; // Number of knotpoints in the horizon + + // State and input + tinyMatrix x; // nx x N + tinyMatrix u; // nu x N-1 + + // Linear control cost terms + tinyMatrix q; // nx x N + tinyMatrix r; // nu x N-1 + + // Linear Riccati backward pass terms + tinyMatrix p; // nx x N + tinyMatrix d; // nu x N-1 + + // Bound constraint variables + // Slack variables + tinyMatrix v; // nx x N + tinyMatrix vnew; // nx x N + tinyMatrix z; // nu x N-1 + tinyMatrix znew; // nu x N-1 + + // Dual variables + tinyMatrix g; // nx x N + tinyMatrix y; // nu x N-1 + + // State and input bounds + tinyMatrix x_min; // nx x N + tinyMatrix x_max; // nx x N + tinyMatrix u_min; // nu x N-1 + tinyMatrix u_max; // nu x N-1 + + // Cone constraint variables + // Variables to keep track of general cone information + int numStateCones; // Number of cone constraints on states at each time step + int numInputCones; // Number of cone constraints on inputs at each time step + tinyVector cx; // One coefficient for each state cone + tinyVector cu; // One coefficient for each input cone + VectorXi Acx; // Start indices for each state cone + VectorXi Acu; // Start indices for each input cone + VectorXi qcx; // Dimension for each state cone + VectorXi qcu; // Dimension for each input cone + + // Slack variables + tinyMatrix vc; // nx x N + tinyMatrix vcnew; // nx x N + tinyMatrix zc; // nu x N-1 + tinyMatrix zcnew; // nu x N-1 + + // Dual variables + tinyMatrix gc; // nx x N + tinyMatrix yc; // nu x N-1 + + // Linear constraint variables + // Variables to keep track of general linear constraint information + int numStateLinear; // Number of linear constraints on states at each time step + int numInputLinear; // Number of linear constraints on inputs at each time step + + // Constraint matrices and vectors + tinyMatrix Alin_x; // Normal vectors for state linear constraints (numStateLinear x nx) + tinyVector blin_x; // Offset values for state linear constraints (numStateLinear x 1) + tinyMatrix Alin_u; // Normal vectors for input linear constraints (numInputLinear x nu) + tinyVector blin_u; // Offset values for input linear constraints (numInputLinear x 1) + + // Slack variables for linear constraints + tinyMatrix vl; // nx x N + tinyMatrix vlnew; // nx x N + tinyMatrix zl; // nu x N-1 + tinyMatrix zlnew; // nu x N-1 + + // Dual variables for linear constraints + tinyMatrix gl; // nx x N + tinyMatrix yl; // nu x N-1 + + // Q, R, A, B, f given by user + tinyVector Q; // nx x 1 + tinyVector R; // nu x 1 + tinyMatrix Adyn; // nx x nx (state transition matrix) + tinyMatrix Bdyn; // nx x nu (control matrix) + tinyVector fdyn; // nx x 1 (affine vector) + + // Reference trajectory to track for one horizon + tinyMatrix Xref; // nx x N + tinyMatrix Uref; // nu x N-1 + + // Temporaries + tinyVector Qu; // nu x 1 + + // Variables for keeping track of solve status + tinytype primal_residual_state; + tinytype primal_residual_input; + tinytype dual_residual_state; + tinytype dual_residual_input; + int status; + int iter; +} TinyWorkspace; + +/** + * Main TinyMPC solver structure that holds all information. + */ +typedef struct { + TinySolution* solution; // Solution + TinySettings* settings; // Problem settings + TinyCache* cache; // Problem cache + TinyWorkspace* work; // Solver workspace +} TinySolver; + +// Add at the top with other definitions +#define BENCH_NX 12 +#define BENCH_NU 4 + +#ifdef __cplusplus +} +#endif diff --git a/wust_vision-main/tasks/auto_aim/armor_control/traj.hpp b/wust_vision-main/tasks/auto_aim/armor_control/traj.hpp new file mode 100644 index 0000000..c0f2392 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_control/traj.hpp @@ -0,0 +1,111 @@ +#pragma once + +#include +#include +#include +#include +#include +namespace wust_vision { +namespace auto_aim { + + template + concept HasStaticLerp = requires(const T& a, const T& b, double t) { + { + T::lerp(a, b, t) + } -> std::same_as; + }; + + template + class Trajectory { + public: + void reserve(size_t n) { + cp_vec.reserve(n); + dt_vec.reserve(n > 0 ? n - 1 : 0); + prefix_time.reserve(n); + } + + void clear() { + cp_vec.clear(); + dt_vec.clear(); + prefix_time.clear(); + total_duration_ = 0.0; + } + + void push_back(const PointT& p, double dt = 0.0) { + if (cp_vec.empty()) { + cp_vec.push_back(p); + prefix_time.push_back(0.0); + total_duration_ = 0.0; + return; + } + + assert(dt >= 0.0); + + cp_vec.push_back(p); + dt_vec.push_back(dt); + + total_duration_ += dt; + prefix_time.push_back(total_duration_); + } + + void set(const std::vector& c, const std::vector& t) { + assert(!c.empty()); + assert(c.size() == t.size() + 1); + + cp_vec = c; + dt_vec = t; + + prefix_time.resize(cp_vec.size()); + prefix_time[0] = 0.0; + + for (size_t i = 0; i < dt_vec.size(); ++i) + prefix_time[i + 1] = prefix_time[i] + dt_vec[i]; + + total_duration_ = prefix_time.back(); + } + double getPrefixTimeAtIdx(int i) const { + return prefix_time[i]; + } + PointT getStateAtIdx(int i) const { + return cp_vec[i]; + } + + PointT getStateAtTime(double t) const { + if (cp_vec.empty()) + return PointT {}; + + if (t <= 0.0) + return cp_vec.front(); + + if (t >= total_duration_) + return cp_vec.back(); + + auto it = std::lower_bound(prefix_time.begin(), prefix_time.end(), t); + size_t i1 = std::distance(prefix_time.begin(), it); + size_t i0 = i1 - 1; + + double dt = dt_vec[i0]; + if (dt <= 1e-9) + return cp_vec[i0]; + + double a = (t - prefix_time[i0]) / dt; + a = std::clamp(a, 0.0, 1.0); + + return PointT::lerp(cp_vec[i0], cp_vec[i1], a); + } + + double getTotalDuration() const { + return total_duration_; + } + size_t size() const { + return cp_vec.size(); + } + + std::vector cp_vec; + std::vector dt_vec; + std::vector prefix_time; + double total_duration_ { 0.0 }; + }; + +} // namespace auto_aim +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_control/very_aimer.cpp b/wust_vision-main/tasks/auto_aim/armor_control/very_aimer.cpp new file mode 100644 index 0000000..51759a1 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_control/very_aimer.cpp @@ -0,0 +1,1355 @@ +#include "very_aimer.hpp" +#include "tasks/auto_aim/armor_control/tinympc/tiny_api.hpp" +#include "tasks/auto_aim/armor_control/tinympc/types.hpp" +#include "tasks/auto_aim/armor_tracker/target.hpp" +#include "tasks/auto_aim/auto_aim_fsm.hpp" +#include "tasks/auto_aim/type.hpp" +#include "tasks/type_common.hpp" +#include "traj.hpp" +#include "wust_vl/common/utils/manual_compensator.hpp" +#include "wust_vl/common/utils/parameter.hpp" +#include +namespace wust_vision::auto_aim { + +struct VeryAimer::Impl { + [[nodiscard]] static inline double lerpAngle(double a0, double a1, double t) noexcept { + double d = std::remainder(a1 - a0, 2.0 * M_PI); + return a0 + t * d; + } + [[nodiscard]] inline static double rad2deg(double r) noexcept { + static constexpr double rad_deg = 180.0 / M_PI; + return r * rad_deg; + } + struct ControlPoint { + double yaw; + double pitch; + int id_in_target; + Eigen::Vector4d xyza; + }; + struct AimPoint { + Eigen::Vector3d pos; + double d_angle; + static inline AimPoint lerp(const AimPoint& p0, const AimPoint& p1, double a) noexcept { + AimPoint r; + r.pos = (1.0 - a) * p0.pos + a * p1.pos; + r.d_angle = lerpAngle(p0.d_angle, p1.d_angle, a); + return r; + } + }; + struct GimbalState { + struct State { + double p; + double v; + double a; + }; + State yaw_state; + State pitch_state; + int aim_id = 0; + GimbalState() {} + GimbalState(const GimbalState::State& y, const GimbalState::State& p): + yaw_state(y), + pitch_state(p) {} + static GimbalState lerp(const GimbalState& s0, const GimbalState& s1, double a) noexcept { + GimbalState r; + r.aim_id = (a > 0.5) ? s0.aim_id : s1.aim_id; + r.yaw_state = GimbalState::State { .p = lerpAngle(s0.yaw_state.p, s1.yaw_state.p, a), + .v = std::lerp(s0.yaw_state.v, s1.yaw_state.v, a), + .a = std::lerp(s0.yaw_state.a, s1.yaw_state.a, a) }; + r.pitch_state = + GimbalState::State { .p = lerpAngle(s0.pitch_state.p, s1.pitch_state.p, a), + .v = std::lerp(s0.pitch_state.v, s1.pitch_state.v, a), + .a = std::lerp(s0.pitch_state.a, s1.pitch_state.a, a) }; + + return r; + } + }; + + struct QuinticSegment { + double T = 0.0; + Eigen::Matrix c; + GimbalState::State head; + GimbalState::State tail; + + static inline Eigen::Matrix solve1dFullPivLu( + double p0, + double v0, + double a0, + double p1, + double v1, + double a1, + double T + ) noexcept { + Eigen::Matrix A; + Eigen::Matrix b; + + double T2 = T * T, T3 = T2 * T, T4 = T3 * T, T5 = T4 * T; + + // Rows: p(0)=p0, p'(0)=v0, p''(0)=a0, + // p(T)=p1, p'(T)=v1, p''(T)=a1 + A << 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 1, T, T2, T3, T4, T5, 0, 1, + 2 * T, 3 * T2, 4 * T3, 5 * T4, 0, 0, 2, 6 * T, 12 * T2, 20 * T3; + + b << p0, v0, a0, p1, v1, a1; + return A.fullPivLu().solve(b); + } + static inline Eigen::Matrix solve1dClosedForm( + double p0, + double v0, + double a0, + double p1, + double v1, + double a1, + double T + ) noexcept { + Eigen::Matrix c; + double T2 = T * T; + double T3 = T2 * T; + double T4 = T3 * T; + double T5 = T4 * T; + + // known low-order coefficients + double c0 = p0; + double c1 = v0; + double c2 = a0 * 0.5; + + // closed-form for c3, c4, c5 (derived from boundary conditions at t=T) + double c3 = + (-3.0 * T2 * a0 + T2 * a1 - 12.0 * T * v0 - 8.0 * T * v1 - 20.0 * p0 + 20.0 * p1) + / (2.0 * T3); + double c4 = + (1.5 * T2 * a0 - T2 * a1 + 8.0 * T * v0 + 7.0 * T * v1 + 15.0 * p0 - 15.0 * p1) + / T4; + double c5 = (-T2 * a0 + T2 * a1 - 6.0 * T * v0 - 6.0 * T * v1 - 12.0 * p0 + 12.0 * p1) + / (2.0 * T5); + + c << c0, c1, c2, c3, c4, c5; + return c; + } + + [[nodiscard]] static inline QuinticSegment + build(const GimbalState::State& s0, const GimbalState::State& s1, double T) noexcept { + QuinticSegment seg; + seg.head = s0; + seg.tail = s1; + seg.T = T; + seg.c = solve1dClosedForm(s0.p, s0.v, s0.a, s1.p, s1.v, s1.a, T); + return seg; + } + + static inline double evalAcc(const Eigen::Matrix& c, double t) noexcept { + return 2 * c[2] + 6 * c[3] * t + 12 * c[4] * t * t + 20 * c[5] * t * t * t; + } + static inline double maxAbsAcc(const Eigen::Matrix& c, double T) noexcept { + if (T <= 0.0) + return 0.0; + + auto acc = [&](double t) { + double t2 = t * t; + return 2 * c[2] + 6 * c[3] * t + 12 * c[4] * t2 + 20 * c[5] * t2 * t; + }; + + double max_acc = std::max(std::abs(acc(0.0)), std::abs(acc(T))); + + // jerk = 6c3 + 24c4 t + 60c5 t^2 + double A = 60.0 * c[5]; + double B = 24.0 * c[4]; + double C = 6.0 * c[3]; + + const double eps = 1e-9; + + if (std::abs(A) < eps) { + if (std::abs(B) > eps) { + double t = -C / B; + if (t > 0.0 && t < T) + max_acc = std::max(max_acc, std::abs(acc(t))); + } + } else { + double D = B * B - 4 * A * C; + if (D >= 0.0) { + double sqrtD = std::sqrt(D); + double inv2A = 1.0 / (2 * A); + + double t1 = (-B + sqrtD) * inv2A; + double t2 = (-B - sqrtD) * inv2A; + + if (t1 > 0.0 && t1 < T) + max_acc = std::max(max_acc, std::abs(acc(t1))); + if (t2 > 0.0 && t2 < T) + max_acc = std::max(max_acc, std::abs(acc(t2))); + } + } + + return std::isfinite(max_acc) ? max_acc : 0.0; + } + + [[nodiscard]] double inline duration() const noexcept { + return T; + } + + [[nodiscard]] double inline MaxAcc(void) const noexcept { + return QuinticSegment::maxAbsAcc(c, T); + } + [[nodiscard]] GimbalState::State inline eval(double t) const noexcept { + GimbalState::State s; + if (T <= 0.0) + return s; + t = std::clamp(t, 0.0, T); + double t2 = t * t, t3 = t2 * t, t4 = t3 * t, t5 = t4 * t; + s.p = c[0] + c[1] * t + c[2] * t2 + c[3] * t3 + c[4] * t4 + c[5] * t5; + s.v = c[1] + 2 * c[2] * t + 3 * c[3] * t2 + 4 * c[4] * t3 + 5 * c[5] * t4; + s.a = evalAcc(c, t); + return s; + } + }; + + class LimitTrajectory: public Trajectory { + public: + struct Traj { + std::vector segs; + std::vector seg_dt; + std::vector seg_prefix_time; + }; + + Traj yaw_traj; + Traj pitch_traj; + static inline double angleDiff(double a, double b) noexcept { + double d = a - b; + while (d > M_PI) + d -= 2 * M_PI; + while (d < -M_PI) + d += 2 * M_PI; + return d; + } + + static inline double unwrapAngle(double prev, double curr) noexcept { + return prev + angleDiff(curr, prev); + } + + void unwrapStates(std::vector& s) const noexcept { + for (size_t i = 1; i < s.size(); ++i) { + s[i].yaw_state.p = unwrapAngle(s[i - 1].yaw_state.p, s[i].yaw_state.p); + s[i].pitch_state.p = unwrapAngle(s[i - 1].pitch_state.p, s[i].pitch_state.p); + } + } + + [[nodiscard]] std::pair, std::vector> + computeNodeStates(const std::vector& gp, const std::vector& dt) + const noexcept { + const size_t N = gp.size(); + std::vector yaw(N), pitch(N); + for (size_t i = 0; i < N; ++i) { + yaw[i] = gp[i].yaw_state; + pitch[i] = gp[i].pitch_state; + } + if (N < 2) + return { yaw, pitch }; + auto compute_va = [&](std::vector& s) { + // 边界 + s.front().v = s.back().v = 0.0; + s.front().a = s.back().a = 0.0; + + for (size_t i = 1; i + 1 < N; ++i) { + const double dt0 = dt[i - 1]; + const double dt1 = dt[i]; + const double denom = dt0 + dt1; + + if (denom < 1e-6) { + s[i].v = s[i].a = 0.0; + continue; + } + + const double w0 = dt1 / denom; + const double w1 = dt0 / denom; + + s[i].v = w0 * (s[i].p - s[i - 1].p) / dt0 + w1 * (s[i + 1].p - s[i].p) / dt1; + + s[i].a = + 2.0 * ((s[i + 1].p - s[i].p) / dt1 - (s[i].p - s[i - 1].p) / dt0) / denom; + } + }; + compute_va(yaw); + compute_va(pitch); + + return { yaw, pitch }; + } + + void limitTraj( + Traj& traj, + const std::vector& s, + int near_change_idx, + double max_acc + ) const noexcept { + traj.segs.clear(); + traj.seg_dt.clear(); + traj.seg_prefix_time.clear(); + + const int N = static_cast(s.size()); + if (N <= 1) + return; + + const auto& _prefix_time = prefix_time; + const auto& _dt_vec = dt_vec; + + auto buildSeg = [&](int l, int r) -> QuinticSegment { + double dur = _prefix_time[r] - _prefix_time[l]; + return QuinticSegment::build(s[l], s[r], dur); + }; + + std::optional> interval; + if (near_change_idx >= 0) { + int l = std::clamp(near_change_idx, 0, N - 1); + int r = std::clamp(near_change_idx + 1, 0, N - 1); + if (l < r) + interval.emplace(l, r); + } + + if (!interval) { + traj.segs.reserve(N - 1); + for (int i = 0; i < N - 1; ++i) + traj.segs.push_back(QuinticSegment::build(s[i], s[i + 1], _dt_vec[i])); + + traj.seg_dt.reserve(traj.segs.size()); + for (const auto& seg: traj.segs) + traj.seg_dt.push_back(seg.duration()); + + traj.seg_prefix_time.resize(traj.segs.size() + 1); + traj.seg_prefix_time[0] = 0.0; + for (size_t i = 0; i < traj.seg_dt.size(); ++i) + traj.seg_prefix_time[i + 1] = traj.seg_prefix_time[i] + traj.seg_dt[i]; + return; + } + + { + int& l = interval->first; + int& r = interval->second; + QuinticSegment seg = buildSeg(l, r); + + auto try_candidate = [&](int nl, int nr) -> bool { + nl = std::max(0, nl); + nr = std::min(N - 1, nr); + if (nl == l && nr == r) + return false; + + QuinticSegment cand = buildSeg(nl, nr); + if (cand.MaxAcc() <= seg.MaxAcc()) { + l = nl; + r = nr; + seg = std::move(cand); + return true; + } + return false; + }; + + while (seg.MaxAcc() > max_acc) { + bool expanded = false; + + if (l > 0 || r < N - 1) + expanded = try_candidate(l - 1, r + 1); + + if (!expanded && l > 0) + expanded = try_candidate(l - 1, r); + + if (!expanded && r < N - 1) + expanded = try_candidate(l, r + 1); + + if (!expanded && (l > 0 || r < N - 1)) { + int nl = std::max(0, l - 1); + int nr = std::min(N - 1, r + 1); + QuinticSegment forceSeg = buildSeg(nl, nr); + + if (forceSeg.MaxAcc() < seg.MaxAcc() || (nl == 0 && nr == N - 1)) { + l = nl; + r = nr; + seg = std::move(forceSeg); + expanded = true; + } + } + + if (!expanded) + break; + if (l == 0 && r == N - 1 && seg.MaxAcc() > max_acc) + break; + } + } + + traj.segs.reserve(N - 1); + for (int i = 0; i < N - 1; ++i) { + if (interval && i == interval->first) { + traj.segs.push_back(buildSeg(interval->first, interval->second)); + i = interval->second - 1; // skip covered indices + } else { + traj.segs.push_back(QuinticSegment::build(s[i], s[i + 1], _dt_vec[i])); + } + } + + traj.seg_dt.reserve(traj.segs.size()); + for (const auto& seg: traj.segs) + traj.seg_dt.push_back(seg.duration()); + + traj.seg_prefix_time.resize(traj.segs.size() + 1); + traj.seg_prefix_time[0] = 0.0; + for (size_t i = 0; i < traj.seg_dt.size(); ++i) + traj.seg_prefix_time[i + 1] = traj.seg_prefix_time[i] + traj.seg_dt[i]; + } + + void buildLimit(double max_yaw_acc, double max_pitch_acc) noexcept { + unwrapStates(cp_vec); + auto [yaw_states, pitch_states] = computeNodeStates(cp_vec, dt_vec); + int best_idx = -1; + double best_dist = 1e100; + const size_t offset = 2; + const int N = static_cast(cp_vec.size()); + if (N < 2) + return; + const double mid_time = 0.5 * total_duration_; + for (size_t i = offset; i + offset + 1 < cp_vec.size(); ++i) { + if (cp_vec[i].aim_id == cp_vec[i + 1].aim_id) + continue; + + const double seg_mid = 0.5 * (prefix_time[i] + prefix_time[i + 1]); + const double dist = std::abs(seg_mid - mid_time); + + if (dist < best_dist) { + best_dist = dist; + best_idx = static_cast(i); + } + } + limitTraj(yaw_traj, yaw_states, best_idx, max_yaw_acc); + limitTraj(pitch_traj, pitch_states, best_idx, max_pitch_acc); + } + void simpleTraj(Traj& traj, const std::vector& s) const noexcept { + traj.segs.clear(); + traj.seg_dt.clear(); + traj.seg_prefix_time.clear(); + + const int N = static_cast(s.size()); + if (N <= 1) + return; + for (int i = 0; i < N - 1; ++i) { + traj.segs.push_back(QuinticSegment::build(s[i], s[i + 1], dt_vec[i])); + } + traj.seg_dt.reserve(traj.segs.size()); + for (const auto& seg: traj.segs) + traj.seg_dt.push_back(seg.duration()); + + traj.seg_prefix_time.resize(traj.segs.size() + 1); + traj.seg_prefix_time[0] = 0.0; + for (size_t i = 0; i < traj.seg_dt.size(); ++i) + traj.seg_prefix_time[i + 1] = traj.seg_prefix_time[i] + traj.seg_dt[i]; + } + void buildSimple() { + unwrapStates(cp_vec); + auto [yaw_states, pitch_states] = computeNodeStates(cp_vec, dt_vec); + simpleTraj(yaw_traj, yaw_states); + simpleTraj(pitch_traj, pitch_states); + } + [[nodiscard]] inline GimbalState::State + getStateAtTime(double t, const Traj& traj) const noexcept { + if (traj.segs.empty()) + return {}; + + if (t <= 0.0) + return traj.segs.front().eval(0.0); + + if (t >= total_duration_) + return traj.segs.back().eval(traj.segs.back().T); + + const auto it = + std::upper_bound(traj.seg_prefix_time.begin(), traj.seg_prefix_time.end(), t); + + size_t i = std::distance(traj.seg_prefix_time.begin(), it) - 1; + i = std::min(i, traj.segs.size() - 1); + + const double t0 = traj.seg_prefix_time[i]; + return traj.segs[i].eval(t - t0); + } + [[nodiscard]] inline GimbalState getStateAtTime(double t) const noexcept { + GimbalState::State yaw = getStateAtTime(t, yaw_traj); + GimbalState::State pitch = getStateAtTime(t, pitch_traj); + return GimbalState(yaw, pitch); + } + }; + + struct VeryAimerConfig: wust_vl::common::utils::ParamGroup { + static constexpr const char* Logger = "Config: very_aimer"; + static constexpr const char* kKey = "very_aimer"; + const char* key() const override { + return kKey; + } + using Ptr = std::shared_ptr; + VeryAimerConfig() { + sample_total_time_param.onChange([this](double o, double n) { + sample_dt = sample_total_time_param.get() / sample_horizon_param.get(); + sample_half_horizon = sample_horizon_param.get() / 2; + }); + sample_horizon_param.onChange([this](double o, double n) { + sample_dt = sample_total_time_param.get() / sample_horizon_param.get(); + sample_half_horizon = sample_horizon_param.get() / 2; + }); + } + static Ptr create() { + return std::make_shared(); + } + std::shared_ptr manual_compensator; + GEN_PARAM(double, sample_total_time); + GEN_PARAM(int, sample_horizon); + GEN_PARAM(double, control_delay); + GEN_PARAM(double, delay_enable_fire_error); + GEN_PARAM(double, max_yaw_acc); + GEN_PARAM(double, max_pitch_acc); + GEN_PARAM(double, prediction_delay); + GEN_PARAM(double, comming_angle); + GEN_PARAM(double, leaving_angle); + GEN_PARAM(double, yaw_limit_deg); + GEN_PARAM(double, shooting_range_h); + GEN_PARAM(double, shooting_range_w_small); + GEN_PARAM(double, shooting_range_w_big); + GEN_PARAM(double, min_enable_pitch_deg); + GEN_PARAM(double, min_enable_yaw_deg); + GEN_PARAM(bool, fuck_test); + GEN_PARAM(double, fuck_test_thresh); + double sample_dt = 0.01; + int sample_half_horizon = 100; + bool first_load = false; + struct Mpc { + int max_iter; + std::vector Q_pitch; + std::vector Q_yaw; + std::vector R_pitch; + std::vector R_yaw; + void load(const YAML::Node& node) { + max_iter = node["max_iter"].as(); + Q_pitch = node["Q_pitch"].as>(); + Q_yaw = node["Q_yaw"].as>(); + R_pitch = node["R_pitch"].as>(); + R_yaw = node["R_yaw"].as>(); + } + } mpc; + void loadSelf(const YAML::Node& node) override { + if (!first_load) { + manual_compensator = std::make_shared(); + std::vector entries; + + if (node["trajectory_offset"]) { + for (const auto& node: node["trajectory_offset"]) { + wust_vl::common::utils::OffsetEntry e; + e.d_min = node["d_min"].as(); + e.d_max = node["d_max"].as(); + e.h_min = node["h_min"].as(); + e.h_max = node["h_max"].as(); + e.pitch_off = node["pitch_off"].as(); + e.yaw_off = node["yaw_off"].as(); + entries.push_back(e); + } + manual_compensator->setBasePitch(node["base_offset"]["pitch"].as()); + manual_compensator->setBaseYaw(node["base_offset"]["yaw"].as()); + } + if (!manual_compensator->updateMapFlow(entries) || entries.size() < 1) { + std::cout << "Trajectory compensator init failed" << std::endl; + } + mpc.load(node); + first_load = true; + } else { + } + shooting_range_h_param.load(node); + shooting_range_w_small_param.load(node); + shooting_range_w_big_param.load(node); + yaw_limit_deg_param.load(node); + min_enable_pitch_deg_param.load(node); + min_enable_yaw_deg_param.load(node); + prediction_delay_param.load(node); + comming_angle_param.load(node); + leaving_angle_param.load(node); + control_delay_param.load(node); + delay_enable_fire_error_param.load(node); + max_yaw_acc_param.load(node); + max_pitch_acc_param.load(node); + sample_total_time_param.load(node); + sample_horizon_param.load(node); + fuck_test_param.load(node); + fuck_test_thresh_param.load(node); + } + }; + + class VeryAimerTrajBase { + public: + using Ptr = std::shared_ptr; + LimitTrajectory target_traj; + Trajectory aim_traj; + ControlPoint cp0; + AutoAimFsm fsm; + Eigen::Vector3d fin_aim_pos; + AimTarget aim_target; + virtual GimbalState getTargetState(double t) const = 0; + virtual GimbalState getControlState(double t) const = 0; + }; + class VeryAimerTrajMpc: public VeryAimerTrajBase { + public: + using Ptr = std::shared_ptr; + VeryAimerTrajMpc( + + ) {} + static Ptr create() { + return std::make_shared(); + } + Trajectory control_traj; + GimbalState getTargetState(double t) const override { + return target_traj.LimitTrajectory::getStateAtTime(t); + } + GimbalState getControlState(double t) const override { + return control_traj.getStateAtTime(t); + } + }; + class VeryAimerTrajSeg: public VeryAimerTrajBase { + public: + using Ptr = std::shared_ptr; + VeryAimerTrajSeg() {} + static Ptr create() { + return std::make_shared(); + } + GimbalState getTargetState(double t) const override { + return target_traj.Trajectory::getStateAtTime(t); + } + GimbalState getControlState(double t) const override { + return target_traj.LimitTrajectory::getStateAtTime(t); + } + }; + enum class Type : int { + Mpc = 0, + Seg = 1, + } type_; + Impl(wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter) { + auto_aim_config_parameter_ = auto_aim_config_parameter; + reset(); + } + + void reset() { + config_ = VeryAimerConfig::create(); + trajectory_compensator_config_ = TrajectoryCompensatorConfig::create(); + auto_aim_config_parameter_->registerGroup(*trajectory_compensator_config_); + auto_aim_config_parameter_->registerGroup(*config_); + auto_aim_config_parameter_->reloadFromOldPath(); + const auto yaml = auto_aim_config_parameter_->getConfig(); + const auto type = yaml["very_aimer"]["type"].as(); + if (type == "mpc" || type == "MPC") { + type_ = Type::Mpc; + } else if (type == "seg" || type == "SEG") { + type_ = Type::Seg; + + } else { + type_ = Type::Seg; + } + if (type_ == Type::Mpc) { + mpcReset(); + } + } + int max_iter_ = 10; + TinySolver* yaw_solver_; + TinySolver* pitch_solver_; + void mpcReset() { + const int horizon = config_->sample_horizon_param.get(); + const int half_horizon = config_->sample_half_horizon; + const double dt = config_->sample_dt; + Eigen::MatrixXd A_pitch { { 1, dt }, { 0, 1 } }; + Eigen::MatrixXd B_pitch { { 0 }, { dt } }; + Eigen::VectorXd f_pitch { { 0, 0 } }; + Eigen::Matrix Q_p(config_->mpc.Q_pitch.data()); + Eigen::Matrix R_p(config_->mpc.R_pitch.data()); + tiny_setup( + &pitch_solver_, + A_pitch, + B_pitch, + f_pitch, + Q_p.asDiagonal(), + R_p.asDiagonal(), + 1.0, + 2, + 1, + horizon, + 0 + ); + + Eigen::MatrixXd x_min_pitch = Eigen::MatrixXd::Constant(2, horizon, -1e17); + Eigen::MatrixXd x_max_pitch = Eigen::MatrixXd::Constant(2, horizon, 1e17); + Eigen::MatrixXd u_min_pitch = + Eigen::MatrixXd::Constant(1, horizon - 1, -config_->max_pitch_acc_param.get()); + Eigen::MatrixXd u_max_pitch = + Eigen::MatrixXd::Constant(1, horizon - 1, config_->max_pitch_acc_param.get()); + tiny_set_bound_constraints( + pitch_solver_, + x_min_pitch, + x_max_pitch, + u_min_pitch, + u_max_pitch + ); + pitch_solver_->settings->max_iter = max_iter_; + Eigen::MatrixXd A_yaw { { 1, dt }, { 0, 1 } }; + Eigen::MatrixXd B_yaw { { 0 }, { dt } }; + Eigen::VectorXd f_yaw { { 0, 0 } }; + Eigen::Matrix Q_y(config_->mpc.Q_yaw.data()); + Eigen::Matrix R_y(config_->mpc.R_yaw.data()); + tiny_setup( + &yaw_solver_, + A_yaw, + B_yaw, + f_yaw, + Q_y.asDiagonal(), + R_y.asDiagonal(), + 1.0, + 2, + 1, + horizon, + 0 + ); + + Eigen::MatrixXd x_min_yaw = Eigen::MatrixXd::Constant(2, horizon, -1e17); + Eigen::MatrixXd x_max_yaw = Eigen::MatrixXd::Constant(2, horizon, 1e17); + Eigen::MatrixXd u_min_yaw = + Eigen::MatrixXd::Constant(1, horizon - 1, -config_->max_yaw_acc_param.get()); + Eigen::MatrixXd u_max_yaw = + Eigen::MatrixXd::Constant(1, horizon - 1, config_->max_yaw_acc_param.get()); + tiny_set_bound_constraints(yaw_solver_, x_min_yaw, x_max_yaw, u_min_yaw, u_max_yaw); + yaw_solver_->settings->max_iter = max_iter_; + } + int selectArmor(const Target& target, const AutoAimFsm& auto_aim_fsm) const noexcept { + static int lock_id = -1; + const auto [aim_first, aim_center, aim_pair] = getAimStatus(auto_aim_fsm); + const auto armor_list = target.getArmorPosAndYaw(); + const int armor_num = static_cast(armor_list.size()); + int i_chosen = 0; + + const double center_yaw = std::atan2(target.target_state_.cy(), target.target_state_.cx()); + + std::vector delta_angles; + delta_angles.reserve(armor_num); + for (int i = 0; i < armor_num; ++i) { + delta_angles.push_back(angles::normalize_angle(armor_list[i][3] - center_yaw)); + } + + const auto pick_best_by_min_delta = [&](const std::vector& idxs) -> int { + int best = -1; + double best_val = std::numeric_limits::infinity(); + for (int i: idxs) { + const double val = std::abs(delta_angles[i]); + if (val < best_val) { + best_val = val; + best = i; + } + } + return best; + }; + + if (aim_first && target.tracked_id_ != ArmorNumber::OUTPOST && armor_num > 0) { + std::vector candidates; + constexpr double in_first = 60.0 / 57.3; + for (int i = 0; i < armor_num; ++i) + if (std::abs(delta_angles[i]) <= in_first) + candidates.push_back(i); + + if (!candidates.empty()) { + if (candidates.size() > 1) { + int a = candidates[0], b = candidates[1]; + if (lock_id != a && lock_id != b) { + lock_id = (std::abs(delta_angles[a]) < std::abs(delta_angles[b])) ? a : b; + } + int pick = (lock_id >= 0 && lock_id < armor_num) + ? lock_id + : pick_best_by_min_delta(candidates); + if (pick >= 0) { + i_chosen = pick; + } + } else { + lock_id = -1; + int pick = candidates[0]; + i_chosen = pick; + } + } + + return i_chosen; + } + if (armor_num > 0) { + int best_idx = -1; + + if (target.tracked_id_ != ArmorNumber::OUTPOST) { + const double coming_angle = config_->comming_angle_param.get() * M_PI / 180.0; + const double leaving_angle = config_->leaving_angle_param.get() * M_PI / 180.0; + + for (int i = 0; i < armor_num; ++i) { + if (std::abs(delta_angles[i]) > coming_angle) + continue; + + if (target.target_state_.vyaw() > 0 && delta_angles[i] < leaving_angle) + best_idx = i; + if (target.target_state_.vyaw() < 0 && delta_angles[i] > -leaving_angle) + best_idx = i; + } + } + + if (best_idx < 0) { + std::vector all(armor_num); + std::iota(all.begin(), all.end(), 0); + best_idx = pick_best_by_min_delta(all); + } + if (aim_pair && target.tracked_id_ != ArmorNumber::OUTPOST) { + std::vector all; + if (target.target_state_.h() > 0) { + all.push_back(1); + all.push_back(3); + } else { + all.push_back(0); + all.push_back(2); + } + best_idx = pick_best_by_min_delta(all); + } + + i_chosen = best_idx; + } + + return i_chosen; + } + ControlPoint + getControlPoint(Eigen::Vector3d aim_target_pos, double diff_yaw, double bullet_speed) const { + ControlPoint cp; + double control_yaw = std::atan2(aim_target_pos.y(), aim_target_pos.x()); + double raw_pitch = std::atan2( + aim_target_pos.z(), + std::sqrt( + aim_target_pos.x() * aim_target_pos.x() + aim_target_pos.y() * aim_target_pos.y() + ) + ); + if (!trajectory_compensator_config_->trajectory_compensator + ->compensate(aim_target_pos, raw_pitch, bullet_speed)) + { + WUST_ERROR("very_aimer") << " traj compense error"; + + break_this_ = true; + return cp; + } + + double control_pitch = raw_pitch; + const auto offs = config_->manual_compensator->angleHardCorrect( + aim_target_pos.head(2).norm(), + aim_target_pos.z() + ); + control_yaw = angles::normalize_angle((control_yaw + offs[1] * M_PI / 180.0)); + control_pitch = (control_pitch + offs[0] * M_PI / 180.0); + cp.pitch = control_pitch; + cp.yaw = control_yaw; + cp.xyza.head<3>() = aim_target_pos; + cp.xyza[3] = diff_yaw; + + return cp; + } + std::tuple calEnableDiff( + Eigen::Vector3d aim_target_pos, + double diff_yaw, + const AutoAimFsm& auto_aim_fsm, + bool is_big + ) const noexcept { + const double distance = aim_target_pos.norm(); + double shooting_range_yaw; + if (!is_big) { + shooting_range_yaw = + std::abs(atan2(config_->shooting_range_w_small_param.get() / 2, distance)); + } else { + shooting_range_yaw = + std::abs(atan2(config_->shooting_range_w_big_param.get() / 2, distance)); + } + double shooting_range_pitch = + std::abs(atan2(config_->shooting_range_h_param.get() / 2, distance)); + double yaw_factor = 0.0; + + const double yaw_rad = diff_yaw; + if (auto_aim_fsm != AutoAimFsm::AIM_SINGLE_ARMOR) { + if (std::abs(yaw_rad) <= config_->yaw_limit_deg_param.get() / 180.0 * M_PI) { + yaw_factor = std::cos(yaw_rad); + } + } else { + yaw_factor = std::cos(yaw_rad); + } + + const double pitch_factor = std::cos(15.0 * M_PI / 180); + + shooting_range_yaw = + std::max(shooting_range_yaw, config_->min_enable_yaw_deg_param.get() * M_PI / 180); + shooting_range_pitch = + std::max(shooting_range_pitch, config_->min_enable_pitch_deg_param.get() * M_PI / 180); + shooting_range_yaw *= yaw_factor; + shooting_range_pitch *= pitch_factor; + + return std::make_tuple(std::abs(shooting_range_yaw), std::abs(shooting_range_pitch)); + } + std::tuple getAimStatus(const AutoAimFsm& auto_aim_fsm) const noexcept { + const bool aim_first = (auto_aim_fsm == AutoAimFsm::AIM_SINGLE_ARMOR); + const bool aim_center = (auto_aim_fsm == AutoAimFsm::AIM_WHOLE_CAR_CENTER); + const bool aim_pair = (auto_aim_fsm == AutoAimFsm::AIM_WHOLE_CAR_PAIR); + return std::make_tuple(aim_first, aim_center, aim_pair); + } + + ControlPoint choseAndGetControlPoint( + const Target& target, + double bullet_speed, + const AutoAimFsm& auto_aim_fsm + ) const noexcept { + const auto [aim_first, aim_center, aim_pair] = getAimStatus(auto_aim_fsm); + const int target_select = selectArmor(target, auto_aim_fsm); + const auto armors_xyza = target.getArmorPosAndYaw(); + Eigen::Vector3d aim_pos = armors_xyza[target_select].head<3>(); + + if (aim_center) { + const double raw_z = aim_pos.z(); + double c_xy_dis = std::sqrt( + target.target_state_.cx() * target.target_state_.cx() + + target.target_state_.cy() * target.target_state_.cy() + ); + const double c_yaw = std::atan2(target.target_state_.cy(), target.target_state_.cx()); + c_xy_dis -= target.getArmor2CenterXYDis(target_select); + aim_pos.x() = c_xy_dis * std::cos(c_yaw); + aim_pos.y() = c_xy_dis * std::sin(c_yaw); + aim_pos.z() = raw_z; + } + const double center_yaw = std::atan2(target.target_state_.cy(), target.target_state_.cx()); + const double d_angle = + angles::shortest_angular_distance(center_yaw, armors_xyza[target_select][3]); + ControlPoint cp = getControlPoint(aim_pos, d_angle, bullet_speed); + cp.id_in_target = target_select; + return cp; + } + struct FireResult { + bool fire; + double enable_yaw_diff; + double enable_pitch_diff; + FireResult(bool f, double ey, double ep): + fire(f), + enable_yaw_diff(ey), + enable_pitch_diff(ep) {} + }; + inline FireResult + canFireAtTime(const VeryAimerTrajBase::Ptr& traj, double t, bool is_big) const noexcept { + auto cal_delay = [&](double _t) { + const auto target_delay = traj->getTargetState(_t + config_->control_delay_param.get()); + const auto control_delay = + traj->getControlState(_t + config_->control_delay_param.get()); + + if (std::hypot( + angles::normalize_angle(target_delay.yaw_state.p + traj->cp0.yaw) + - angles::normalize_angle(control_delay.yaw_state.p + traj->cp0.yaw), + angles::normalize_angle(target_delay.pitch_state.p) + - angles::normalize_angle(control_delay.pitch_state.p) + ) + >= config_->delay_enable_fire_error_param.get()) + { + return false; + } + return true; + }; + if (!cal_delay(t)) { + return { false, 0, 0 }; + } + if (!cal_delay(-t)) { + return { false, 0, 0 }; + } + const auto target = traj->getTargetState(t); + const auto control = traj->getControlState(t); + const auto aim = traj->aim_traj.getStateAtTime(t); + + const double target_yaw = angles::normalize_angle(target.yaw_state.p + traj->cp0.yaw); + const double control_yaw = angles::normalize_angle(control.yaw_state.p + traj->cp0.yaw); + + const auto [enable_yaw, enable_pitch] = + calEnableDiff(aim.pos, aim.d_angle, traj->fsm, is_big); + + return { std::abs(angles::shortest_angular_distance(target_yaw, control_yaw)) <= enable_yaw + && std::abs(angles::shortest_angular_distance( + target.pitch_state.p, + control.pitch_state.p + )) + <= enable_pitch, + enable_yaw, + enable_pitch }; + } + std::pair> getTrajectory( + Target& target, + const ControlPoint& cp0, + double bullet_speed, + const AutoAimFsm& auto_aim_fsm + ) const { + LimitTrajectory traj; + Trajectory aim_traj; + const int horizon = config_->sample_horizon_param.get(); + const int half_horizon = config_->sample_half_horizon; + const double dt = config_->sample_dt; + + traj.reserve(horizon); + aim_traj.reserve(horizon); + + // prepare: roll the target back so we start from same relative time as original impl + target.predictSimple(-dt * (half_horizon + 1)); + + // compute first two cps (target state is mutated between calls but choseAndGetControlPoint takes const&) + auto cp_last = choseAndGetControlPoint(target, bullet_speed, auto_aim_fsm); + target.predictSimple(dt); + auto cp = choseAndGetControlPoint(target, bullet_speed, auto_aim_fsm); + + for (int i = 0; i < horizon; ++i) { + target.predictSimple(dt); + const auto cp_next = choseAndGetControlPoint(target, bullet_speed, auto_aim_fsm); + GimbalState pt; + pt.yaw_state.p = angles::normalize_angle(cp.yaw - cp0.yaw); + pt.pitch_state.p = cp.pitch; + pt.aim_id = cp.id_in_target; + traj.push_back(pt, dt); + AimPoint aim_pt; + aim_pt.d_angle = cp.xyza[3]; + aim_pt.pos = cp.xyza.head<3>(); + aim_traj.push_back(aim_pt, dt); + cp_last = cp; + cp = cp_next; + } + + return { std::move(traj), std::move(aim_traj) }; + } + template + std::shared_ptr buildVeryAimerCommon( + Target& target, + int fin_target_select, + double bullet_speed, + const AutoAimFsm& auto_aim_fsm, + FinalizeFn finalize + ) { + auto res = std::make_shared(); + + const auto fin_armors_xyza = target.getArmorPosAndYaw(); + res->fin_aim_pos = fin_armors_xyza[fin_target_select].head<3>(); + + const auto [aim_first, aim_center, aim_pair] = getAimStatus(auto_aim_fsm); + + if (aim_center) { + const double raw_z = res->fin_aim_pos.z(); + double c_xy_dis = std::hypot(target.target_state_.cx(), target.target_state_.cy()); + const double c_yaw = std::atan2(target.target_state_.cy(), target.target_state_.cx()); + + c_xy_dis -= target.getArmor2CenterXYDis(fin_target_select); + + res->fin_aim_pos.x() = c_xy_dis * std::cos(c_yaw); + res->fin_aim_pos.y() = c_xy_dis * std::sin(c_yaw); + res->fin_aim_pos.z() = raw_z; + } + + // AimTarget + { + AimTarget at; + at.pos = res->fin_aim_pos; + at.valid = true; + + Eigen::Vector3d euler; + euler.x() = M_PI / 2.0; + euler.y() = (target.tracked_id_ == ArmorNumber::OUTPOST) ? -0.2618 : 0.2618; + euler.z() = target.target_state_.yaw(); + + at.ori = utils::eulerToQuat(euler, utils::EulerOrder::ZYX); + res->aim_target = at; + } + + const double center_yaw = std::atan2(target.target_state_.cy(), target.target_state_.cx()); + + const double d_angle = + angles::shortest_angular_distance(center_yaw, fin_armors_xyza[fin_target_select][3]); + + res->cp0 = getControlPoint(res->fin_aim_pos, d_angle, bullet_speed); + res->cp0.id_in_target = fin_target_select; + res->cp0.xyza = fin_armors_xyza[fin_target_select]; + + auto traj = getTrajectory(target, res->cp0, bullet_speed, auto_aim_fsm); + + finalize(res, traj); // 差异点交给外部 + + return res; + } + VeryAimerTrajSeg::Ptr buildVerAimerTrajSeg( + Target& target, + int fin_target_select, + double bullet_speed, + const AutoAimFsm& auto_aim_fsm + ) { + return buildVeryAimerCommon( + target, + fin_target_select, + bullet_speed, + auto_aim_fsm, + [this](auto& res, auto& traj) { + traj.first.buildLimit( + config_->max_yaw_acc_param.get(), + config_->max_pitch_acc_param.get() + ); + res->target_traj = traj.first; + res->aim_traj = traj.second; + } + ); + } + + VeryAimerTrajMpc::Ptr buildVerAimerTrajMpc( + Target& target, + int fin_target_select, + double bullet_speed, + const AutoAimFsm& auto_aim_fsm + ) { + return buildVeryAimerCommon( + target, + fin_target_select, + bullet_speed, + auto_aim_fsm, + [this](auto& res, auto& traj) { + traj.first.buildSimple(); + res->target_traj = traj.first; + res->aim_traj = traj.second; + res->control_traj = solveTrajectoryMpc(res->target_traj); + } + ); + } + + Trajectory solveTrajectoryMpc(const Trajectory& traj) { + const double total_time = traj.getTotalDuration(); + const int horizon = config_->sample_horizon_param.get(); + const int half_horizon = config_->sample_half_horizon; + const double dt = config_->sample_dt; + const auto trajVecToEigen = [&](const Trajectory& traj) { + Eigen::Matrix mat(4, horizon); + + const double half_t = total_time * 0.5; + for (int k = 0; k < horizon; ++k) { + int i = k - half_horizon; + double t = i * dt + half_t; + auto state = traj.Trajectory::getStateAtTime(t); + mat(0, k) = state.yaw_state.p; + mat(1, k) = state.yaw_state.v; + mat(2, k) = state.pitch_state.p; + mat(3, k) = state.pitch_state.v; + } + return mat; + }; + + auto traj_eigen = trajVecToEigen(traj); + Eigen::VectorXd x0(2); + x0 << traj_eigen(0, 0), traj_eigen(1, 0); + tiny_set_x0(yaw_solver_, x0); + yaw_solver_->work->Xref = traj_eigen.block(0, 0, 2, horizon); + tiny_solve(yaw_solver_); + + x0 << traj_eigen(2, 0), traj_eigen(3, 0); + tiny_set_x0(pitch_solver_, x0); + pitch_solver_->work->Xref = traj_eigen.block(2, 0, 2, horizon); + tiny_solve(pitch_solver_); + Trajectory control_traj; + control_traj.reserve(horizon); + for (int i = 0; i < horizon; i++) { + GimbalState tp; + tp.yaw_state.p = yaw_solver_->work->x(0, i); + tp.yaw_state.v = yaw_solver_->work->x(1, i); + tp.pitch_state.p = pitch_solver_->work->x(0, i); + tp.pitch_state.v = pitch_solver_->work->x(1, i); + tp.yaw_state.a = yaw_solver_->work->u(0, i); + tp.pitch_state.a = pitch_solver_->work->u(0, i); + control_traj.push_back(tp, dt); + } + return control_traj; + } + struct PredictResult { + double fly_time; + int fin_target_select; + std::vector fin_armors_xyza; + }; + + PredictResult + predictAndSelect(Target& target, double bullet_speed, const AutoAimFsm& auto_aim_fsm) { + const int roughly_select = selectArmor(target, auto_aim_fsm); + + const auto now = wust_vl::common::utils::time_utils::now(); + target.predictSimple(now); + + const auto ap = target.getArmorPositions(); + double prev_fly_time = trajectory_compensator_config_->trajectory_compensator + ->getFlyingTime(ap[roughly_select].head<3>(), bullet_speed); + + std::vector iteration_target(10, target); + + for (int iter = 0; iter < 10; ++iter) { + iteration_target[iter].predictSimple(prev_fly_time); + + const int iter_select = selectArmor(iteration_target[iter], auto_aim_fsm); + + const auto iter_poss = iteration_target[iter].getArmorPositions(); + + const double iter_fly_time = + trajectory_compensator_config_->trajectory_compensator->getFlyingTime( + iter_poss[roughly_select], + bullet_speed + ); + + if (std::abs(iter_fly_time - prev_fly_time) < 1e-3) + break; + + prev_fly_time = iter_fly_time; + } + + const double predict_time = prev_fly_time + config_->prediction_delay_param.get(); + + target.predictSimple(predict_time); + + PredictResult res; + res.fly_time = prev_fly_time; + res.fin_armors_xyza = target.getArmorPosAndYaw(); + res.fin_target_select = selectArmor(target, auto_aim_fsm); + return res; + } + + template + GimbalCmd veryAimImpl( + Target target, + double bullet_speed, + const AutoAimFsm& auto_aim_fsm, + BuildFn build_fn + ) { + GimbalCmd cmd; + if (!trajectory_compensator_config_->trajectory_compensator) { + cmd.appear = false; + return cmd; + } + + const auto [aim_first, aim_center, aim_pair] = getAimStatus(auto_aim_fsm); + + auto predict = predictAndSelect(target, bullet_speed, auto_aim_fsm); + + VeryAimerPtr build; + try { + build = build_fn(target, predict.fin_target_select, bullet_speed, auto_aim_fsm); + } catch (...) { + WUST_WARN("very_aimer") << "build failed"; + cmd.appear = false; + return cmd; + } + + cmd.aim_target = build->aim_target; + + double target_yaw = build->cp0.yaw; + double target_pitch = build->cp0.pitch; + + if (aim_center) { + const double center_yaw = + std::atan2(target.target_state_.cy(), target.target_state_.cx()); + + const double d_angle = angles::shortest_angular_distance( + center_yaw, + predict.fin_armors_xyza[predict.fin_target_select][3] + ); + + auto cp = getControlPoint( + predict.fin_armors_xyza[predict.fin_target_select].head<3>(), + d_angle, + bullet_speed + ); + + target_yaw = cp.yaw; + target_pitch = cp.pitch; + } + + const int half_horizon = config_->sample_half_horizon; + const double half_t = build->target_traj.getPrefixTimeAtIdx(half_horizon); + + const auto cs = build->getControlState(half_t); + const double yaw = angles::normalize_angle(cs.yaw_state.p + build->cp0.yaw); + + cmd.yaw = rad2deg(yaw); + cmd.v_yaw = rad2deg(cs.yaw_state.v); + cmd.a_yaw = rad2deg(cs.yaw_state.a); + + cmd.pitch = rad2deg(cs.pitch_state.p); + cmd.v_pitch = rad2deg(cs.pitch_state.v); + cmd.a_pitch = rad2deg(cs.pitch_state.a); + + cmd.target_yaw = rad2deg(target_yaw); + + cmd.target_pitch = rad2deg(target_pitch); + + cmd.distance = build->fin_aim_pos.norm(); + cmd.fly_time = predict.fly_time; + bool is_big = + target.tracked_id_ == ArmorNumber::NO1 || target.tracked_id_ == ArmorNumber::BASE; + auto fire = canFireAtTime(build, half_t, is_big); + if (config_->fuck_test_param.get()) { + double center_yaw = std::atan2(target.target_state_.cy(), target.target_state_.cx()); + Eigen::Vector3d vel = target.target_state_.vel(); + double vx_center = std::cos(center_yaw) * vel.x() + std::sin(center_yaw) * vel.y(); + double vy_center = -std::sin(center_yaw) * vel.x() + std::cos(center_yaw) * vel.y(); + + double thresh = config_->fuck_test_thresh_param.get(); + bool no_shoot = (target.target_state_.vyaw() > 0 && vy_center < thresh) + || (target.target_state_.vyaw() <= 0 && vy_center > -thresh); + + if (no_shoot) { + fire.fire = false; + fire.enable_pitch_diff = 0.0; + fire.enable_yaw_diff = 0.0; + } + } + + cmd.enable_yaw_diff = rad2deg(fire.enable_yaw_diff); + cmd.enable_pitch_diff = rad2deg(fire.enable_pitch_diff); + cmd.fire_advice = fire.fire; + + cmd.appear = cmd.isValid(); + if (break_this_) { + cmd.appear = false; + } + + return cmd; + } + GimbalCmd veryAimSeg(Target target, double bullet_speed, const AutoAimFsm& auto_aim_fsm) { + return veryAimImpl( + std::move(target), + bullet_speed, + auto_aim_fsm, + [this](auto& t, int sel, double bs, const AutoAimFsm& fsm) { + return buildVerAimerTrajSeg(t, sel, bs, fsm); + } + ); + } + + GimbalCmd veryAimMpc(Target target, double bullet_speed, const AutoAimFsm& auto_aim_fsm) { + return veryAimImpl( + std::move(target), + bullet_speed, + auto_aim_fsm, + [this](auto& t, int sel, double bs, const AutoAimFsm& fsm) { + return buildVerAimerTrajMpc(t, sel, bs, fsm); + } + ); + } + + GimbalCmd veryAim(Target target, double bullet_speed, const AutoAimFsm& auto_aim_fsm) { + break_this_ = false; + if (type_ == Type::Mpc) { + return veryAimMpc(target, bullet_speed, auto_aim_fsm); + } else if (type_ == Type::Seg) { + return veryAimSeg(target, bullet_speed, auto_aim_fsm); + } else { + return veryAimSeg(target, bullet_speed, auto_aim_fsm); + } + } + TrajectoryCompensatorConfig::Ptr trajectory_compensator_config_; + wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter_; + VeryAimerConfig::Ptr config_; + mutable bool break_this_ = false; +}; +VeryAimer::VeryAimer(wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter) { + _impl = std::make_unique(auto_aim_config_parameter); +} +VeryAimer::~VeryAimer() { + _impl.reset(); +} +GimbalCmd VeryAimer::veryAim(Target target, double bullet_speed, const AutoAimFsm& auto_aim_fsm) + +{ + return _impl->veryAim(target, bullet_speed, auto_aim_fsm); +} +} // namespace wust_vision::auto_aim \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_control/very_aimer.hpp b/wust_vision-main/tasks/auto_aim/armor_control/very_aimer.hpp new file mode 100644 index 0000000..c61a2f8 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_control/very_aimer.hpp @@ -0,0 +1,27 @@ +#pragma once +#include +namespace wust_vl::common::utils { +class Parameter; +} +using wust_vlParamterPtr = std::shared_ptr; +namespace wust_vision { +struct GimbalCmd; +} +namespace wust_vision::auto_aim { +enum class AutoAimFsm; + +class Target; +class VeryAimer { +public: + using Ptr = std::shared_ptr; + VeryAimer(wust_vlParamterPtr auto_aim_config_parameter); + static Ptr create(wust_vlParamterPtr auto_aim_config_parameter) { + return std::make_shared(auto_aim_config_parameter); + }; + ~VeryAimer(); + [[nodiscard]] GimbalCmd + veryAim(Target target, double bullet_speed, const AutoAimFsm& auto_aim_fsm); + struct Impl; + std::unique_ptr _impl; +}; +} // namespace wust_vision::auto_aim \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_detect/armor_detector_base.hpp b/wust_vision-main/tasks/auto_aim/armor_detect/armor_detector_base.hpp new file mode 100644 index 0000000..0d50ff4 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_detect/armor_detector_base.hpp @@ -0,0 +1,70 @@ +// Copyright 2025 Xiaojian Wu +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "tasks/auto_aim/type.hpp" +#include "tasks/type_common.hpp" +namespace wust_vision { +namespace auto_aim { + struct LightParams { + // width / height + double min_ratio; + double max_ratio; + // vertical angle + double max_angle; + // judge color + int color_diff_thresh; + double max_angle_diff; + int binary_thres; + void load(const YAML::Node& config) { + binary_thres = config["binary_thres"].as(); + min_ratio = config["min_ratio"].as(); + max_ratio = config["max_ratio"].as(); + max_angle = config["max_angle"].as(); + max_angle_diff = config["max_angle_diff"].as(); + color_diff_thresh = config["color_diff_thresh"].as(); + } + }; + struct ArmorParams { + double min_light_ratio; + // light pairs distance + double min_small_center_distance; + double max_small_center_distance; + double min_large_center_distance; + double max_large_center_distance; + // horizontal angle + double max_angle; + void load(const YAML::Node& config) { + min_light_ratio = config["min_light_ratio"].as(); + min_small_center_distance = config["min_small_center_distance"].as(); + max_small_center_distance = config["max_small_center_distance"].as(); + min_large_center_distance = config["min_large_center_distance"].as(); + max_large_center_distance = config["max_large_center_distance"].as(); + max_angle = config["max_angle"].as(); + } + }; + class ArmorDetectorBase { + public: + using Ptr = std::unique_ptr; + virtual ~ArmorDetectorBase() = default; + + virtual void + pushInput(CommonFrame& frame, const std::optional& target_number) = 0; + + using DetectorCallback = + std::function&, const CommonFrame&)>; + + virtual void setCallback(DetectorCallback cb) = 0; + }; +} // namespace auto_aim +} // namespace wust_vision diff --git a/wust_vision-main/tasks/auto_aim/armor_detect/armor_detector_common.cpp b/wust_vision-main/tasks/auto_aim/armor_detect/armor_detector_common.cpp new file mode 100644 index 0000000..b178fe7 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_detect/armor_detector_common.cpp @@ -0,0 +1,473 @@ +// Copyright 2025 Xiaojian Wu +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "tasks/auto_aim/armor_detect/armor_detector_common.hpp" +#include "tasks/auto_aim/armor_detect/armor_detector_base.hpp" +#include "tasks/auto_aim/armor_detect/number_classifier/factory.hpp" +#include "tasks/utils/utils.hpp" +#include "wust_vl/common/utils/timer.hpp" +namespace wust_vision { +namespace auto_aim { + struct ArmorDetectorCommon::Impl { + public: + Impl(const YAML::Node& config) { + params_.load(config); + number_classifier_ = NumberClassifierFactory::createNumberClassifier( + params_.classify_backend, + params_.classify_model_path, + params_.classify_label_path + ); + } + + bool extractNetImage(const cv::Mat& src, ArmorObject& armor) const noexcept { + constexpr int light_length = 12; + constexpr int warp_height = 28; + constexpr int small_armor_width = 32; + constexpr int large_armor_width = 54; + const cv::Size roi_size(20, 28); + + if (src.empty() || src.cols < 10 || src.rows < 10) { + std::cerr << "[extractNetImage] input src is empty or too small!" << std::endl; + return false; + } + + const auto ordered = armor.sortCorners(armor.pts); + + const cv::Point2f& p0 = ordered[0]; + const cv::Point2f& p1 = ordered[1]; + const cv::Point2f& p2 = ordered[2]; + const cv::Point2f& p3 = ordered[3]; + + const float l1_len = cv::norm(p1 - p0); + const float l2_len = cv::norm(p2 - p3); + const cv::Point2f c1 = (p0 + p1) * 0.5f; + const cv::Point2f c2 = (p2 + p3) * 0.5f; + + const float avg_light_len = 0.5f * (l1_len + l2_len); + const float center_dist = + avg_light_len > 1e-3f ? cv::norm(c1 - c2) / avg_light_len : 0.f; + + const bool is_large = center_dist > params_.armor_params.min_large_center_distance; + + const cv::Rect bbox = cv::boundingRect(armor.pts); + if (bbox.width <= 0 || bbox.height <= 0) + return false; + + if (bbox.width > src.cols || bbox.height > src.rows) + return false; + const int dw = static_cast(bbox.width * (params_.expand_ratio_w - 1.f)); + const int dh = static_cast(bbox.height * (params_.expand_ratio_h - 1.f)); + + int new_x = bbox.x - (dw >> 1); + int new_y = bbox.y - (dh >> 1); + new_x = std::max(new_x, 0); + new_y = std::max(new_y, 0); + + int new_w = std::min(bbox.width + dw, src.cols - new_x); + int new_h = std::min(bbox.height + dh, src.rows - new_y); + + if (new_w <= 0 || new_h <= 0) + return false; + + const cv::Rect expanded_rect(new_x, new_y, new_w, new_h); + + cv::Mat litroi_color = src(expanded_rect); + if (litroi_color.empty()) + return false; + + cv::Mat litroi_gray; + try { + cv::cvtColor(litroi_color, litroi_gray, cv::COLOR_BGR2GRAY); + } catch (...) { + return false; + } + + armor.whole_gray_img = litroi_gray; + if (params_.enable_cv) { + cv::Mat litroi_binary; + try { + cv::threshold( + litroi_gray, + litroi_binary, + params_.light_params.binary_thres, + 255, + cv::THRESH_BINARY + ); + armor.whole_binary_img = litroi_binary; + } catch (...) { + return false; + } + } + const cv::Point2f offset(static_cast(new_x), static_cast(new_y)); + + if (params_.enable_classify) { + cv::Point2f src_vertices[4] = { armor.pts[1] - offset, + armor.pts[0] - offset, + armor.pts[3] - offset, + armor.pts[2] - offset }; + + const int warp_width = is_large ? large_armor_width : small_armor_width; + const int top_light_y = (warp_height - light_length) / 2 - 1; + const int bottom_light_y = top_light_y + light_length; + if (warp_width <= 0 || warp_height <= 0) + return false; + cv::Point2f dst_vertices[4] = { + { 0.f, static_cast(bottom_light_y) }, + { 0.f, static_cast(top_light_y) }, + { static_cast(warp_width - 1), static_cast(top_light_y) }, + { static_cast(warp_width - 1), static_cast(bottom_light_y) } + }; + + const cv::Mat warp_mat = cv::getPerspectiveTransform(src_vertices, dst_vertices); + + cv::Mat number_image; + cv::warpPerspective( + litroi_gray, + number_image, + warp_mat, + cv::Size(warp_width, warp_height), + cv::INTER_LINEAR, + cv::BORDER_CONSTANT, + 0 + ); + + const int roi_x = (warp_width - roi_size.width) >> 1; + const cv::Rect num_roi(roi_x, 0, roi_size.width, roi_size.height); + + if ((num_roi & cv::Rect(0, 0, warp_width, warp_height)) != num_roi) + return false; + + cv::Mat num_crop = number_image(num_roi); + + cv::threshold( + num_crop, + armor.number_img, + 0, + 255, + cv::THRESH_BINARY | cv::THRESH_OTSU + ); + } + + armor.whole_rgb_img = litroi_color; + armor.local_offset = offset; + + return true; + } + + bool refineLightsFromArmorPts(ArmorObject& armor) const noexcept { + armor.center = (armor.pts[0] + armor.pts[1] + armor.pts[2] + armor.pts[3]) * 0.25f; + + const int n_lights = static_cast(armor.lights.size()); + if (n_lights < 2) + return false; + + const auto ordered = armor.sortCorners(armor.pts); + + const cv::Point2f ref_centers[2] = { (ordered[0] + ordered[1]) * 0.5f, + (ordered[2] + ordered[3]) * 0.5f }; + + int best0 = -1, best1 = -1; + float best0_d2 = std::numeric_limits::max(); + float best1_d2 = std::numeric_limits::max(); + + for (int i = 0; i < n_lights; ++i) { + const cv::Point2f& c = armor.lights[i].center; + + const cv::Point2f d0 = c - ref_centers[0]; + const float dist0 = d0.dot(d0); + if (dist0 < best0_d2) { + best0_d2 = dist0; + best0 = i; + } + + const cv::Point2f d1 = c - ref_centers[1]; + const float dist1 = d1.dot(d1); + if (dist1 < best1_d2) { + best1_d2 = dist1; + best1 = i; + } + } + + if (best0 == best1) { + best1 = -1; + best1_d2 = std::numeric_limits::max(); + + for (int i = 0; i < n_lights; ++i) { + if (i == best0) + continue; + + const cv::Point2f d = armor.lights[i].center - ref_centers[1]; + const float dist = d.dot(d); + if (dist < best1_d2) { + best1_d2 = dist; + best1 = i; + } + } + } + + if (best0 < 0 || best1 < 0) + return false; + + const auto& l0 = armor.lights[best0]; + const auto& l1 = armor.lights[best1]; + + if (l0.center.x < l1.center.x) { + armor.lights[0] = l0; + armor.lights[1] = l1; + } else { + armor.lights[0] = l1; + armor.lights[1] = l0; + } + + return true; + } + + std::vector + findLights(const cv::Mat& color_img, const cv::Mat& binary_img, ArmorObject& armor) + const noexcept { + std::vector> contours; + contours.reserve(64); + + cv::findContours(binary_img, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE); + + std::vector all_lights; + all_lights.reserve(contours.size()); + + for (const auto& contour: contours) { + const int n = static_cast(contour.size()); + if (n < 6) + continue; + + Light light(contour); + if (!isLight(light)) + continue; + + int sum_r = 0; + int sum_b = 0; + + for (const auto& pt: contour) { + const cv::Vec3b* row = color_img.ptr(pt.y); + const cv::Vec3b& pix = row[pt.x]; + sum_r += pix[0]; + sum_b += pix[2]; + } + + const int avg_diff = std::abs(sum_r - sum_b) / n; + if (avg_diff <= params_.light_params.color_diff_thresh) + continue; + + light.color = (sum_r > sum_b) ? 0 : 1; // 0=红, 1=蓝 + all_lights.emplace_back(std::move(light)); + } + + std::sort(all_lights.begin(), all_lights.end(), [](const Light& a, const Light& b) { + return a.center.x < b.center.x; + }); + + armor.lights = all_lights; + return all_lights; + } + + bool isLight(const Light& light) const noexcept { + // width / length 比例 + const float ratio = light.width / light.length; + + if (ratio <= params_.light_params.min_ratio || ratio >= params_.light_params.max_ratio) + return false; + + if (light.tilt_angle >= params_.light_params.max_angle) + return false; + + return true; + } + bool isArmor(const Light& l1, const Light& l2) const noexcept { + const float len1 = l1.length; + const float len2 = l2.length; + if (len1 <= 1e-3f || len2 <= 1e-3f) + return false; + + const float min_len = (len1 < len2) ? len1 : len2; + const float max_len = (len1 < len2) ? len2 : len1; + if (min_len / max_len <= params_.armor_params.min_light_ratio) + return false; + + const cv::Point2f d = l1.center - l2.center; + const float dist2 = d.dot(d); + + const float avg_len = 0.5f * (len1 + len2); + + const float min_small = params_.armor_params.min_small_center_distance * avg_len; + const float max_small = params_.armor_params.max_small_center_distance * avg_len; + const float min_large = params_.armor_params.min_large_center_distance * avg_len; + const float max_large = params_.armor_params.max_large_center_distance * avg_len; + + const float min_small2 = min_small * min_small; + const float max_small2 = max_small * max_small; + const float min_large2 = min_large * min_large; + const float max_large2 = max_large * max_large; + + const bool small_ok = (dist2 >= min_small2 && dist2 < max_small2); + const bool large_ok = (dist2 >= min_large2 && dist2 < max_large2); + + if (!(small_ok || large_ok)) + return false; + + static const float tan_max_angle = + std::tan(params_.armor_params.max_angle * CV_PI / 180.0f); + + if (std::abs(d.y) >= std::abs(d.x) * tan_max_angle) + return false; + + if (l1.color != l2.color) + return false; + + return true; + } + + std::vector detectNet( + const cv::Mat& src_img, + std::vector& objs_result, + Eigen::Matrix3f transform_matrix, + int detect_color, + const std::optional& target_number + ) const noexcept { + std::vector armors; + + if (!src_img.data || src_img.empty()) { + std::cout << "img data nullptr or empty" << std::endl; + return armors; + } + + if (objs_result.empty()) { + return armors; + } + + for (auto& armor_in: objs_result) { + ArmorObject armor = armor_in; + + if ((detect_color == 0 && armor.color == ArmorColor::BLUE) + || (detect_color == 1 && armor.color == ArmorColor::RED)) + { + continue; + } + if (params_.enable_classify || params_.enable_cv) { + bool ok = false; + + ok = extractNetImage(src_img, armor); + + if (!ok) + continue; + } + + if (params_.enable_classify) { + number_classifier_->classifyNumber(armor); + if (armor.confidence < params_.classifier_threshold) + continue; + } + + if (target_number.has_value()) { + if (!isSameTarget(target_number.value(), armor.number)) { + continue; + } + } + + if (armor.color == ArmorColor::NONE || armor.color == ArmorColor::PURPLE) { + armor.is_ok = false; + armor.transform(transform_matrix); + armors.emplace_back(armor); + continue; + } + if (params_.enable_cv) { + findLights(armor.whole_rgb_img, armor.whole_binary_img, armor); + + if (refineLightsFromArmorPts(armor)) { + if (isArmor(armor.lights[0], armor.lights[1])) { + armor.is_ok = true; + for (auto& light: armor.lights) { + light.addOffset(armor.local_offset); + } + } + } + + if (armor.is_ok) { + armor.is_ok = armor.checkOkptsRight(params_.max_pts_error); + } + } + + if (!armor.is_ok) { + auto ordered = armor.sortCorners(armor.pts); + Light l1, l2; + l1.length = cv::norm(ordered[1] - ordered[0]); + l1.center = (ordered[0] + ordered[1]) / 2.0; + l2.length = cv::norm(ordered[2] - ordered[3]); + l2.center = (ordered[2] + ordered[3]) / 2.0; + if (!isArmor(l1, l2)) { + continue; + } + } + + armor.transform(transform_matrix); + + armors.emplace_back(armor); + } + + return armors; + } + std::unique_ptr number_classifier_; + struct ArmorDetectCommonParams { + std::string classify_backend = "opencv"; + std::string classify_model_path; + std::string classify_label_path; + double classifier_threshold = 0.5; + LightParams light_params; + ArmorParams armor_params; + float expand_ratio_w = 1.1f; + float expand_ratio_h = 1.1f; + double max_pts_error = 20.0; + bool enable_cv = false; + bool enable_classify = true; + void load(const YAML::Node& config) { + expand_ratio_w = config["cv"]["light"]["expand_ratio_w"].as(1.1); + expand_ratio_h = config["cv"]["light"]["expand_ratio_h"].as(1.1); + max_pts_error = config["cv"]["light"]["max_pts_error"].as(20.0); + enable_cv = config["cv"]["enable"].as(); + light_params.load(config["cv"]["light"]); + armor_params.load(config["cv"]["armor"]); + enable_classify = config["classify"]["enable"].as(); + classify_model_path = + utils::expandEnv(config["classify"]["model_path"].as()); + classify_label_path = + utils::expandEnv(config["classify"]["label_path"].as()); + classify_backend = config["classify"]["backend"].as(); + classifier_threshold = config["classify"]["threshold"].as(); + } + } params_; + }; + ArmorDetectorCommon::ArmorDetectorCommon(const YAML::Node& config) { + _impl = std::make_unique(config); + } + ArmorDetectorCommon::~ArmorDetectorCommon() { + _impl.reset(); + } + std::vector ArmorDetectorCommon::detectNet( + const cv::Mat& src_img, + std::vector& objs_result, + Eigen::Matrix3f transform_matrix, + int detect_color, + const std::optional& target_number + ) { + return _impl + ->detectNet(src_img, objs_result, transform_matrix, detect_color, target_number); + } +} // namespace auto_aim +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_detect/armor_detector_common.hpp b/wust_vision-main/tasks/auto_aim/armor_detect/armor_detector_common.hpp new file mode 100644 index 0000000..2a6b9df --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_detect/armor_detector_common.hpp @@ -0,0 +1,41 @@ +// Copyright 2025 Xiaojian Wu +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "tasks/auto_aim/type.hpp" +namespace wust_vision { +namespace auto_aim { + class ArmorDetectorCommon { + public: + using Ptr = std::unique_ptr; + ArmorDetectorCommon(const YAML::Node& config); + static Ptr create(const YAML::Node& config) { + return std::make_unique(config); + } + ~ArmorDetectorCommon(); + + std::vector detectNet( + const cv::Mat& src_img, + std::vector& objs_result, + Eigen::Matrix3f transform_matrix, + int detect_color, + const std::optional& target_number = std::nullopt + ); + + private: + struct Impl; + std::unique_ptr _impl; + }; +} // namespace auto_aim +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_detect/armor_detector_factory.hpp b/wust_vision-main/tasks/auto_aim/armor_detect/armor_detector_factory.hpp new file mode 100644 index 0000000..4581e82 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_detect/armor_detector_factory.hpp @@ -0,0 +1,134 @@ +// Copyright 2025 XiaoJian Wu +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "armor_detector_base.hpp" +#include "tasks/utils/config.hpp" +#include +#include + +#ifdef USE_OPENVINO + #include "tasks/auto_aim/armor_detect/openvino/armor_detector_openvino.hpp" +#endif +#ifdef USE_TRT + #include "tasks/auto_aim/armor_detect/tensorrt/armor_detector_tensorrt.hpp" +#endif +#ifdef USE_NCNN + #include "tasks/auto_aim/armor_detect/ncnn/armor_detector_ncnn.hpp" +#endif +#ifdef USE_ORT + #include "tasks/auto_aim/armor_detect/onnxruntime/armor_detector_onnxruntime.hpp" +#endif +#include "tasks/auto_aim/armor_detect/opencv/armor_detector_opencv.hpp" + +namespace wust_vision { +namespace auto_aim { + + class DetectorFactory { + public: + static ArmorDetectorBase::Ptr createArmorDetector( + const std::string& backend, + bool use_armor_detect_common, + std::string cv_config_path = OPENCV_CONFIG, + std::string ml_config_path = ML_CONFIG + ) { + // 检查编译时是否支持 + auto isBackendEnabled = [&backend]() -> bool { +#ifdef USE_OPENVINO + if (backend == "openvino") + return true; +#endif +#ifdef USE_TRT + if (backend == "tensorrt") + return true; +#endif +#ifdef USE_NCNN + if (backend == "ncnn") + return true; +#endif +#ifdef USE_ORT + if (backend == "onnxruntime") + return true; +#endif + if (backend == "opencv") + return true; + return false; + }; + + if (!isBackendEnabled()) { + std::cout << "Backend " << backend << " is not enabled at compile time." + << std::endl; + throw std::runtime_error("Backend " + backend + " is not enabled at compile time."); + } + + auto getConfigPath = [&](const std::string& backend) -> std::string { + if (backend == "opencv") + return cv_config_path; + else + return ml_config_path; + }; + + std::string config_path = getConfigPath(backend); + if (config_path.empty()) { + std::cout << "No config path for backend: " << backend << std::endl; + throw std::runtime_error("No config path for backend: " + backend); + } + + YAML::Node armor_detect_config = YAML::LoadFile(config_path); + + // 创建对应后端实例 +#if defined(USE_OPENVINO) + if (backend == "openvino") { + return ArmorDetectorOpenVino::create( + armor_detect_config["armor_detector"], + use_armor_detect_common + ); + } +#endif +#if defined(USE_TRT) + if (backend == "tensorrt") { + return ArmorDetectorTrt::create( + armor_detect_config["armor_detector"], + use_armor_detect_common + ); + } +#endif +#if defined(USE_NCNN) + if (backend == "ncnn") { + return ArmorDetectorNCNN::create( + armor_detect_config["armor_detector"], + use_armor_detect_common + ); + } +#endif +#if defined(USE_ORT) + if (backend == "onnxruntime") { + return ArmorDetectorOnnxRuntime::create( + armor_detect_config["armor_detector"], + use_armor_detect_common + ); + } +#endif + if (backend == "opencv") { + return ArmorDetectorOpenCV::create(armor_detect_config["armor_detector"]); + } + std::cout << "Unsupported armor detector backend (or not compiled): " << backend + << std::endl; + throw std::runtime_error( + "Unsupported armor detector backend (or not compiled): " + backend + ); + } + }; + +} // namespace auto_aim +} // namespace wust_vision diff --git a/wust_vision-main/tasks/auto_aim/armor_detect/armor_infer.cpp b/wust_vision-main/tasks/auto_aim/armor_detect/armor_infer.cpp new file mode 100644 index 0000000..c5a8c2e --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_detect/armor_infer.cpp @@ -0,0 +1,198 @@ +#include "armor_infer.hpp" +namespace wust_vision::auto_aim::armor_infer { +struct GridAndStride { + int grid0; + int grid1; + int stride; +}; +[[nodiscard]] static inline std::vector +generate_grids_and_stride(int target_w, int target_h, const std::vector& strides) noexcept { + std::vector grid_strides; + for (int stride: strides) { + const int num_w = target_w / stride; + const int num_h = target_h / stride; + grid_strides.reserve(grid_strides.size() + num_w * num_h); + for (int gy = 0; gy < num_h; ++gy) { + for (int gx = 0; gx < num_w; ++gx) { + grid_strides.push_back(GridAndStride { gx, gy, stride }); + } + } + } + return grid_strides; +} +std::vector ArmorInfer::postProcessTUP_impl(const cv::Mat& out) const { + static std::optional> _grid_strides; + if (!_grid_strides) { + _grid_strides = generate_grids_and_stride(inputW(), inputH(), { 8, 16, 32 }); + } + const auto& grid_strides = _grid_strides.value(); + std::vector out_objs; + const int num_anchors = + static_cast(std::min(grid_strides.size(), static_cast(out.rows))); + for (int a = 0; a < num_anchors; ++a) { + const float confidence = out.at(a, 8); + if (confidence < conf_threshold_) + continue; + + const auto& gs = grid_strides[a]; + const int gx = gs.grid0, gy = gs.grid1, stride = gs.stride; + + // color & class + const int color_offset = 9; + const int num_colors = ModelTraits::NUM_COLORS; + const int num_classes = ModelTraits::NUM_CLASSES; + + cv::Mat color_scores = out.row(a).colRange(color_offset, color_offset + num_colors); + cv::Mat class_scores = + out.row(a).colRange(color_offset + num_colors, color_offset + num_colors + num_classes); + + double max_color, max_class; + cv::Point color_id, class_id; + cv::minMaxLoc(color_scores, nullptr, &max_color, nullptr, &color_id); + cv::minMaxLoc(class_scores, nullptr, &max_class, nullptr, &class_id); + + const float x1 = (out.at(a, 0) + gx) * stride; + const float y1 = (out.at(a, 1) + gy) * stride; + const float x2 = (out.at(a, 2) + gx) * stride; + const float y2 = (out.at(a, 3) + gy) * stride; + const float x3 = (out.at(a, 4) + gx) * stride; + const float y3 = (out.at(a, 5) + gy) * stride; + const float x4 = (out.at(a, 6) + gx) * stride; + const float y4 = (out.at(a, 7) + gy) * stride; + + ArmorObject obj; + obj.pts = { cv::Point2f(x1, y1), + cv::Point2f(x2, y2), + cv::Point2f(x3, y3), + cv::Point2f(x4, y4) }; + obj.box = cv::boundingRect(obj.pts); + obj.color = static_cast(color_id.x); + obj.number = static_cast(class_id.x); + obj.confidence = confidence; + out_objs.push_back(std::move(obj)); + } + return topKAndNms(out_objs, top_k_, nms_threshold_); +} + +std::vector ArmorInfer::postProcessRP_impl(const cv::Mat& out) const { + std::vector out_objs; + const int rows = out.rows; + const int color_offset = 9; + const int num_colors = ModelTraits::NUM_COLORS; + const int num_classes = ModelTraits::NUM_CLASSES; + + for (int r = 0; r < rows; ++r) { + float conf_raw = out.at(r, 8); + const float confidence = static_cast(sigmoid(conf_raw)); + if (confidence < conf_threshold_) + continue; + + cv::Mat color_scores = out.row(r).colRange(color_offset, color_offset + num_colors); + cv::Mat class_scores = + out.row(r).colRange(color_offset + num_colors, color_offset + num_colors + num_classes); + + double max_color_score, max_class_score; + cv::Point color_id, class_id; + cv::minMaxLoc(color_scores, nullptr, &max_color_score, nullptr, &color_id); + cv::minMaxLoc(class_scores, nullptr, &max_class_score, nullptr, &class_id); + + ArmorObject obj; + obj.pts.resize(4); + for (int k = 0; k < 4; ++k) { + const float x = out.at(r, 0 + k * 2); + const float y = out.at(r, 1 + k * 2); + obj.pts[k] = cv::Point2f(x, y); + } + obj.box = cv::boundingRect(obj.pts); + + obj.color = static_cast(color_id.x); + + obj.number = static_cast(class_id.x); + obj.confidence = confidence; + out_objs.push_back(std::move(obj)); + } + + return topKAndNms(out_objs, top_k_, nms_threshold_); +} + +std::vector ArmorInfer::postProcessAT_impl(const cv::Mat& out) const { + std::vector out_objs; + + constexpr int nkpt = ModelTraits::NUM_KPTS; + constexpr int nk = nkpt * 2; // keypoints flattened + auto max_det = out.rows; + auto det_dim = out.cols; + auto output_ptr = out.ptr(); + for (int i = 0; i < max_det; ++i) { + const float* row = output_ptr + i * det_dim; + float conf = row[4]; + if (!std::isfinite(conf) || conf < conf_threshold_) + continue; + float x1 = row[0]; + float y1 = row[1]; + float x2 = row[2]; + float y2 = row[3]; + int cls = static_cast(row[5]); + + if (!std::isfinite(x1) || !std::isfinite(y1) || !std::isfinite(x2) || !std::isfinite(y2) + || x2 <= x1 || y2 <= y1) + continue; + + ArmorObject obj; + obj.box = cv::Rect2f(x1, y1, x2 - x1, y2 - y1); + obj.confidence = conf; + auto color_num = ModelTraits::CLASSES[cls]; + obj.color = color_num.first; + obj.number = color_num.second; + + obj.pts.reserve(nkpt); + for (int k = 0; k < nkpt; ++k) { + float kx = row[6 + 2 * k]; + float ky = row[6 + 2 * k + 1]; + obj.pts.emplace_back(kx, ky); + } + + out_objs.emplace_back(std::move(obj)); + } + + return out_objs; +} +std::vector ArmorInfer::postProcessBOX_impl(const cv::Mat& out) const { + std::vector out_objs; + auto max_det = out.rows; + auto det_dim = out.cols; + auto output_ptr = out.ptr(); + for (int i = 0; i < max_det; ++i) { + const float* row = output_ptr + i * det_dim; + float conf = row[4]; + if (!std::isfinite(conf) || conf < conf_threshold_) + continue; + float x1 = row[0]; + float y1 = row[1]; + float x2 = row[2]; + float y2 = row[3]; + int cls = static_cast(row[5]); + + if (!std::isfinite(x1) || !std::isfinite(y1) || !std::isfinite(x2) || !std::isfinite(y2) + || x2 <= x1 || y2 <= y1) + continue; + + ArmorObject obj; + obj.box = cv::Rect2f(x1, y1, x2 - x1, y2 - y1); + obj.confidence = conf; + auto color_num = ModelTraits::CLASSES[cls]; + obj.color = color_num.first; + obj.number = color_num.second; + std::vector pts; + pts.resize(4); + pts[0] = cv::Point2f(obj.box.x + obj.box.width, obj.box.y + obj.box.height); // 右下 + pts[1] = cv::Point2f(obj.box.x + obj.box.width, obj.box.y); // 右上 + pts[2] = cv::Point2f(obj.box.x, obj.box.y); // 左上 + pts[3] = cv::Point2f(obj.box.x, obj.box.y + obj.box.height); // 左下 + obj.pts = std::move(pts); + + out_objs.emplace_back(std::move(obj)); + } + return out_objs; +} +} // namespace wust_vision::auto_aim::armor_infer \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_detect/armor_infer.hpp b/wust_vision-main/tasks/auto_aim/armor_detect/armor_infer.hpp new file mode 100644 index 0000000..f291de7 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_detect/armor_infer.hpp @@ -0,0 +1,324 @@ +#pragma once +#include "tasks/auto_aim/type.hpp" + +namespace wust_vision::auto_aim::armor_infer { + +static constexpr float MERGE_CONF_ERROR = 0.15f; +static constexpr float MERGE_MIN_IOU = 0.9f; + +enum class Mode { TUP, RP, AT, BOX416, BOX320, BOX }; + +inline Mode modeFromString(const std::string& m) { + if (m == "tup" || m == "TUP") + return Mode::TUP; + if (m == "rp" || m == "RP") + return Mode::RP; + if (m == "at" || m == "AT") + return Mode::AT; + if (m == "box416" || m == "BOX416") + return Mode::BOX416; + if (m == "box320" || m == "BOX320") + return Mode::BOX320; + return Mode::TUP; +} + +// ------------------------- model traits ------------------------- +template +struct ModelTraits; // declare +// TUP +template<> +struct ModelTraits { + static constexpr int INPUT_W = 416; + static constexpr int INPUT_H = 416; + static constexpr int NUM_CLASSES = 8; + static constexpr int NUM_COLORS = 4; + static constexpr bool USE_NORM = false; + static constexpr bool INPUT_RGB = true; +}; + +// RP +template<> +struct ModelTraits { + static constexpr int INPUT_W = 640; + static constexpr int INPUT_H = 640; + static constexpr int NUM_CLASSES = 9; + static constexpr int NUM_COLORS = 4; + static constexpr bool USE_NORM = true; + static constexpr bool INPUT_RGB = false; +}; + +template<> +struct ModelTraits { + static constexpr int INPUT_W = 640; + static constexpr int INPUT_H = 640; + static constexpr int NUM_KPTS = 4; + static constexpr bool USE_NORM = true; + static constexpr bool INPUT_RGB = false; + static constexpr std::array, 64> CLASSES = { { + { ArmorColor::BLUE, ArmorNumber::SENTRY }, { ArmorColor::BLUE, ArmorNumber::NO1 }, + { ArmorColor::BLUE, ArmorNumber::NO2 }, { ArmorColor::BLUE, ArmorNumber::NO3 }, + { ArmorColor::BLUE, ArmorNumber::NO4 }, { ArmorColor::BLUE, ArmorNumber::NO5 }, + { ArmorColor::BLUE, ArmorNumber::OUTPOST }, { ArmorColor::BLUE, ArmorNumber::BASE }, + { ArmorColor::BLUE, ArmorNumber::SENTRY }, { ArmorColor::BLUE, ArmorNumber::NO1 }, + { ArmorColor::BLUE, ArmorNumber::NO2 }, { ArmorColor::BLUE, ArmorNumber::NO3 }, + { ArmorColor::BLUE, ArmorNumber::NO4 }, { ArmorColor::BLUE, ArmorNumber::NO5 }, + { ArmorColor::BLUE, ArmorNumber::OUTPOST }, { ArmorColor::BLUE, ArmorNumber::BASE }, + { ArmorColor::RED, ArmorNumber::SENTRY }, { ArmorColor::RED, ArmorNumber::NO1 }, + { ArmorColor::RED, ArmorNumber::NO2 }, { ArmorColor::RED, ArmorNumber::NO3 }, + { ArmorColor::RED, ArmorNumber::NO4 }, { ArmorColor::RED, ArmorNumber::NO5 }, + { ArmorColor::RED, ArmorNumber::OUTPOST }, { ArmorColor::RED, ArmorNumber::BASE }, + { ArmorColor::RED, ArmorNumber::SENTRY }, { ArmorColor::RED, ArmorNumber::NO1 }, + { ArmorColor::RED, ArmorNumber::NO2 }, { ArmorColor::RED, ArmorNumber::NO3 }, + { ArmorColor::RED, ArmorNumber::NO4 }, { ArmorColor::RED, ArmorNumber::NO5 }, + { ArmorColor::RED, ArmorNumber::OUTPOST }, { ArmorColor::RED, ArmorNumber::BASE }, + { ArmorColor::NONE, ArmorNumber::SENTRY }, { ArmorColor::NONE, ArmorNumber::NO1 }, + { ArmorColor::NONE, ArmorNumber::NO2 }, { ArmorColor::NONE, ArmorNumber::NO3 }, + { ArmorColor::NONE, ArmorNumber::NO4 }, { ArmorColor::NONE, ArmorNumber::NO5 }, + { ArmorColor::NONE, ArmorNumber::OUTPOST }, { ArmorColor::NONE, ArmorNumber::BASE }, + { ArmorColor::NONE, ArmorNumber::SENTRY }, { ArmorColor::NONE, ArmorNumber::NO1 }, + { ArmorColor::NONE, ArmorNumber::NO2 }, { ArmorColor::NONE, ArmorNumber::NO3 }, + { ArmorColor::NONE, ArmorNumber::NO4 }, { ArmorColor::NONE, ArmorNumber::NO5 }, + { ArmorColor::NONE, ArmorNumber::OUTPOST }, { ArmorColor::NONE, ArmorNumber::BASE }, + { ArmorColor::PURPLE, ArmorNumber::SENTRY }, { ArmorColor::PURPLE, ArmorNumber::NO1 }, + { ArmorColor::PURPLE, ArmorNumber::NO2 }, { ArmorColor::PURPLE, ArmorNumber::NO3 }, + { ArmorColor::PURPLE, ArmorNumber::NO4 }, { ArmorColor::PURPLE, ArmorNumber::NO5 }, + { ArmorColor::PURPLE, ArmorNumber::OUTPOST }, { ArmorColor::PURPLE, ArmorNumber::BASE }, + { ArmorColor::PURPLE, ArmorNumber::SENTRY }, { ArmorColor::PURPLE, ArmorNumber::NO1 }, + { ArmorColor::PURPLE, ArmorNumber::NO2 }, { ArmorColor::PURPLE, ArmorNumber::NO3 }, + { ArmorColor::PURPLE, ArmorNumber::NO4 }, { ArmorColor::PURPLE, ArmorNumber::NO5 }, + { ArmorColor::PURPLE, ArmorNumber::OUTPOST }, { ArmorColor::PURPLE, ArmorNumber::BASE }, + } }; +}; +template<> +struct ModelTraits { + static constexpr int INPUT_W = 416; + static constexpr int INPUT_H = 416; +}; +template<> +struct ModelTraits { + static constexpr int INPUT_W = 320; + static constexpr int INPUT_H = 320; +}; +template<> +struct ModelTraits { + static constexpr int INPUT_W = 416; + static constexpr int INPUT_H = 416; + static constexpr bool INPUT_RGB = false; + static constexpr bool USE_NORM = true; + static constexpr std::array, 12> CLASSES = { { + { ArmorColor::BLUE, ArmorNumber::NO1 }, + { ArmorColor::BLUE, ArmorNumber::NO2 }, + { ArmorColor::BLUE, ArmorNumber::NO3 }, + { ArmorColor::BLUE, ArmorNumber::NO4 }, + { ArmorColor::BLUE, ArmorNumber::NO5 }, + { ArmorColor::BLUE, ArmorNumber::SENTRY }, + { ArmorColor::RED, ArmorNumber::NO1 }, + { ArmorColor::RED, ArmorNumber::NO2 }, + { ArmorColor::RED, ArmorNumber::NO3 }, + { ArmorColor::RED, ArmorNumber::NO4 }, + { ArmorColor::RED, ArmorNumber::NO5 }, + { ArmorColor::RED, ArmorNumber::SENTRY }, + + } }; +}; + +[[nodiscard]] inline double sigmoid(double x) noexcept { + return x >= 0 ? 1.0 / (1.0 + std::exp(-x)) : std::exp(x) / (1.0 + std::exp(x)); +} + +[[nodiscard]] inline float rectIoU(const cv::Rect2f& a, const cv::Rect2f& b) noexcept { + const cv::Rect2f inter = a & b; + const float inter_area = inter.area(); + const float union_area = a.area() + b.area() - inter_area; + if (union_area <= 0.f || std::isnan(union_area)) + return 0.f; + return inter_area / union_area; +} + +// Merge / NMS helpers that mimic original intent but clearer +inline void nms_merge_sorted_bboxes( + std::vector& objs, + std::vector& out_indices, + float nms_threshold +) { + out_indices.clear(); + const size_t n = objs.size(); + std::vector areas(n); + for (size_t i = 0; i < n; ++i) + areas[i] = objs[i].box.area(); + + for (size_t i = 0; i < n; ++i) { + ArmorObject& a = objs[i]; + bool keep = true; + for (int idx: out_indices) { + ArmorObject& b = objs[idx]; + const float iou = rectIoU(a.box, b.box); + if (std::isnan(iou) || iou > nms_threshold) { + keep = false; + if (a.number == b.number && a.color == b.color && iou > MERGE_MIN_IOU + && std::abs(a.confidence - b.confidence) < MERGE_CONF_ERROR) + { + // accumulate points for later averaging + for (const auto& pt: a.pts) + b.pts.push_back(pt); + } + break; + } + } + if (keep) + out_indices.push_back(static_cast(i)); + } +} + +inline std::vector +topKAndNms(std::vector& objs, int top_k, float nms_threshold) { + std::sort(objs.begin(), objs.end(), [](const ArmorObject& a, const ArmorObject& b) { + return a.confidence > b.confidence; + }); + if (static_cast(objs.size()) > top_k) + objs.resize(static_cast(top_k)); + + std::vector indices; + nms_merge_sorted_bboxes(objs, indices, nms_threshold); + + std::vector result; + result.reserve(indices.size()); + for (size_t i = 0; i < indices.size(); ++i) { + result.push_back(std::move(objs[indices[i]])); + // average merged extra points if any + auto& ro = result.back(); + if (ro.pts.size() >= 8) { + const size_t npts = ro.pts.size(); + std::array accum { { { 0, 0 }, { 0, 0 }, { 0, 0 }, { 0, 0 } } }; + for (size_t j = 0; j < npts; ++j) + accum[j % 4] += ro.pts[j]; + ro.pts.resize(4); + for (int k = 0; k < 4; ++k) { + float denom = static_cast(npts) / 4.0f; + ro.pts[k].x = accum[k].x / denom; + ro.pts[k].y = accum[k].y / denom; + } + } + } + return result; +} + +// However providing full modern unified class below that delegates using templated helpers. + +class ArmorInfer { +public: + ArmorInfer( + Mode mode = Mode::TUP, + float conf_threshold = 0.25f, + float nms_threshold = 0.45f, + int top_k = 100 + ) noexcept: + mode_(mode), + conf_threshold_(conf_threshold), + nms_threshold_(nms_threshold), + top_k_(top_k) { + setMode(mode_); + } + + void setMode(Mode m) noexcept { + mode_ = m; + switch (mode_) { + case Mode::TUP: { + input_w_ = ModelTraits::INPUT_W; + input_h_ = ModelTraits::INPUT_H; + use_norm_ = ModelTraits::USE_NORM; + input_rgb_ = ModelTraits::INPUT_RGB; + break; + } + case Mode::RP: { + input_w_ = ModelTraits::INPUT_W; + input_h_ = ModelTraits::INPUT_H; + use_norm_ = ModelTraits::USE_NORM; + input_rgb_ = ModelTraits::INPUT_RGB; + break; + } + case Mode::AT: { + input_w_ = ModelTraits::INPUT_W; + input_h_ = ModelTraits::INPUT_H; + use_norm_ = ModelTraits::USE_NORM; + input_rgb_ = ModelTraits::INPUT_RGB; + break; + } + case Mode::BOX416: { + input_w_ = ModelTraits::INPUT_W; + input_h_ = ModelTraits::INPUT_H; + use_norm_ = ModelTraits::USE_NORM; + input_rgb_ = ModelTraits::INPUT_RGB; + break; + } + case Mode::BOX320: { + input_w_ = ModelTraits::INPUT_W; + input_h_ = ModelTraits::INPUT_H; + use_norm_ = ModelTraits::USE_NORM; + input_rgb_ = ModelTraits::INPUT_RGB; + break; + } break; + } + } + + void setConfThreshold(float t) noexcept { + conf_threshold_ = t; + } + void setNmsThreshold(float t) noexcept { + nms_threshold_ = t; + } + void setTopK(int k) noexcept { + top_k_ = k; + } + + int inputW() const noexcept { + return input_w_; + } + int inputH() const noexcept { + return input_h_; + } + bool useNorm() const noexcept { + return use_norm_; + } + bool inputRGB() const noexcept { + return input_rgb_; + } + + // main dispatching entry (keeps original signature) + [[nodiscard]] std::vector postProcess(const cv::Mat& output_buffer) const { + switch (mode_) { + case Mode::TUP: + return postProcessTUP_impl(output_buffer); + case Mode::RP: + return postProcessRP_impl(output_buffer); + case Mode::AT: + return postProcessAT_impl(output_buffer); + case Mode::BOX416: + return postProcessBOX_impl(output_buffer); + case Mode::BOX320: + return postProcessBOX_impl(output_buffer); + } + return {}; + } + +private: + std::vector postProcessTUP_impl(const cv::Mat& out) const; + + std::vector postProcessRP_impl(const cv::Mat& out) const; + + std::vector postProcessAT_impl(const cv::Mat& out) const; + + std::vector postProcessBOX_impl(const cv::Mat& out) const; + +private: + Mode mode_; + int input_w_ { 0 }; + int input_h_ { 0 }; + float conf_threshold_ { 0.25f }; + float nms_threshold_ { 0.45f }; + int top_k_ { 100 }; + bool use_norm_ { false }; + bool input_rgb_ { false }; +}; + +} // namespace wust_vision::auto_aim::armor_infer \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_detect/ncnn/armor_detector_ncnn.cpp b/wust_vision-main/tasks/auto_aim/armor_detect/ncnn/armor_detector_ncnn.cpp new file mode 100644 index 0000000..4b1e8a9 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_detect/ncnn/armor_detector_ncnn.cpp @@ -0,0 +1,222 @@ + +#ifdef USE_NCNN + #include "tasks/auto_aim/armor_detect/ncnn/armor_detector_ncnn.hpp" + #include "tasks/auto_aim/armor_detect/armor_detector_common.hpp" + #include "tasks/auto_aim/armor_detect/armor_infer.hpp" + #include "wust_vl/ml_net/ncnn/ncnn_net.hpp" +namespace wust_vision { +namespace auto_aim { + + struct ArmorDetectorNCNN::Impl { + public: + Impl(const YAML::Node& config, bool use_armor_detect_common) { + if (use_armor_detect_common) { + armor_detect_common_ = ArmorDetectorCommon::create(config); + } + std::string model_type = config["ncnn"]["model_type"].as(); + auto model = armor_infer::modeFromString(model_type); + float conf_threshold = config["ncnn"]["conf_threshold"].as(); + int top_k = config["ncnn"]["top_k"].as(); + float nms_threshold = config["ncnn"]["nms_threshold"].as(); + armor_infer_ = std::make_unique( + model, + conf_threshold, + nms_threshold, + top_k + ); + std::string model_path_param = + utils::expandEnv(config["ncnn"]["model_path_param"].as()); + std::string model_path_bin = + utils::expandEnv(config["ncnn"]["model_path_bin"].as()); + bool use_gpu = config["ncnn"]["use_gpu"].as(); + int cpu_threads = config["ncnn"]["cpu_threads"].as(); + bool use_lightmode = config["ncnn"]["use_lightmode"].as(); + auto input_name = config["ncnn"]["input_name"].as(); + auto output_name = config["ncnn"]["output_name"].as(); + int device_id = config["ncnn"]["device_id"].as(); + wust_vl::ml_net::NCNNNet::Params params; + params.model_path_param = model_path_param; + params.model_path_bin = model_path_bin; + params.input_name = input_name; + params.output_name = output_name; + params.use_vulkan = use_gpu; + params.device_id = device_id; + params.use_light_mode = use_lightmode; + params.cpu_threads = cpu_threads; + ncnn_net_ = std::make_unique(); + ncnn_net_->init(params); + } + static Ptr create(const YAML::Node& config, bool use_armor_detect_common) { + return std::make_unique(config, use_armor_detect_common); + } + ~Impl() { + armor_detect_common_.reset(); + ncnn_net_.reset(); + } + cv::Mat ncnnMatToCvMat(const ncnn::Mat& m) { + cv::Mat img(m.h, m.w, CV_8UC3); + m.to_pixels(img.data, ncnn::Mat::PIXEL_RGB2BGR); + + return img; + } + void setCallback(DetectorCallback callback) { + infer_callback_ = callback; + } + static ncnn::Mat letterbox_to_ncnn( + const cv::Mat& img, + Eigen::Matrix3f& transform_matrix, + int out_w, + int out_h, + float norm, + bool swap_rb = true + ) { + const int img_w = img.cols; + const int img_h = img.rows; + + float scale = std::min(out_w * 1.0f / img_w, out_h * 1.0f / img_h); + int resize_w = static_cast(round(img_w * scale)); + int resize_h = static_cast(round(img_h * scale)); + + int pad_w = out_w - resize_w; + int pad_h = out_h - resize_h; + int pad_left = static_cast(round(pad_w / 2.0f - 0.1f)); + int pad_top = static_cast(round(pad_h / 2.0f - 0.1f)); + + transform_matrix << 1.0f / scale, 0, -pad_left / scale, 0, 1.0f / scale, + -pad_top / scale, 0, 0, 1; + ncnn::Mat out; + if (swap_rb) { + out = ncnn::Mat::from_pixels_resize( + img.data, + ncnn::Mat::PIXEL_BGR2RGB, + img_w, + img_h, + resize_w, + resize_h + ); + } else { + out = ncnn::Mat::from_pixels_resize( + img.data, + ncnn::Mat::PIXEL_RGB, + img_w, + img_h, + resize_w, + resize_h + ); + } + + int pad_right = out_w - resize_w - pad_left; + int pad_bottom = out_h - resize_h - pad_top; + + ncnn::Mat padded; + ncnn::copy_make_border( + out, + padded, + pad_top, + pad_bottom, + pad_left, + pad_right, + ncnn::BORDER_CONSTANT, + 114.f + ); + + std::array mean_vals; + std::array norm_vals; + + mean_vals = { 0.f, 0.f, 0.f }; + norm_vals = { norm, norm, norm }; + + padded.substract_mean_normalize(mean_vals.data(), norm_vals.data()); + + return padded; + } + void + processCallback(const CommonFrame& frame, const std::optional& target_number) { + // Eigen::Matrix3f transform_matrix; + // cv::Mat resized_img = letterbox(frame.src_img, transform_matrix); + // ncnn::Mat in = + // ncnn::Mat::from_pixels(resized_img.data, ncnn::Mat::PIXEL_BGR2RGB, INPUT_W, INPUT_H); + Eigen::Matrix3f transform_matrix; + auto roi = frame.img_frame.src_img(frame.expanded); + const bool swap_rb = armor_infer_->inputRGB() + != (frame.img_frame.pixel_format == wust_vl::video::PixelFormat::RGB); + const float scale = armor_infer_->useNorm() ? 1.0f / 255.0f : 1.0f; + ncnn::Mat in = letterbox_to_ncnn( + roi.clone(), + transform_matrix, + armor_infer_->inputW(), + armor_infer_->inputH(), + scale, + swap_rb + ); + cv::Mat resized_img = ncnnMatToCvMat(in); + + auto out = ncnn_net_->infer(in); + + cv::Mat output_buffer(out.h, out.w, CV_32F, out.data); + + // Parse YOLO output + auto objs_result = armor_infer_->postProcess(output_buffer); + std::vector armors; + if (armor_detect_common_) { + armors = armor_detect_common_->detectNet( + resized_img, + objs_result, + transform_matrix, + frame.detect_color, + target_number + ); + // Call callback function + if (this->infer_callback_) { + this->infer_callback_(armors, frame); + return; + } + } else { + for (auto obj: objs_result) { + auto detect_color = frame.detect_color; + if (detect_color == 0 && obj.color == ArmorColor::BLUE) { + continue; + } else if (detect_color == 1 && obj.color == ArmorColor::RED) { + continue; + } + obj.transform(transform_matrix); + armors.push_back(obj); + } + if (this->infer_callback_) { + this->infer_callback_(armors, frame); + return; + } + } + + return; + } + void pushInput(CommonFrame& frame, const std::optional& target_number) { + frame.id = current_id_++; + processCallback(frame, target_number); + } + + private: + DetectorCallback infer_callback_; + std::unique_ptr armor_detect_common_; + std::unique_ptr armor_infer_; + int current_id_ = 0; + std::unique_ptr ncnn_net_; + }; + ArmorDetectorNCNN::ArmorDetectorNCNN(const YAML::Node& config, bool use_armor_detect_common) { + _impl = std::make_unique(config, use_armor_detect_common); + } + ArmorDetectorNCNN::~ArmorDetectorNCNN() { + _impl.reset(); + } + void ArmorDetectorNCNN::setCallback(DetectorCallback callback) { + _impl->setCallback(callback); + } + void ArmorDetectorNCNN::pushInput( + CommonFrame& frame, + const std::optional& target_number + ) { + _impl->pushInput(frame, target_number); + } +} // namespace auto_aim +} // namespace wust_vision +#endif \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_detect/ncnn/armor_detector_ncnn.hpp b/wust_vision-main/tasks/auto_aim/armor_detect/ncnn/armor_detector_ncnn.hpp new file mode 100644 index 0000000..1cf1a07 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_detect/ncnn/armor_detector_ncnn.hpp @@ -0,0 +1,36 @@ +// Copyright 2025 XiaoJian Wu +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "tasks/auto_aim/armor_detect/armor_detector_base.hpp" +namespace wust_vision { +namespace auto_aim { + class ArmorDetectorNCNN: public ArmorDetectorBase { + public: + using Ptr = std::unique_ptr; + explicit ArmorDetectorNCNN(const YAML::Node& config, bool use_armor_detect_common); + static Ptr create(const YAML::Node& config, bool use_armor_detect_common) { + return std::make_unique(config, use_armor_detect_common); + } + ~ArmorDetectorNCNN(); + void setCallback(DetectorCallback callback) override; + void + pushInput(CommonFrame& frame, const std::optional& target_number) override; + + private: + struct Impl; + std::unique_ptr _impl; + }; +} // namespace auto_aim +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_detect/number_classifier/base.hpp b/wust_vision-main/tasks/auto_aim/armor_detect/number_classifier/base.hpp new file mode 100644 index 0000000..ce4e428 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_detect/number_classifier/base.hpp @@ -0,0 +1,12 @@ +#pragma once +#include "tasks/auto_aim/type.hpp" +namespace wust_vision { +namespace auto_aim { + class NumberClassifierBase { + public: + virtual ~NumberClassifierBase() = default; + virtual bool classifyNumber(ArmorObject& armor) = 0; + virtual void initNumberClassifier() = 0; + }; +} // namespace auto_aim +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_detect/number_classifier/factory.hpp b/wust_vision-main/tasks/auto_aim/armor_detect/number_classifier/factory.hpp new file mode 100644 index 0000000..1bf4e82 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_detect/number_classifier/factory.hpp @@ -0,0 +1,34 @@ +#pragma once +#include "number_classifier.hpp" +#ifdef USE_TRT + #include "number_classifier_trt.hpp" +#endif +namespace wust_vision { +namespace auto_aim { + class NumberClassifierFactory { + public: + static std::unique_ptr createNumberClassifier( + const std::string& backend, + const std::string& classify_model_path, + const std::string& classify_label_path + ) { +#if defined(USE_TRT) + if (backend == "tensorrt") { + return std::make_unique( + classify_model_path, + classify_label_path + ); + } +#endif + + if (backend == "opencv") { + return std::make_unique(classify_model_path, classify_label_path); + } + + throw std::runtime_error( + "Unsupported number classifier backend (or not compiled): " + backend + ); + } + }; +} // namespace auto_aim +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_detect/number_classifier/number_classifier.cpp b/wust_vision-main/tasks/auto_aim/armor_detect/number_classifier/number_classifier.cpp new file mode 100644 index 0000000..6364199 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_detect/number_classifier/number_classifier.cpp @@ -0,0 +1,116 @@ +// Copyright Chen Jun 2023. Licensed under the MIT License. +// Copyright 2025 Xiaojian Wu +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "number_classifier.hpp" +#include +namespace wust_vision { +namespace auto_aim { + NumberClassifier::NumberClassifier( + const std::string& classify_model_path, + const std::string& classify_label_path + ): + classify_model_path_(classify_model_path), + classify_label_path_(classify_label_path) { + initNumberClassifier(); + } + void NumberClassifier::initNumberClassifier() { + const std::string model_path = classify_model_path_; + std::unique_ptr number_net_ = + std::make_unique(cv::dnn::readNetFromONNX(model_path)); + + if (number_net_->empty()) { + WUST_ERROR("number_classifier") + << "Failed to load number classifier model from " << model_path; + std::exit(EXIT_FAILURE); + } else { + WUST_INFO("number_classifier") + << "Successfully loaded number classifier model from " << model_path; + } + + const std::string label_path = classify_label_path_; + std::ifstream label_file(label_path); + std::string line; + + class_names_.clear(); + + while (std::getline(label_file, line)) { + class_names_.push_back(line); + } + + if (class_names_.empty()) { + WUST_ERROR("number_classifier") << "Failed to load labels from " << label_path; + std::exit(EXIT_FAILURE); + } else { + WUST_INFO("number_classifier") + << "Successfully loaded " << class_names_.size() << " labels from " << label_path; + } + number_net_.reset(); + } + bool NumberClassifier::classifyNumber(ArmorObject& armor) { + static thread_local std::unique_ptr thread_net; + if (armor.number_img.empty()) { + return false; + } + + if (!thread_net) { + thread_net = + std::make_unique(cv::dnn::readNetFromONNX(classify_model_path_)); + WUST_DEBUG("number_classifier") << "Loaded number classifier model for this thread"; + if (thread_net->empty()) { + WUST_ERROR("number_classifier") + << "Failed to load thread-local number classifier model."; + return false; + } + } + + const cv::Mat image = armor.number_img; + + cv::Mat blob; + cv::dnn::blobFromImage(image, blob, 1.0 / 255.0); + + thread_net->setInput(blob); + cv::Mat outputs = thread_net->forward(); + double max_val; + cv::minMaxLoc(outputs, nullptr, &max_val); + + cv::Mat prob; + cv::exp(outputs - max_val, prob); + prob /= cv::sum(prob)[0]; + + double confidence; + cv::Point class_id; + cv::minMaxLoc(prob, nullptr, &confidence, nullptr, &class_id); + + const int label_id = class_id.x; + const double raw_conf = armor.confidence; + armor.confidence = confidence; + + static const std::map label_to_armor_number = { + { 0, ArmorNumber::NO1 }, { 1, ArmorNumber::NO2 }, { 2, ArmorNumber::NO3 }, + { 3, ArmorNumber::NO4 }, { 4, ArmorNumber::NO5 }, { 5, ArmorNumber::OUTPOST }, + { 6, ArmorNumber::SENTRY }, { 7, ArmorNumber::BASE } + }; + + if (label_id < 8 && label_to_armor_number.find(label_id) != label_to_armor_number.end()) { + armor.number = label_to_armor_number.at(label_id); + + return true; + } else { + armor.confidence = raw_conf; + return false; + } + } +} // namespace auto_aim +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_detect/number_classifier/number_classifier.hpp b/wust_vision-main/tasks/auto_aim/armor_detect/number_classifier/number_classifier.hpp new file mode 100644 index 0000000..7415194 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_detect/number_classifier/number_classifier.hpp @@ -0,0 +1,35 @@ +// Copyright Chen Jun 2023. Licensed under the MIT License. +// Copyright 2025 Xiaojian Wu +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "base.hpp" +#include "tasks/auto_aim/type.hpp" +namespace wust_vision { +namespace auto_aim { + class NumberClassifier: public NumberClassifierBase { + public: + NumberClassifier( + const std::string& classify_model_path, + const std::string& classify_label_path + ); + void initNumberClassifier() override; + bool classifyNumber(ArmorObject& armor) override; + + private: + std::vector class_names_; + std::string classify_model_path_; + std::string classify_label_path_; + }; +} // namespace auto_aim +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_detect/number_classifier/number_classifier_trt.cpp b/wust_vision-main/tasks/auto_aim/armor_detect/number_classifier/number_classifier_trt.cpp new file mode 100644 index 0000000..9a2ae42 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_detect/number_classifier/number_classifier_trt.cpp @@ -0,0 +1,118 @@ +// Copyright Chen Jun 2023. Licensed under the MIT License. +// Copyright 2025 Xiaojian Wu +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifdef USE_TRT + #include "number_classifier_trt.hpp" + #include +namespace wust_vision { +namespace auto_aim { + NumberClassifierTRT::NumberClassifierTRT( + const std::string& classify_model_path, + const std::string& classify_label_path + ): + classify_model_path_(classify_model_path), + classify_label_path_(classify_label_path) { + initNumberClassifier(); + } + void NumberClassifierTRT::initNumberClassifier() { + const std::string model_path = classify_model_path_; + trt_net_ = std::make_unique(); + wust_vl::ml_net::TensorRTNet::Params trt_params; + trt_params.model_path = model_path; + trt_params.input_dims = nvinfer1::Dims4 { 1, 1, 20, 28 }; + trt_net_->init(trt_params); + const auto input_output_dims = trt_net_->getInputOutputDims(); + input_dims_ = std::get<0>(input_output_dims); + output_dims_ = std::get<1>(input_output_dims); + + const std::string label_path = classify_label_path_; + std::ifstream label_file(label_path); + std::string line; + + class_names_.clear(); + + while (std::getline(label_file, line)) { + class_names_.push_back(line); + } + + if (class_names_.empty()) { + WUST_ERROR("number_classifier_trt") << "Failed to load labels from " << label_path; + std::exit(EXIT_FAILURE); + } else { + WUST_INFO("number_classifier_trt") + << "Successfully loaded " << class_names_.size() << " labels from " << label_path; + } + } + bool NumberClassifierTRT::classifyNumber(ArmorObject& armor) { + static thread_local std::unique_ptr ctx; + if (armor.number_img.empty()) { + return false; + } + + if (!ctx) { + auto c = trt_net_->getAContext(); + ctx = std::unique_ptr(c); + WUST_DEBUG("number_classifier_trt") << "Loaded number classifier model for this thread"; + if (!ctx) { + WUST_ERROR("number_classifier_trt") + << "Failed to load thread-local number classifier model."; + return false; + } + } + + const cv::Mat image = armor.number_img; + cv::Mat blob; + cv::dnn::blobFromImage(image, blob, 1.0 / 255.0, cv::Size(28, 20)); + trt_net_->input2Device(blob.ptr()); + void* input_tensor_ptr = trt_net_->getInputTensorPtr(); + trt_net_->infer(input_tensor_ptr, ctx.get()); + + const float* out = static_cast(trt_net_->output2Host()); + + cv::Mat outputs(1, 9, CV_32F); + std::memcpy(outputs.data, out, 9 * sizeof(float)); + + double max_val; + cv::minMaxLoc(outputs, nullptr, &max_val); + + cv::Mat prob; + cv::exp(outputs - max_val, prob); + prob /= cv::sum(prob)[0]; + + double confidence; + cv::Point class_id; + cv::minMaxLoc(prob, nullptr, &confidence, nullptr, &class_id); + + const int label_id = class_id.x; + const double raw_conf = armor.confidence; + armor.confidence = confidence; + + static const std::map label_to_armor_number = { + { 0, ArmorNumber::NO1 }, { 1, ArmorNumber::NO2 }, { 2, ArmorNumber::NO3 }, + { 3, ArmorNumber::NO4 }, { 4, ArmorNumber::NO5 }, { 5, ArmorNumber::OUTPOST }, + { 6, ArmorNumber::SENTRY }, { 7, ArmorNumber::BASE } + }; + + if (label_id < 8 && label_to_armor_number.find(label_id) != label_to_armor_number.end()) { + armor.number = label_to_armor_number.at(label_id); + + return true; + } else { + armor.confidence = raw_conf; + return false; + } + } +} // namespace auto_aim +} // namespace wust_vision +#endif \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_detect/number_classifier/number_classifier_trt.hpp b/wust_vision-main/tasks/auto_aim/armor_detect/number_classifier/number_classifier_trt.hpp new file mode 100644 index 0000000..d845075 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_detect/number_classifier/number_classifier_trt.hpp @@ -0,0 +1,39 @@ +// Copyright Chen Jun 2023. Licensed under the MIT License. +// Copyright 2025 Xiaojian Wu +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "base.hpp" +#include "tasks/auto_aim/type.hpp" +#include "wust_vl/ml_net/tensorrt/tensorrt_net.hpp" +namespace wust_vision { +namespace auto_aim { + class NumberClassifierTRT: public NumberClassifierBase { + public: + NumberClassifierTRT( + const std::string& classify_model_path, + const std::string& classify_label_path + ); + void initNumberClassifier() override; + bool classifyNumber(ArmorObject& armor) override; + + private: + std::vector class_names_; + std::string classify_model_path_; + std::string classify_label_path_; + std::unique_ptr trt_net_; + nvinfer1::Dims input_dims_; + nvinfer1::Dims output_dims_; + }; +} // namespace auto_aim +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_detect/onnxruntime/armor_detector_onnxruntime.cpp b/wust_vision-main/tasks/auto_aim/armor_detect/onnxruntime/armor_detector_onnxruntime.cpp new file mode 100644 index 0000000..f0ad698 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_detect/onnxruntime/armor_detector_onnxruntime.cpp @@ -0,0 +1,141 @@ +#ifdef USE_ORT + + #include "armor_detector_onnxruntime.hpp" + #include "tasks/auto_aim/armor_detect/armor_detector_common.hpp" + #include "tasks/auto_aim/armor_detect/armor_infer.hpp" + #include "tasks/utils/utils.hpp" + #include "wust_vl/ml_net/onnxruntime/onnxruntime_net.hpp" +namespace wust_vision { +namespace auto_aim { + struct ArmorDetectorOnnxRuntime::Impl { + public: + Impl(const YAML::Node& config, bool use_armor_detect_common) { + if (use_armor_detect_common) { + armor_detect_common_ = std::make_unique(config); + } + std::string model_type = config["onnxruntime"]["model_type"].as(); + auto model = armor_infer::modeFromString(model_type); + float conf_threshold = config["onnxruntime"]["conf_threshold"].as(); + int top_k = config["onnxruntime"]["top_k"].as(); + float nms_threshold = config["onnxruntime"]["nms_threshold"].as(); + armor_infer_ = std::make_unique( + model, + conf_threshold, + nms_threshold, + top_k + ); + std::string provider = config["onnxruntime"]["provider"].as("CPU"); + provider_ = wust_vl::ml_net::string2OrtProvider(provider); + onnxruntime_net_ = std::make_unique(); + wust_vl::ml_net::OnnxRuntimeNet::Params params; + std::string model_path = + utils::expandEnv(config["onnxruntime"]["model_path"].as()); + params.model_path = model_path; + params.provider = provider_; + onnxruntime_net_->init(params); + } + + ~Impl() { + onnxruntime_net_.reset(); + armor_detect_common_.reset(); + } + void setCallback(DetectorCallback callback) { + infer_callback_ = callback; + } + void + processCallback(const CommonFrame& frame, const std::optional& target_number) { + Eigen::Matrix3f transform_matrix; + auto roi = frame.img_frame.src_img(frame.expanded); + cv::Mat resized_img = utils::letterbox( + roi, + transform_matrix, + armor_infer_->inputW(), + armor_infer_->inputH() + ); + const bool swap_rb = armor_infer_->inputRGB() + != (frame.img_frame.pixel_format == wust_vl::video::PixelFormat::RGB); + const float scale = armor_infer_->useNorm() ? 1.0f / 255.0f : 1.0f; + cv::Mat blob = cv::dnn::blobFromImage( + resized_img, + scale, + cv::Size(armor_infer_->inputW(), armor_infer_->inputH()), + cv::Scalar(0, 0, 0), + swap_rb + ); + + auto output_data = onnxruntime_net_->infer(blob.ptr(), blob.total()); + + auto output_shape = onnxruntime_net_->getOutputShape(); + int rows = static_cast(output_shape[1]); + int cols = static_cast(output_shape[2]); + cv::Mat output_buffer(rows, cols, CV_32F, output_data); + + // Parsed variable + std::vector objs_result; + objs_result = armor_infer_->postProcess(output_buffer); + std::vector armors; + if (armor_detect_common_) { + armors = armor_detect_common_->detectNet( + resized_img, + objs_result, + transform_matrix, + frame.detect_color, + target_number + ); + // Call callback function + if (this->infer_callback_) { + this->infer_callback_(armors, frame); + return; + } + } else { + for (auto obj: objs_result) { + auto detect_color = frame.detect_color; + if (detect_color == 0 && obj.color == ArmorColor::BLUE) { + continue; + } else if (detect_color == 1 && obj.color == ArmorColor::RED) { + continue; + } + obj.transform(transform_matrix); + armors.push_back(obj); + } + if (this->infer_callback_) { + this->infer_callback_(armors, frame); + return; + } + } + + return; + } + void pushInput(CommonFrame& frame, const std::optional& target_number) { + frame.id = current_id_++; + processCallback(frame, target_number); + } + + wust_vl::ml_net::OrtProvider provider_ = wust_vl::ml_net::OrtProvider::CPU; + DetectorCallback infer_callback_; + std::unique_ptr armor_detect_common_; + std::unique_ptr armor_infer_; + int current_id_ = 0; + std::unique_ptr onnxruntime_net_; + }; + ArmorDetectorOnnxRuntime::ArmorDetectorOnnxRuntime( + const YAML::Node& config, + bool use_armor_detect_common + ) { + _impl = std::make_unique(config, use_armor_detect_common); + } + ArmorDetectorOnnxRuntime::~ArmorDetectorOnnxRuntime() { + _impl.reset(); + } + void ArmorDetectorOnnxRuntime::setCallback(DetectorCallback callback) { + _impl->setCallback(callback); + } + void ArmorDetectorOnnxRuntime::pushInput( + CommonFrame& frame, + const std::optional& target_number + ) { + _impl->pushInput(frame, target_number); + } +} // namespace auto_aim +} // namespace wust_vision +#endif \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_detect/onnxruntime/armor_detector_onnxruntime.hpp b/wust_vision-main/tasks/auto_aim/armor_detect/onnxruntime/armor_detector_onnxruntime.hpp new file mode 100644 index 0000000..84bb05f --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_detect/onnxruntime/armor_detector_onnxruntime.hpp @@ -0,0 +1,36 @@ +// Copyright 2025 Xiaojian Wu +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "tasks/auto_aim/armor_detect/armor_detector_base.hpp" +namespace wust_vision { +namespace auto_aim { + class ArmorDetectorOnnxRuntime: public ArmorDetectorBase { + public: + using Ptr = std::unique_ptr; + + explicit ArmorDetectorOnnxRuntime(const YAML::Node& config, bool use_armor_detect_common); + static Ptr create(const YAML::Node& config, bool use_armor_detect_common) { + return std::make_unique(config, use_armor_detect_common); + } + ~ArmorDetectorOnnxRuntime(); + void + pushInput(CommonFrame& frame, const std::optional& target_number) override; + void setCallback(DetectorCallback callback) override; + + private: + struct Impl; + std::unique_ptr _impl; + }; +} // namespace auto_aim +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_detect/opencv/armor_detector_opencv.cpp b/wust_vision-main/tasks/auto_aim/armor_detect/opencv/armor_detector_opencv.cpp new file mode 100644 index 0000000..b25c01e --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_detect/opencv/armor_detector_opencv.cpp @@ -0,0 +1,382 @@ +// Copyright Chen Jun 2023. Licensed under the MIT License. +// +// Additional modifications and features by Chengfu Zou, Labor. Licensed under +// Apache License 2.0. +// +// Copyright (C) FYT Vision Group. All rights reserved. +// Copyright 2025 Xiaojian Wu +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tasks/auto_aim/armor_detect/opencv/armor_detector_opencv.hpp" +#include "tasks/auto_aim/armor_detect/number_classifier/number_classifier.hpp" +#include "tasks/utils/utils.hpp" +namespace wust_vision { +namespace auto_aim { + struct ArmorDetectorOpenCV::Impl { + public: + Impl(const YAML::Node& config) { + auto classify_model_path = + utils::expandEnv(config["classify"]["model_path"].as()); + auto classify_label_path = + utils::expandEnv(config["classify"]["label_path"].as()); + double classify_threshold = config["classify"]["threshold"].as(); + number_classifier_ = + std::make_unique(classify_model_path, classify_label_path); + light_params_.load(config["light"]); + armor_params_.load(config["armor"]); + } + + std::vector detect( + const cv::Mat& input, + int detect_color, + const std::optional& target_number + ) noexcept { + if (input.empty()) + return {}; + + std::vector lights_; + + cv::Mat binary_img, gray_img; + std::tie(binary_img, gray_img) = preprocessImage(input); + lights_ = findLights(input, binary_img); + std::vector armors = matchLights(lights_, detect_color); + std::vector valid_armors; + + for (auto& armor: armors) { + try { + armor.number_img = extractNumber(gray_img, armor); + if (armor.number_img.empty()) + continue; + + if (!number_classifier_->classifyNumber(armor)) + continue; + if (target_number.has_value()) { + if (!isSameTarget(target_number.value(), armor.number)) { + continue; + } + } + if (armor.confidence < classifier_threshold_) + continue; + + if (armor.number != ArmorNumber::NO1 && armor.number != ArmorNumber::BASE + && armor.type == ArmorType::LARGE) + { + continue; + } + + valid_armors.push_back(armor); + + } catch (const std::exception& e) { + std::cerr << "[detect] Exception: " << e.what() << std::endl; + } + } + return valid_armors; + } + + std::tuple preprocessImage(const cv::Mat& img) noexcept { + cv::Mat gray_img; + cv::Mat binary_img; + + if (img.empty()) { + return { binary_img, gray_img }; // 空图直接返回空 + } + + if (img.channels() == 3) { + cv::cvtColor(img, gray_img, cv::COLOR_RGB2GRAY); + } else if (img.channels() == 1) { + cv::cvtColor(img, gray_img, cv::COLOR_BayerRG2GRAY); + } else { + return { binary_img, gray_img }; + } + + cv::threshold(gray_img, binary_img, light_params_.binary_thres, 255, cv::THRESH_BINARY); + + return { binary_img, gray_img }; + } + + std::vector findLights(const cv::Mat& img, const cv::Mat& binary_img) noexcept { + std::vector> contours; + contours.reserve(64); + cv::findContours(binary_img, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE); + cv::Mat color_img; + if (img.channels() == 3) { + color_img = img; + } else if (img.channels() == 1) { + cv::cvtColor(img, color_img, cv::COLOR_BayerRG2BGR); + } else { + return {}; + } + std::vector lights; + lights.reserve(contours.size()); + + for (const auto& contour: contours) { + const int n = static_cast(contour.size()); + if (n < 6) + continue; + Light light(contour); + if (!isLight(light)) + continue; + int sum_r = 0; + int sum_b = 0; + for (const auto& pt: contour) { + const cv::Vec3b& pix = color_img.at(pt); + sum_r += pix[0]; + sum_b += pix[2]; + } + + const int avg_diff = std::abs(sum_r - sum_b) / n; + if (avg_diff <= light_params_.color_diff_thresh) + continue; + + light.color = (sum_r > sum_b) ? 0 : 1; + lights.emplace_back(std::move(light)); + } + + std::sort(lights.begin(), lights.end(), [](const Light& a, const Light& b) { + return a.center.x < b.center.x; + }); + + return lights; + } + + bool isLight(const Light& light) noexcept { + // width / length 比例 + const float ratio = light.width / light.length; + + if (ratio <= light_params_.min_ratio || ratio >= light_params_.max_ratio) + return false; + + if (light.tilt_angle >= light_params_.max_angle) + return false; + + return true; + } + + std::vector + matchLights(const std::vector& lights, int detect_color) noexcept { + const int n = static_cast(lights.size()); + std::vector armors; + armors.reserve(n); + + for (int i = 0; i < n; ++i) { + const Light& l1 = lights[i]; + if (l1.color != detect_color) + continue; + + const float max_dx = l1.length * armor_params_.max_large_center_distance; + + for (int j = i + 1; j < n; ++j) { + const Light& l2 = lights[j]; + + if (l2.color != detect_color) + continue; + + const float dx = l2.center.x - l1.center.x; + if (dx > max_dx) + break; + + ArmorType type = isArmor(l1, l2); + if (type == ArmorType::INVALID) + continue; + + if (containLight(i, j, lights)) + continue; + + ArmorObject armor(l1, l2); + armor.type = type; + armor.color = (detect_color == 0) ? ArmorColor::RED : ArmorColor::BLUE; + + armors.emplace_back(std::move(armor)); + } + } + + return armors; + } + + // Check if there is another light in the boundingRect formed by the 2 lights + bool containLight(const int i, const int j, const std::vector& lights) noexcept { + const Light& l1 = lights[i]; + const Light& l2 = lights[j]; + + float min_x = std::min({ l1.top.x, l1.bottom.x, l2.top.x, l2.bottom.x }); + float max_x = std::max({ l1.top.x, l1.bottom.x, l2.top.x, l2.bottom.x }); + float min_y = std::min({ l1.top.y, l1.bottom.y, l2.top.y, l2.bottom.y }); + float max_y = std::max({ l1.top.y, l1.bottom.y, l2.top.y, l2.bottom.y }); + + const float avg_len = 0.5f * (l1.length + l2.length); + const float avg_wid = 0.5f * (l1.width + l2.width); + + for (int k = i + 1; k < j; ++k) { + const Light& t = lights[k]; + + if (t.width > 2.0f * avg_wid) + continue; + if (t.length < 0.5f * avg_len) + continue; + + const cv::Point2f& c = t.center; + if (c.x >= min_x && c.x <= max_x && c.y >= min_y && c.y <= max_y) { + return true; + } + } + return false; + } + + ArmorType isArmor(const Light& l1, const Light& l2) noexcept { + const float len1 = l1.length; + const float len2 = l2.length; + if (len1 <= 1e-3f || len2 <= 1e-3f) + return ArmorType::INVALID; + + const float min_len = (len1 < len2) ? len1 : len2; + const float max_len = (len1 < len2) ? len2 : len1; + if (min_len / max_len <= armor_params_.min_light_ratio) + return ArmorType::INVALID; + const cv::Point2f d = l1.center - l2.center; + const float dist2 = d.dot(d); + + const float avg_len = 0.5f * (len1 + len2); + + const float min_small = armor_params_.min_small_center_distance * avg_len; + const float max_small = armor_params_.max_small_center_distance * avg_len; + const float min_large = armor_params_.min_large_center_distance * avg_len; + const float max_large = armor_params_.max_large_center_distance * avg_len; + + const float min_small2 = min_small * min_small; + const float max_small2 = max_small * max_small; + const float min_large2 = min_large * min_large; + const float max_large2 = max_large * max_large; + + const bool small_ok = (dist2 >= min_small2 && dist2 < max_small2); + const bool large_ok = (dist2 >= min_large2 && dist2 < max_large2); + + if (!(small_ok || large_ok)) + return ArmorType::INVALID; + + static const float tan_max_angle = std::tan(armor_params_.max_angle * CV_PI / 180.0f); + + if (std::abs(d.y) >= std::abs(d.x) * tan_max_angle) + return ArmorType::INVALID; + + float delta_angle = std::fabs(l1.angle - l2.angle); + if (delta_angle > 90.0f) + delta_angle = 180.0f - delta_angle; + + if (delta_angle >= light_params_.max_angle_diff) + return ArmorType::INVALID; + + return large_ok ? ArmorType::LARGE : ArmorType::SMALL; + } + + cv::Mat extractNumber(const cv::Mat& src, const ArmorObject& armor) const noexcept { + constexpr int light_length = 12; + constexpr int warp_height = 28; + constexpr int small_armor_width = 32; + constexpr int large_armor_width = 54; + const cv::Size roi_size(20, 28); + + cv::Point2f src_pts[4] = { armor.lights[0].bottom, + armor.lights[0].top, + armor.lights[1].top, + armor.lights[1].bottom }; + + const int warp_width = + (armor.type == ArmorType::SMALL) ? small_armor_width : large_armor_width; + + const int top_y = (warp_height - light_length) / 2 - 1; + const int bottom_y = top_y + light_length; + + cv::Point2f dst_pts[4] = { + { 0.f, static_cast(bottom_y) }, + { 0.f, static_cast(top_y) }, + { static_cast(warp_width - 1), static_cast(top_y) }, + { static_cast(warp_width - 1), static_cast(bottom_y) } + }; + + cv::Mat warp_mat = cv::getPerspectiveTransform(src_pts, dst_pts); + + cv::Mat warped; + cv::warpPerspective( + src, + warped, + warp_mat, + cv::Size(warp_width, warp_height), + cv::INTER_LINEAR, + cv::BORDER_CONSTANT, + 0 + ); + + const int roi_x = (warp_width - roi_size.width) >> 1; + if (roi_x < 0 || roi_x + roi_size.width > warp_width) + return cv::Mat(); + + cv::Mat number = warped(cv::Rect(roi_x, 0, roi_size.width, roi_size.height)); + + cv::threshold(number, number, 0, 255, cv::THRESH_BINARY | cv::THRESH_OTSU); + + return number; + } + + void setCallback(DetectorCallback callback) { + this->infer_callback_ = callback; + } + void toPts(ArmorObject& armor) { + if (armor.lights.size() != 2) { + armor.is_ok = false; + return; + } + armor.pts[0] = armor.lights[0].top; + armor.pts[1] = armor.lights[0].bottom; + armor.pts[2] = armor.lights[1].bottom; + armor.pts[3] = armor.lights[1].top; + } + + void pushInput(CommonFrame& frame, const std::optional& target_number) { + frame.id = current_id_++; + std::vector objs_result; + auto roi = frame.img_frame.src_img(frame.expanded); + + objs_result = detect(roi, frame.detect_color, target_number); + + if (this->infer_callback_) { + this->infer_callback_(objs_result, frame); + return; + } + return; + } + + LightParams light_params_; + ArmorParams armor_params_; + double classifier_threshold_ = 0.5; + std::unique_ptr number_classifier_; + DetectorCallback infer_callback_; + int current_id_ = 0; + }; + ArmorDetectorOpenCV::ArmorDetectorOpenCV(const YAML::Node& config) { + _impl = std::make_unique(config); + } + ArmorDetectorOpenCV::~ArmorDetectorOpenCV() { + _impl.reset(); + } + void ArmorDetectorOpenCV::setCallback(DetectorCallback callback) { + _impl->setCallback(callback); + } + void ArmorDetectorOpenCV::pushInput( + CommonFrame& frame, + const std::optional& target_number + ) { + _impl->pushInput(frame, target_number); + } +} // namespace auto_aim +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_detect/opencv/armor_detector_opencv.hpp b/wust_vision-main/tasks/auto_aim/armor_detect/opencv/armor_detector_opencv.hpp new file mode 100644 index 0000000..9f690e7 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_detect/opencv/armor_detector_opencv.hpp @@ -0,0 +1,42 @@ +// Copyright Chen Jun 2023. Licensed under the MIT License. +// +// Additional modifications and features by Chengfu Zou, Labor. Licensed under +// Apache License 2.0. +// +// Copyright (C) FYT Vision Group. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "tasks/auto_aim/armor_detect/armor_detector_base.hpp" +#include "tasks/type_common.hpp" +namespace wust_vision { +namespace auto_aim { + class ArmorDetectorOpenCV: public ArmorDetectorBase { + public: + using Ptr = std::unique_ptr; + explicit ArmorDetectorOpenCV(const YAML::Node& config); + static Ptr create(const YAML::Node& config) { + return std::make_unique(config); + } + ~ArmorDetectorOpenCV(); + void + pushInput(CommonFrame& frame, const std::optional& target_number) override; + void setCallback(DetectorCallback callback) override; + + private: + struct Impl; + std::unique_ptr _impl; + }; +} // namespace auto_aim +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_detect/openvino/armor_detector_openvino.cpp b/wust_vision-main/tasks/auto_aim/armor_detect/openvino/armor_detector_openvino.cpp new file mode 100644 index 0000000..47e1b8f --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_detect/openvino/armor_detector_openvino.cpp @@ -0,0 +1,206 @@ +// Copyright 2025 Zikang Xie +// Copyright 2025 Xiaojian Wu +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifdef USE_OPENVINO + #include "tasks/auto_aim/armor_detect/openvino/armor_detector_openvino.hpp" + #include "tasks/auto_aim/armor_detect/armor_detector_common.hpp" + #include "tasks/auto_aim/armor_detect/armor_infer.hpp" + #include "tasks/utils/utils.hpp" + #include "wust_vl/ml_net/openvino/openvino_net.hpp" +namespace wust_vision { +namespace auto_aim { + struct ArmorDetectorOpenVino::Impl { + public: + Impl(const YAML::Node& config, bool use_armor_detect_common) { + if (use_armor_detect_common) { + armor_detect_common_ = std::make_unique(config); + } + std::string model_type = config["openvino"]["model_type"].as(); + auto model = armor_infer::modeFromString(model_type); + float conf_threshold = config["openvino"]["conf_threshold"].as(); + int top_k = config["openvino"]["top_k"].as(); + float nms_threshold = config["openvino"]["nms_threshold"].as(); + bool use_throughputmode = config["openvino"]["use_throughputmode"].as(); + armor_infer_ = std::make_unique( + model, + conf_threshold, + nms_threshold, + top_k + ); + std::string model_path = + utils::expandEnv(config["openvino"]["model_path"].as()); + + auto device_name = config["openvino"]["device_name"].as(); + ov_params_.model_path = model_path; + ov_params_.device_name = device_name; + ov_params_.mode = use_throughputmode ? ov::hint::PerformanceMode::THROUGHPUT + : ov::hint::PerformanceMode::LATENCY; + initOpenVINO(); + } + void + initOpenVINO(wust_vl::video::PixelFormat pixel_format = wust_vl::video::PixelFormat::BGR) { + openvino_net_.reset(); + openvino_net_ = std::make_unique(); + const auto ppp_init_fun = [this, pixel_format](ov::preprocess::PrePostProcessor& ppp) { + if (pixel_format == wust_vl::video::PixelFormat::RGB) { + ppp.input() + .tensor() + .set_element_type(ov::element::u8) + .set_layout("NHWC") + .set_color_format(ov::preprocess::ColorFormat::RGB); + } else { + ppp.input() + .tensor() + .set_element_type(ov::element::u8) + .set_layout("NHWC") + .set_color_format(ov::preprocess::ColorFormat::BGR); + } + pixel_format_ = pixel_format; + const bool RGB = armor_infer_->inputRGB(); + const float scale = armor_infer_->useNorm() ? 255.0f : 1.0f; + if (RGB) { + ppp.input() + .preprocess() + .convert_element_type(ov::element::f32) + .convert_color(ov::preprocess::ColorFormat::RGB) + .scale(scale); + } else { + ppp.input() + .preprocess() + .convert_element_type(ov::element::f32) + .convert_color(ov::preprocess::ColorFormat::BGR) + .scale(scale); + } + + ppp.input().model().set_layout("NCHW"); + + ppp.output().tensor().set_element_type(ov::element::f32); + }; + + openvino_net_->init(ov_params_, ppp_init_fun); + } + + ~Impl() { + openvino_net_.reset(); + armor_detect_common_.reset(); + } + + void setCallback(DetectorCallback callback) { + infer_callback_ = callback; + } + bool processCallback( + const CommonFrame& frame, + const std::optional& target_number + ) const { + const auto start = std::chrono::steady_clock::now(); + Eigen::Matrix3f transform_matrix; + const auto roi = frame.img_frame.src_img(frame.expanded); + cv::Mat resized_img = utils::letterbox( + roi, + transform_matrix, + armor_infer_->inputW(), + armor_infer_->inputH() + ); + const auto input_info = openvino_net_->getInputInfo(); + const auto input_tensor = + ov::Tensor(input_info.first, input_info.second, resized_img.data); + const auto output = openvino_net_->infer_thread_local(input_tensor); + + // Process output data + const auto output_shape = output.get_shape(); + + const float* ptr = output.data(); + cv::Mat + output_buffer(output_shape[1], output_shape[2], CV_32F, const_cast(ptr)); + + // Parsed variable + auto objs_result = armor_infer_->postProcess(output_buffer); + + std::vector armors; + if (armor_detect_common_) { + armors = armor_detect_common_->detectNet( + resized_img, + objs_result, + transform_matrix, + frame.detect_color, + target_number + ); + // Call callback function + if (this->infer_callback_) { + this->infer_callback_(armors, frame); + return true; + } + } else { + for (auto obj: objs_result) { + auto detect_color = frame.detect_color; + if (detect_color == 0 && obj.color == ArmorColor::BLUE) { + continue; + } else if (detect_color == 1 && obj.color == ArmorColor::RED) { + continue; + } + obj.transform(transform_matrix); + armors.push_back(obj); + } + if (this->infer_callback_) { + this->infer_callback_(armors, frame); + return true; + } + } + + return false; + } + void pushInput(CommonFrame& frame, const std::optional& target_number) { + if (resetting_) { + return; + } + frame.id = current_id_++; + if (frame.img_frame.pixel_format != pixel_format_) { + resetting_ = true; + initOpenVINO(frame.img_frame.pixel_format); + resetting_ = false; + } + processCallback(frame, target_number); + } + + private: + wust_vl::video::PixelFormat pixel_format_ = wust_vl::video::PixelFormat::BGR; + std::unique_ptr openvino_net_; + DetectorCallback infer_callback_; + std::unique_ptr armor_detect_common_; + std::unique_ptr armor_infer_; + int current_id_ = 0; + wust_vl::ml_net::OpenvinoNet::Params ov_params_; + bool resetting_ = false; + }; + ArmorDetectorOpenVino::ArmorDetectorOpenVino( + const YAML::Node& config, + bool use_armor_detect_common + ) { + _impl = std::make_unique(config, use_armor_detect_common); + } + ArmorDetectorOpenVino::~ArmorDetectorOpenVino() { + _impl.reset(); + } + void ArmorDetectorOpenVino::setCallback(DetectorCallback callback) { + _impl->setCallback(callback); + } + void ArmorDetectorOpenVino::pushInput( + CommonFrame& frame, + const std::optional& target_number + ) { + _impl->pushInput(frame, target_number); + } +} // namespace auto_aim +} // namespace wust_vision +#endif \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_detect/openvino/armor_detector_openvino.hpp b/wust_vision-main/tasks/auto_aim/armor_detect/openvino/armor_detector_openvino.hpp new file mode 100644 index 0000000..5a1564a --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_detect/openvino/armor_detector_openvino.hpp @@ -0,0 +1,38 @@ +// Copyright 2023 Yunlong Feng +// Copyright 2025 Lihan Chen +// Copyright 2025 XiaoJian Wu +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "tasks/auto_aim/armor_detect/armor_detector_base.hpp" +namespace wust_vision { +namespace auto_aim { + class ArmorDetectorOpenVino: public ArmorDetectorBase { + public: + using Ptr = std::unique_ptr; + explicit ArmorDetectorOpenVino(const YAML::Node& config, bool use_armor_detect_common); + static Ptr create(const YAML::Node& config, bool use_armor_detect_common) { + return std::make_unique(config, use_armor_detect_common); + } + ~ArmorDetectorOpenVino(); + void + pushInput(CommonFrame& frame, const std::optional& target_number) override; + + void setCallback(DetectorCallback callback) override; + + private: + struct Impl; + std::unique_ptr _impl; + }; +} // namespace auto_aim +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_detect/tensorrt/armor_detector_tensorrt.cpp b/wust_vision-main/tasks/auto_aim/armor_detect/tensorrt/armor_detector_tensorrt.cpp new file mode 100644 index 0000000..d743cf8 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_detect/tensorrt/armor_detector_tensorrt.cpp @@ -0,0 +1,320 @@ +// Copyright 2025 Zikang Xie +// Copyright 2025 Xiaojian Wu +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifdef USE_TRT + #include "tasks/auto_aim/armor_detect/tensorrt/armor_detector_tensorrt.hpp" + #include "cuda_infer/armor_infer.hpp" + #include "tasks/auto_aim/armor_detect/armor_detector_common.hpp" + #include "tasks/auto_aim/armor_detect/armor_infer.hpp" + #include "tasks/utils/utils.hpp" + #include "wust_vl/common/concurrency/adaptive_resource_pool.hpp" + #include "wust_vl/common/utils/logger.hpp" + #include "wust_vl/common/utils/timer.hpp" + #include "wust_vl/ml_net/tensorrt/tensorrt_net.hpp" +namespace wust_vision { +namespace auto_aim { + static constexpr int MAX_SRC_IMG_W = 1920; + static constexpr int MAX_SRC_IMG_H = 1440; + struct ArmorDetectorTrt::Impl { + public: + struct Infer { + std::unique_ptr context; + std::unique_ptr cuda_infer; + }; + + Impl(const YAML::Node& config, bool use_armor_detect_common) { + if (use_armor_detect_common) { + armor_detect_common_ = std::make_unique(config); + } + const double conf_threshold = config["tensorrt"]["conf_threshold"].as(); + const double nms_threshold = config["tensorrt"]["nms_threshold"].as(); + const int top_k = config["tensorrt"]["top_k"].as(); + const int max_infer_running = config["tensorrt"]["max_infer_running"].as(); + const double min_free_mem_ratio = config["tensorrt"]["min_free_mem_ratio"].as(); + use_cuda_pre_ = config["tensorrt"]["use_cuda_pre"].as(); + log_time_ = config["tensorrt"]["log_time"].as(); + const std::string model_type = config["tensorrt"]["model_type"].as(); + const std::string model_path = + utils::expandEnv(config["tensorrt"]["model_path"].as()); + const int device_id = config["tensorrt"]["device_id"].as(); + cudaSetDevice(device_id); + const auto model = armor_infer::modeFromString(model_type); + armor_infer_ = std::make_unique( + model, + conf_threshold, + nms_threshold, + top_k + ); + trt_net_ = std::make_unique(); + wust_vl::ml_net::TensorRTNet::Params trt_params; + trt_params.model_path = model_path; + trt_params.input_dims = + nvinfer1::Dims4 { 1, 3, armor_infer_->inputH(), armor_infer_->inputW() }; + trt_net_->init(trt_params); + const auto input_output_dims = trt_net_->getInputOutputDims(); + input_dims_ = std::get<0>(input_output_dims); + output_dims_ = std::get<1>(input_output_dims); + + wust_vl::common::concurrency::AdaptiveResourcePool::Params pool_params; + pool_params.resource_initializer = [&]() { + std::vector> infers; + for (int i = 0; i < max_infer_running; ++i) { + auto infer = std::make_unique(); + auto ctx = trt_net_->getAContext(); + infer->context = std::unique_ptr(ctx); + if (use_cuda_pre_) { + infer->cuda_infer = std::make_unique(); + infer->cuda_infer->init( + MAX_SRC_IMG_W, + MAX_SRC_IMG_H, + armor_infer_->inputW(), + armor_infer_->inputH() + ); + } + if (!infer->context) { + WUST_ERROR("TRT") << "create infer failed, missing context" + << " index:" << i; + continue; + } + if (use_cuda_pre_ && !infer->cuda_infer) { + WUST_ERROR("TRT") << "create infer failed, missing cuda_infer" + << " index:" << i; + continue; + } + + size_t free_mem, total_mem; + cudaMemGetInfo(&free_mem, &total_mem); + WUST_DEBUG("TRT") << "Free GPU memory:" << free_mem / 1024.0 / 1024.0 << "MB" + << "Total GPU memory:" << total_mem / 1024.0 / 1024.0 << "MB"; + double free_mem_ratio = + static_cast(free_mem) / static_cast(total_mem); + if (free_mem_ratio < min_free_mem_ratio && i > 0) { + WUST_WARN("TRT") << "GPU memory is not enough!" + << "Free GPU memory:" << free_mem_ratio * 100 << "%"; + WUST_INFO("TRT") << "Cut remaining infer"; + break; + } + infers.emplace_back(std::move(infer)); + WUST_INFO("TRT") << "create execution context success" + << "index:" << i; + } + return infers; + }; + auto release_func = [&](std::unique_ptr& resource) { + if (resource) { + if (resource->cuda_infer) { + resource->cuda_infer.reset(); + } + } + }; + auto restore_func = [&](size_t idx) -> std::unique_ptr { + auto infer = std::make_unique(); + auto ctx = trt_net_->getAContext(); + infer->context = std::unique_ptr(ctx); + if (use_cuda_pre_) { + infer->cuda_infer = std::make_unique(); + infer->cuda_infer->init( + MAX_SRC_IMG_W, + MAX_SRC_IMG_H, + armor_infer_->inputW(), + armor_infer_->inputH() + ); + } + if (!infer->context) { + WUST_ERROR("TRT") << "create infer failed, missing context"; + return nullptr; + } + if ((use_cuda_pre_) && !infer->cuda_infer) { + WUST_ERROR("TRT") << "create infer failed, missing cuda_infer"; + return nullptr; + } + return infer; + }; + pool_params.restore_func = restore_func; + + pool_params.release_func = release_func; + + pool_params.can_restore = [&](size_t active_count) { return false; }; + + pool_params.should_release = [&](size_t active_count) { return false; }; + + pool_params.logger = [](const std::string& msg) { + WUST_INFO("ArmorDetectorTrt:infer pool") << msg; + }; + infer_pool_ = + std::make_unique>( + pool_params + ); + } + + ~Impl() { + if (infer_pool_) { + infer_pool_.reset(); + } + trt_net_.reset(); + armor_detect_common_.reset(); + } + + void setCallback(DetectorCallback callback) { + infer_callback_ = callback; + } + struct Tag {}; + void processCallback( + const CommonFrame& frame, + Infer* infer, + const std::optional& target_number + ) const { + std::vector armors; + const auto t0 = wust_vl::common::utils::time_utils::now(); + Eigen::Matrix3f transform_matrix; + std::vector objs_result; + void* input_tensor_ptr; + const cv::Mat roi = frame.img_frame.src_img(frame.expanded); + + cv::Mat resized_img; + const float scale = armor_infer_->useNorm() ? 1.0f / 255.0f : 1.0f; + const bool swap_rb = armor_infer_->inputRGB() + != (frame.img_frame.pixel_format == wust_vl::video::PixelFormat::RGB); + if (infer->cuda_infer && use_cuda_pre_) { + input_tensor_ptr = + infer->cuda_infer + ->preprocess_pitched( //支持不连续内存,无需拷贝后输入可直接传roi的指针 + roi.data, + roi.cols, + roi.rows, + roi.step, + scale, + swap_rb, + transform_matrix, + trt_net_->getStream() + ); + resized_img = infer->cuda_infer->tensorToMat( //nchw_float_to_hwc_uchar + static_cast(input_tensor_ptr), + armor_infer_->inputW(), + armor_infer_->inputH(), + scale, + trt_net_->getStream() + ); + } else { + resized_img = utils::letterbox( + roi, + transform_matrix, + armor_infer_->inputW(), + armor_infer_->inputH() + ); + const cv::Mat blob = cv::dnn::blobFromImage( + resized_img, + scale, + cv::Size(armor_infer_->inputW(), armor_infer_->inputH()), + cv::Scalar(0, 0, 0), + swap_rb + ); + trt_net_->input2Device(blob.ptr()); + input_tensor_ptr = trt_net_->getInputTensorPtr(); + } + const auto t1 = wust_vl::common::utils::time_utils::now(); + if (infer->context && input_tensor_ptr) { + trt_net_->infer(input_tensor_ptr, infer->context.get()); + } + const auto t2 = wust_vl::common::utils::time_utils::now(); + const cv::Mat + output_mat(output_dims_.d[1], output_dims_.d[2], CV_32F, trt_net_->output2Host()); + cudaStreamSynchronize(trt_net_->getStream()); + objs_result = armor_infer_->postProcess(output_mat); + const auto t3 = wust_vl::common::utils::time_utils::now(); + if (log_time_) { + WUST_INFO("TRT") << std::fixed << std::setprecision(3) << "pre " + << wust_vl::common::utils::time_utils::durationMs(t0, t1) << " " + << "infer " + << wust_vl::common::utils::time_utils::durationMs(t1, t2) << " " + << "post " + << wust_vl::common::utils::time_utils::durationMs(t2, t3) << " " + << "total " + << wust_vl::common::utils::time_utils::durationMs(t0, t3); + } + infer_pool_->release(infer); + + if (armor_detect_common_) { + armors = armor_detect_common_->detectNet( + resized_img, + objs_result, + transform_matrix, + frame.detect_color, + target_number + ); + // Call callback function + if (this->infer_callback_) { + this->infer_callback_(armors, frame); + return; + } + } else { + for (auto obj: objs_result) { + auto detect_color = frame.detect_color; + if (detect_color == 0 && obj.color == ArmorColor::BLUE) { + continue; + } else if (detect_color == 1 && obj.color == ArmorColor::RED) { + continue; + } + obj.transform(transform_matrix); + armors.push_back(obj); + } + if (this->infer_callback_) { + this->infer_callback_(armors, frame); + return; + } + } + + return; + } + + void pushInput(CommonFrame& frame, const std::optional& target_number) { + if (infer_pool_) { + auto infer_ptr = infer_pool_->acquire(); + if (infer_ptr != nullptr) { + frame.id = current_id_++; + this->processCallback(frame, infer_ptr, target_number); + } + } + } + + private: + bool use_cuda_pre_; + bool log_time_; + nvinfer1::Dims input_dims_; + nvinfer1::Dims output_dims_; + DetectorCallback infer_callback_; + std::unique_ptr armor_detect_common_; + std::unique_ptr> infer_pool_; + std::unique_ptr armor_infer_; + int current_id_ = 0; + std::unique_ptr trt_net_; + }; + ArmorDetectorTrt::ArmorDetectorTrt(const YAML::Node& config, bool use_armor_detect_common) { + _impl = std::make_unique(config, use_armor_detect_common); + } + ArmorDetectorTrt::~ArmorDetectorTrt() { + _impl.reset(); + } + void ArmorDetectorTrt::setCallback(DetectorCallback callback) { + _impl->setCallback(callback); + } + void ArmorDetectorTrt::pushInput( + CommonFrame& frame, + const std::optional& target_number + ) { + _impl->pushInput(frame, target_number); + } +} // namespace auto_aim +} // namespace wust_vision +#endif \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_detect/tensorrt/armor_detector_tensorrt.hpp b/wust_vision-main/tasks/auto_aim/armor_detect/tensorrt/armor_detector_tensorrt.hpp new file mode 100644 index 0000000..5aed13a --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_detect/tensorrt/armor_detector_tensorrt.hpp @@ -0,0 +1,40 @@ +// Copyright 2025 Zikang Xie +// Copyright 2025 XiaoJian Wu +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "tasks/auto_aim/armor_detect/armor_detector_base.hpp" +namespace wust_vision { +namespace auto_aim { + class ArmorDetectorTrt: public ArmorDetectorBase { + public: + using Ptr = std::unique_ptr; + + explicit ArmorDetectorTrt(const YAML::Node& config, bool use_armor_detect_common); + static Ptr create(const YAML::Node& config, bool use_armor_detect_common) { + return std::make_unique(config, use_armor_detect_common); + } + ~ArmorDetectorTrt(); + void + pushInput(CommonFrame& frame, const std::optional& target_number) override; + + void setCallback(DetectorCallback callback) override; + + private: + struct Impl; + std::unique_ptr _impl; + }; +} // namespace auto_aim +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_omni/armor_omni.cpp b/wust_vision-main/tasks/auto_aim/armor_omni/armor_omni.cpp new file mode 100644 index 0000000..f61e8c9 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_omni/armor_omni.cpp @@ -0,0 +1,442 @@ +#include "armor_omni.hpp" +#include "3rdparty/angles.h" +#include "tasks/auto_aim/armor_tracker/motion_models/motion_modelypdv2.hpp" +#include "tasks/auto_aim/armor_tracker/target.hpp" +#include "tasks/auto_aim/auto_aim_fsm.hpp" +#include "tasks/auto_aim/type.hpp" +#include "wust_vl/common/concurrency/ThreadPool.h" +#include "wust_vl/common/utils/timer.hpp" +#include "wust_vl/video/camera.hpp" +// clang-format off +#include "tasks/auto_aim/armor_detect/armor_detector_factory.hpp" +// clang-format on +#include "tasks/auto_aim/armor_tracker/trackerv3.hpp" +#include "tasks/auto_aim/armor_where/armor_where.hpp" +namespace wust_vision::auto_aim { + +struct ArmorOmni::Impl { + struct One { + using Ptr = std::shared_ptr; + + One(int id) { + self_id = id; + total_score = 0; + } + + static Ptr create(int id) { + return std::make_shared(id); + } + + void load( + const YAML::Node& config, + wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter + ) noexcept { + auto yaw_in_big_yaw_deg = config["yaw_in_big_yaw_deg"].as(); + yaw_in_big_yaw = yaw_in_big_yaw_deg / 180.0 * M_PI; + camera = std::make_shared(); + camera->init(config); + std::string camera_info_path = + utils::expandEnv(config["camera_info_path"].as()); + YAML::Node config_camera_info = YAML::LoadFile(camera_info_path); + std::vector camera_k = + config_camera_info["camera_matrix"]["data"].as>(); + std::vector camera_d = + config_camera_info["distortion_coefficients"]["data"].as>(); + + assert(camera_k.size() == 9); + assert(camera_d.size() == 5); + + cv::Mat K(3, 3, CV_64F); + std::memcpy(K.data, camera_k.data(), 9 * sizeof(double)); + + cv::Mat D(1, 5, CV_64F); + std::memcpy(D.data, camera_d.data(), 5 * sizeof(double)); + + camera_info = std::make_pair(K.clone(), D.clone()); + auto gobal_config = auto_aim_config_parameter->getConfig(); + armor_where = ArmorWhere::create(gobal_config["armor_where"], camera_info); + tracker = Tracker::create(auto_aim_config_parameter); + } + + void start() noexcept { + if (camera) + camera->start(); + } + + int self_id; + double total_score; + std::shared_ptr camera; + ArmorWhere::Ptr armor_where; + Tracker::Ptr tracker; + std::pair camera_info; + double yaw_in_big_yaw; + Target target; + }; + + struct Obj { + ArmorObject armor; + double score = 0; + One::Ptr one; + std::chrono::steady_clock::time_point timestamp; + + Obj(const ArmorObject& a, + double s, + const One::Ptr& o, + std::chrono::steady_clock::time_point ts): + armor(a), + score(s), + one(o), + timestamp(ts) { + if (one) + one->total_score += score; + } + + ~Obj() { + if (one) + one->total_score -= score; + } + }; + + static constexpr const char* _ML_CONFIG = "config/omni/detect_ml.yaml"; + static constexpr const char* _OPENCV_CONFIG = "config/omni/detect_opencv.yaml"; + + ~Impl() { + run_flag_ = false; + } + + Impl(bool detect_color_init, const Ctx& ctx) { + ctx_ = ctx; + detect_color_ = detect_color_init; + + config_ = YAML::LoadFile(OMNI_CONFIG); + auto_aim_config_parameter_ = wust_vl::common::utils::Parameter::create(); + auto_aim_config_parameter_->loadFromFile(OMNI_CONFIG); + auto cameras = config_["cameras"].as>(); + + for (size_t i = 0; i < cameras.size(); ++i) { + auto real_path = utils::expandEnv(cameras[i]); + One::Ptr one = One::create(i); + one->load(YAML::LoadFile(real_path), auto_aim_config_parameter_); + ones_.emplace_back(one); + } + auto_aim_config_parameter_->reloadFromOldPath(); + fps_ = config_["fps"].as(30); + active_time_ = config_["active_time"].as(0.5); + max_infer_running_ = config_["max_infer_running"].as(0); + min_score_ = config_["min_score"].as(); + const std::string armor_detect_backend = + config_["armor_detect_backend"].as(""); + + armor_detector_ = DetectorFactory::createArmorDetector( + armor_detect_backend, + false, + _OPENCV_CONFIG, + _ML_CONFIG + ); + + armor_detector_->setCallback(std::bind( + &ArmorOmni::Impl::ArmorDetectCallback, + this, + std::placeholders::_1, + std::placeholders::_2 + )); + + thread_pool_ = + std::make_unique(max_infer_running_); + + timer_ = std::make_unique("omni"); + latency_averager_ = std::make_unique>(100); + } + + void start() noexcept { + run_flag_ = true; + + for (auto& one: ones_) { + one->start(); + } + + if (timer_) { + const auto timercallback = + std::bind(&ArmorOmni::Impl::timerCallback, this, std::placeholders::_1); + + const double rate_hz = fps_; + + timer_->start(rate_hz, timercallback); + } + } + + int getOneId() const { + static int one_id = 0; + + int id = one_id; + + one_id = (one_id + 1) % ones_.size(); + + return id; + } + + void timerCallback(double dt_ms) noexcept { + if (!run_flag_ || main_tracking_) + return; + + int one_id = getOneId(); + + auto& one = ones_[one_id]; + + auto frame = one->camera->readImage(); + + if (frame.src_img.empty()) + return; + + CommonFrame common_frame; + + common_frame.img_frame = frame; + common_frame.id = one_id; + common_frame.detect_color = detect_color_; + common_frame.expanded = cv::Rect(0, 0, frame.src_img.cols, frame.src_img.rows); + common_frame.offset = cv::Point2f(0, 0); + common_frame.any_ctx = one; + + detect(common_frame); + } + + void detect(CommonFrame& common_frame) { + if (infer_running_count_ >= max_infer_running_ || !thread_pool_ || !run_flag_) { + return; + } + + infer_running_count_++; + + if (common_frame.img_frame.src_img.empty()) { + infer_running_count_--; + return; + } + + if (armor_detector_) { + armor_detector_->pushInput(common_frame, std::nullopt); + } + + infer_running_count_--; + } + + void setDetectColor(bool flag) noexcept { + detect_color_ = flag; + } + + void updateMainTracking(bool flag) noexcept { + main_tracking_ = flag; + } + + int getBestTarget() noexcept { + update(); + return best_target_; + } + + void + ArmorDetectCallback(const std::vector& objs, const CommonFrame& frame) noexcept { + auto one = std::any_cast(frame.any_ctx); + std::vector sorted_objs; + + for (const auto& obj: objs) { + if (obj.color == ArmorColor::NONE || obj.color == ArmorColor::PURPLE) { + continue; + } + sorted_objs.push_back(obj); + std::lock_guard lock(active_results_mutex_); + active_results_.emplace_back(obj, obj.confidence, one, frame.img_frame.timestamp); + } + update(); + Armors armors; + armors.timestamp = frame.img_frame.timestamp; + Eigen::Matrix3d R_gimbal2odom = Eigen::Matrix3d::Identity(); + auto& car_b = ctx_.car_motion_buffer; + auto& big_yaw_b = ctx_.big_yaw_motion_buffer; + + if (car_b && big_yaw_b) { + const auto t_query = armors.timestamp; + + auto apply_motion = [&](const auto& att, const auto& att2) { + R_gimbal2odom = + Eigen::AngleAxisd( + angles::normalize_angle(att2.data.big_yaw + one->yaw_in_big_yaw), + Eigen::Vector3d::UnitZ() + ) + * Eigen::AngleAxisd(0.0, Eigen::Vector3d::UnitY()) + * Eigen::AngleAxisd(att.data.roll, Eigen::Vector3d::UnitX()); + }; + + auto car_past_att = car_b->get_interpolated(t_query); + auto big_yaw_past_att = big_yaw_b->get_interpolated(t_query); + + if (car_past_att && big_yaw_past_att) { + apply_motion(*car_past_att, *big_yaw_past_att); + } else { + auto last_att = car_b->get_last(); + auto last_big_yaw = big_yaw_b->get_last(); + + if (last_att && last_big_yaw) { + apply_motion(*last_att, *last_big_yaw); + } + } + } + Eigen::Matrix3d R_camera2gimbal; + R_camera2gimbal << 0.0, 0.0, 1.0, -1.0, -0.0, 0.0, 0.0, -1.0, 0.0; + Eigen::Matrix4d T_camera_to_odom = utils::computeCameraToOdomTransform( + R_gimbal2odom, + R_camera2gimbal, + Eigen::Vector3d::Zero() + ); + armors.armors = one->armor_where->where(sorted_objs, T_camera_to_odom); + for (auto& armor: armors.armors) { + armor.timestamp = armors.timestamp; + } + auto& target = one->target; + target = one->tracker->track(armors); + + const auto now = std::chrono::steady_clock::now(); + + const auto latency_ms = + wust_vl::common::utils::time_utils::durationMs(frame.img_frame.timestamp, now); + latency_averager_->add(latency_ms); + latency_ms_ = latency_averager_->average(); + detect_count_++; + + printStats(); + } + + void update() noexcept { + std::lock_guard lock(active_results_mutex_); + + while (!active_results_.empty()) { + auto& obj = active_results_.front(); + + if (std::abs(wust_vl::common::utils::time_utils::durationSec( + obj.timestamp, + wust_vl::common::utils::time_utils::now() + )) + > active_time_) + { + active_results_.pop_front(); + } else { + break; + } + } + + if (active_results_.empty()) { + best_target_ = -1; + return; + } + + double max_score = min_score_; + best_target_ = -1; + for (size_t i = 0; i < ones_.size(); ++i) { + if (ones_[i]->total_score > max_score) { + max_score = ones_[i]->total_score; + best_target_ = ones_[i]->self_id; + } + } + } + GimbalCmd solve(double bullet_speed) { + GimbalCmd gimbal_cmd; + std::optional target; + int best_target = getBestTarget(); + if (best_target < 0) { + target = std::nullopt; + } else { + target = ones_[best_target]->target; + } + auto& very_aimer = ctx_.very_aimer; + if (!very_aimer) { + return gimbal_cmd; + } + if (target.has_value() && target->checkTargetAppear()) { + try { + gimbal_cmd = very_aimer->veryAim( + target.value(), + bullet_speed, + AutoAimFsm::AIM_WHOLE_CAR_CENTER + ); + gimbal_cmd.enable_pitch_diff = 0.0; + gimbal_cmd.enable_yaw_diff = 0.0; + gimbal_cmd.fire_advice = false; + } catch (...) { + WUST_ERROR("omni") << "VeryAim error"; + } + } else { + gimbal_cmd.appear = false; + } + return gimbal_cmd; + } + void printStats() { + utils::XSecOnce( + [&] { + WUST_INFO("armor_omni") << "det: " << detect_count_ << " best: " << best_target_ + << " lat: " << latency_ms_; + + detect_count_ = 0; + }, + 1.0 + ); + } + + int fps_; + int max_infer_running_ = 0; + std::atomic infer_running_count_ { 0 }; + + bool detect_color_; + bool main_tracking_ = false; + bool run_flag_ = false; + + double active_time_ = 0; + + std::deque active_results_; + + mutable std::mutex active_results_mutex_; + + std::vector ones_; + + YAML::Node config_; + + std::unique_ptr thread_pool_; + + std::unique_ptr timer_; + + ArmorDetectorBase::Ptr armor_detector_; + wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter_; + + std::unique_ptr> latency_averager_; + int best_target_ = -1; + + int detect_count_ = 0; + double latency_ms_; + double min_score_ = 10.0; + Ctx ctx_; +}; + +ArmorOmni::ArmorOmni(bool detect_color_init, const Ctx& ctx): + _impl(std::make_unique(detect_color_init, ctx)) {} + +ArmorOmni::~ArmorOmni() { + _impl.reset(); +} + +void ArmorOmni::start() { + _impl->start(); +} + +void ArmorOmni::setDetectColor(bool flag) { + _impl->setDetectColor(flag); +} + +void ArmorOmni::updateMainTracking(bool flag) { + _impl->updateMainTracking(flag); +} + +int ArmorOmni::getBestTarget() { + return _impl->getBestTarget(); +} +GimbalCmd ArmorOmni::solve(double bullet_speed) { + return _impl->solve(bullet_speed); +} + +} // namespace wust_vision::auto_aim \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_omni/armor_omni.hpp b/wust_vision-main/tasks/auto_aim/armor_omni/armor_omni.hpp new file mode 100644 index 0000000..946806a --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_omni/armor_omni.hpp @@ -0,0 +1,36 @@ +#pragma once +#include "tasks/auto_aim/armor_control/very_aimer.hpp" +#include "tasks/type_common.hpp" +#include +#include +#include +namespace wust_vision { +namespace auto_aim { + class ArmorOmni { + public: + struct Ctx { + std::shared_ptr> + car_motion_buffer; + std::shared_ptr> + big_yaw_motion_buffer; + VeryAimer::Ptr very_aimer; + }; + static constexpr const char* OMNI_CONFIG = "config/omni/omni.yaml"; + using Ptr = std::unique_ptr; + ArmorOmni(bool detect_color_init, const Ctx& ctx); + static Ptr create(bool detect_color_init, const Ctx& ctx) { + return std::make_unique(detect_color_init, ctx); + } + ~ArmorOmni(); + void start(); + void setDetectColor(bool flag); + void updateMainTracking(bool flag); + int getBestTarget(); + GimbalCmd solve(double bullet_speed); + + private: + struct Impl; + std::unique_ptr _impl; + }; +} // namespace auto_aim +} // namespace wust_vision diff --git a/wust_vision-main/tasks/auto_aim/armor_tracker/motion_models/motion_model_crazy.hpp b/wust_vision-main/tasks/auto_aim/armor_tracker/motion_models/motion_model_crazy.hpp new file mode 100644 index 0000000..4007bd6 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_tracker/motion_models/motion_model_crazy.hpp @@ -0,0 +1,286 @@ +// Copyright 2025 Xiaojian Wu +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +// ceres +#include +#include +#include +// project +#include "KalmanHyLib/kalman_hybird_lib.hpp" + +namespace crazy { +enum class MotionModel { + CONSTANT_VELOCITY = 0, // Constant velocity + CONSTANT_ROTATION = 1, // Constant rotation velocity + CONSTANT_VEL_ROT = 2 // Constant velocity and rotation velocity +}; + +constexpr int X_N = 11, Z_N = 8; +using VecZ = Eigen::Matrix; +using VecX = Eigen::Matrix; +enum class Mean : uint8_t { + PLBX = 0, + PLBY = 1, + PLTX = 2, + PLTY = 3, + PRTX = 4, + PRTY = 5, + PRBX = 6, + PRBY = 7, + // IMUY = 4, + // IMUP = 5, + // IMUR = 6, + Z_N = 8 +}; +enum class State : uint8_t { + CX = 0, + VCX = 1, + CY = 2, + VCY = 3, + CZ = 4, + VCZ = 5, + YAW = 6, + VYAW = 7, + R = 8, + L = 9, + H = 10, + outpost01DZ = 9, + outpost02DZ = 10, + X_N = 11 +}; +struct Predict { + Predict() = default; + explicit Predict( + double dt, + MotionModel model = MotionModel::CONSTANT_VEL_ROT, + double vrx = 0.0, + double vry = 0.0, + double vrz = 0.0 + ): + dt(dt), + model(model), + vrx(vrx), + vry(vry), + vrz(vrz) {} + + template + void operator()(const T x0[X_N], T x1[X_N]) const { + for (int i = 0; i < X_N; i++) { + x1[i] = x0[i]; + } + + // v_xyz + if (model == MotionModel::CONSTANT_VEL_ROT || model == MotionModel::CONSTANT_VELOCITY) { + // linear velocity + x1[(int)State::CX] += x0[(int)State::VCX] * T(dt); + x1[(int)State::CY] += x0[(int)State::VCY] * T(dt); + x1[(int)State::CZ] += x0[(int)State::VCZ] * T(dt); + } else { + // no velocity + x1[(int)State::VCX] *= T(0.); + x1[(int)State::VCY] *= T(0.); + x1[(int)State::VCZ] *= T(0.); + } + + x1[(int)State::CX] -= T(vrx) * T(dt); + x1[(int)State::CY] -= T(vry) * T(dt); + x1[(int)State::CZ] -= T(vrz) * T(dt); + + x1[(int)State::VCZ] *= T(0.); + // v_yaw + if (model == MotionModel::CONSTANT_VEL_ROT || model == MotionModel::CONSTANT_ROTATION) { + // angular velocity + x1[(int)State::YAW] += x0[(int)State::VYAW] * T(dt); + } else { + // no rotation + x1[(int)State::VYAW] *= T(0.); + } + } + + double dt; + MotionModel model; + double vrx, vry, vrz; +}; +constexpr double SMALL_ARMOR_WIDTH = 133.0 / 1000.0; // 135 +constexpr double SMALL_ARMOR_HEIGHT = 50.0 / 1000.0; // 55 +constexpr double LARGE_ARMOR_WIDTH = 225.0 / 1000.0; +constexpr double LARGE_ARMOR_HEIGHT = 50.0 / 1000.0; // 55 + +constexpr double FIFTTEN_DEGREE_RAD = 15 * CV_PI / 180; + +template +T normalize_angle_t(T angle) { + T two_pi = T(2.0 * M_PI); + return angle - two_pi * floor((angle + T(M_PI)) / two_pi); +} + +template +Eigen::Quaternion +eulerToQuat(const Eigen::Vector& euler, int axis0, int axis1, int axis2, bool extrinsic) { + T rz = euler[0]; + T ry = euler[1]; + T rx = euler[2]; + + Eigen::Quaternion qx(Eigen::AngleAxis(rx, Eigen::Vector3::UnitX())); + Eigen::Quaternion qy(Eigen::AngleAxis(ry, Eigen::Vector3::UnitY())); + Eigen::Quaternion qz(Eigen::AngleAxis(rz, Eigen::Vector3::UnitZ())); + + if (!extrinsic) + std::swap(axis0, axis2); + + Eigen::Quaternion q; + + if (axis0 == 0 && axis1 == 1 && axis2 == 2) + q = qx * qy * qz; + else if (axis0 == 0 && axis1 == 2 && axis2 == 1) + q = qx * qz * qy; + else if (axis0 == 1 && axis1 == 0 && axis2 == 2) + q = qy * qx * qz; + else if (axis0 == 1 && axis1 == 2 && axis2 == 0) + q = qy * qz * qx; + else if (axis0 == 2 && axis1 == 0 && axis2 == 1) + q = qz * qx * qy; + else if (axis0 == 2 && axis1 == 1 && axis2 == 0) + q = qz * qy * qx; + else + throw std::invalid_argument("Unsupported axis order"); + + return q; +} + +template +std::vector buildObjectPoints(const U& w, const U& h) noexcept { + using T = U; + T half2 = T(2.0); + return { PointType(T(0), w / half2, -h / half2), + PointType(T(0), w / half2, h / half2), + PointType(T(0), -w / half2, h / half2), + PointType(T(0), -w / half2, -h / half2) }; +} + +struct Measure { + struct MeasureCtx { + int armor_num = 4; + int id = 0; + Eigen::Matrix4d T_odom_to_camera_d; + cv::Mat camera_intrinsic; + cv::Mat camera_distortion; + bool is_big; + } ctx; + + Measure() = default; + explicit Measure(const MeasureCtx& c): ctx(c) {} + + template + void operator()(const T x[X_N], T z[Z_N]) const { + T id_t = T(ctx.id); + T num_t = T(ctx.armor_num); + T two = T(2.0); + T angle = normalize_angle_t(x[(int)State::YAW] + id_t * two * T(M_PI) / num_t); + + bool outpost = (ctx.armor_num == 3); + bool use_l_h = (ctx.armor_num == 4) && (ctx.id == 1 || ctx.id == 3); + + T r = use_l_h ? x[(int)State::R] + x[(int)State::L] : x[(int)State::R]; + + T armor_x = x[(int)State::CX] - ceres::cos(angle) * r; + T armor_y = x[(int)State::CY] - ceres::sin(angle) * r; + T armor_z = outpost ? getoutpost_armor_z(x) + : use_l_h ? x[(int)State::CZ] + x[(int)State::H] + : x[(int)State::CZ]; + + Eigen::Vector3 euler_odom; + euler_odom[0] = angle; //yaw + euler_odom[1] = outpost ? T(-FIFTTEN_DEGREE_RAD) : T(FIFTTEN_DEGREE_RAD); //pitch + euler_odom[2] = T(M_PI / 2.0); //roll + + Eigen::Quaternion q_odom = eulerToQuat(euler_odom, 2, 1, 0, true); + + Eigen::Matrix4 T_odom_to_camera = ctx.T_odom_to_camera_d.cast(); + + Eigen::Vector4 pos_odom4(armor_x, armor_y, armor_z, T(1.0)); + Eigen::Vector4 pos_camera4 = T_odom_to_camera * pos_odom4; + Eigen::Vector3 pos_camera = pos_camera4.template head<3>(); + + Eigen::Matrix3 R_odom_to_camera = T_odom_to_camera.block(0, 0, 3, 3).template cast(); + Eigen::Matrix3 R_ori_odom = q_odom.normalized().toRotationMatrix(); + Eigen::Matrix3 R_camera = R_odom_to_camera * R_ori_odom; + Eigen::Quaternion q_camera(R_camera); + q_camera.normalize(); + + T w3 = ctx.is_big ? T(LARGE_ARMOR_WIDTH) : T(SMALL_ARMOR_WIDTH); + T h3 = ctx.is_big ? T(LARGE_ARMOR_HEIGHT) : T(SMALL_ARMOR_HEIGHT); + + auto objPts = buildObjectPoints>(w3, h3); + + Eigen::Matrix3 R = q_camera.toRotationMatrix(); + Eigen::Matrix t = pos_camera; + + std::vector> Pc; + Pc.reserve(objPts.size()); + for (const auto& p: objPts) { + Eigen::Matrix v = p; + Pc.push_back(R * v + t); + } + + const cv::Mat& K = ctx.camera_intrinsic; + T fx = T(K.at(0, 0)); + T fy = T(K.at(1, 1)); + T cx = T(K.at(0, 2)); + T cy = T(K.at(1, 2)); + + std::array u, v; + for (int i = 0; i < 4; i++) { + T Xc = Pc[i][0]; + T Yc = Pc[i][1]; + T Zc = Pc[i][2]; + + u[i] = fx * (Xc / Zc) + cx; + v[i] = fy * (Yc / Zc) + cy; + } + + z[0] = u[0]; + z[1] = v[0]; + z[2] = u[1]; + z[3] = v[1]; + z[4] = u[2]; + z[5] = v[2]; + z[6] = u[3]; + z[7] = v[3]; + } + + template + T getoutpost_armor_z(const T x[X_N]) const { + if (ctx.id == 0) + return x[(int)State::CZ]; + if (ctx.id == 1) + return x[(int)State::CZ] + x[(int)State::outpost01DZ]; + if (ctx.id == 2) + return x[(int)State::CZ] + x[(int)State::outpost02DZ]; + return x[(int)State::CZ]; + } + + using VecX = Eigen::Matrix; + using VecZ = Eigen::Matrix; + + void h(const VecX& x, VecZ& z) const { + operator()(x.data(), z.data()); + } +}; + +using RobotStateEKF = kalman_hybird_lib::ExtendedKalmanFilter; +using RobotStateESEKF = kalman_hybird_lib::ErrorStateEKF; + +} // namespace crazy \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_tracker/motion_models/motion_modelypdv2.hpp b/wust_vision-main/tasks/auto_aim/armor_tracker/motion_models/motion_modelypdv2.hpp new file mode 100644 index 0000000..98e95a1 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_tracker/motion_models/motion_modelypdv2.hpp @@ -0,0 +1,226 @@ +// Copyright 2025 Xiaojian Wu +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +// ceres +#include +#include +// project +#include "KalmanHyLib/kalman_hybird_lib.hpp" +namespace ypdv2armor_motion_model { + +enum class MotionModel { + CONSTANT_VELOCITY = 0, // Constant velocity + CONSTANT_ROTATION = 1, // Constant rotation velocity + CONSTANT_VEL_ROT = 2 // Constant velocity and rotation velocity +}; + +// X_N: state dimension, Z_N: measurement dimension +constexpr int X_N = 11, Z_N = 4; +using VecZ = Eigen::Matrix; +using VecX = Eigen::Matrix; +enum class MeasureID : uint8_t { YPD_Y = 0, YPD_P = 1, YPD_D = 2, ORI_YAW = 3, Z_N = 4 }; +enum class StateID : uint8_t { + CX = 0, + VCX = 1, + CY = 2, + VCY = 3, + CZ = 4, + VCZ = 5, + YAW = 6, + VYAW = 7, + R = 8, + L = 9, + H = 10, + outpost01DZ = 9, + outpost02DZ = 10, + X_N = 11 +}; +struct State { + VecX x; + [[nodiscard]] inline double cx() const noexcept { + return x((int)StateID::CX); + } + [[nodiscard]] inline double cy() const noexcept { + return x((int)StateID::CY); + } + [[nodiscard]] inline double cz() const noexcept { + return x((int)StateID::CZ); + } + [[nodiscard]] inline Eigen::Vector3d pos() const noexcept { + return Eigen::Vector3d(cx(), cy(), cz()); + } + [[nodiscard]] inline double vcx() const noexcept { + return x((int)StateID::VCX); + } + [[nodiscard]] inline double vcy() const noexcept { + return x((int)StateID::VCY); + } + [[nodiscard]] inline double vcz() const noexcept { + return x((int)StateID::VCZ); + } + [[nodiscard]] inline Eigen::Vector3d vel() const noexcept { + return Eigen::Vector3d(vcx(), vcy(), vcz()); + } + [[nodiscard]] inline double vyaw() const noexcept { + return x((int)StateID::VYAW); + } + [[nodiscard]] inline double yaw() const noexcept { + return x((int)StateID::YAW); + } + [[nodiscard]] inline double r() const noexcept { + return x((int)StateID::R); + } + [[nodiscard]] inline double l() const noexcept { + return x((int)StateID::L); + } + [[nodiscard]] inline double h() const noexcept { + return x((int)StateID::H); + } + [[nodiscard]] inline double outpost01DZ() const noexcept { + return x((int)StateID::outpost01DZ); + } + [[nodiscard]] inline double outpost02DZ() const noexcept { + return x((int)StateID::outpost02DZ); + } +}; +struct Predict { + Predict() = default; + explicit Predict(double dt, MotionModel model = MotionModel::CONSTANT_VEL_ROT): + dt(dt), + model(model) {} + + template + void operator()(const T x0[X_N], T x1[X_N]) const { + for (int i = 0; i < X_N; i++) { + x1[i] = x0[i]; + } + + // v_xyz + if (model == MotionModel::CONSTANT_VEL_ROT || model == MotionModel::CONSTANT_VELOCITY) { + // linear velocity + x1[(int)StateID::CX] += x0[(int)StateID::VCX] * T(dt); + x1[(int)StateID::CY] += x0[(int)StateID::VCY] * T(dt); + x1[(int)StateID::CZ] += x0[(int)StateID::VCZ] * T(dt); + } else { + // no velocity + x1[(int)StateID::VCX] *= T(0.); + x1[(int)StateID::VCY] *= T(0.); + x1[(int)StateID::VCZ] *= T(0.); + } + + // v_yaw + if (model == MotionModel::CONSTANT_VEL_ROT || model == MotionModel::CONSTANT_ROTATION) { + // angular velocity + x1[(int)StateID::YAW] += x0[(int)StateID::VYAW] * T(dt); + } else { + // no rotation + x1[(int)StateID::VYAW] *= T(0.); + } + clampState(x1); + } + template + void clampState(T x1[X_N]) const { + auto& r = x1[(int)StateID::R]; + auto& l = x1[(int)StateID::L]; + r = std::clamp(r, T(0.1), T(0.5)); + if (r < T(0.1) || r > T(0.5)) { + r = T(0.25); + l = T(0); + } + T sum = r + l; + if (sum < T(0.1) || sum > T(0.5)) { + r = T(0.25); + l = T(0); + } + auto& h = x1[(int)StateID::H]; + h = std::clamp(h, T(-0.5), T(0.5)); + } + void f(const VecX& x0, VecX& x1) const { + assert(x0.size() == X_N); + assert(x1.size() == X_N); + operator()(x0.data(), x1.data()); + } + + double dt; + MotionModel model; +}; +template +T normalize_angle_t(T angle) { + T two_pi = T(2.0 * M_PI); + return angle - two_pi * floor((angle + T(M_PI)) / two_pi); +} + +struct Measure { + struct MeasureCtx { + MeasureCtx() = default; + MeasureCtx(int id, int armor_num): armor_num(armor_num), id(id) {} + int armor_num = 4; + int id = 0; + } ctx; + Measure() = default; + explicit Measure(MeasureCtx c): ctx(c) {} + template + void operator()(const T x[X_N], T z[Z_N]) const { + // Compute armor position + auto [armor_x, armor_y, armor_z, angle] = h_armor_xyza(x); + T xy_dist = ceres::sqrt(armor_x * armor_x + armor_y * armor_y); + T dist = ceres::sqrt(xy_dist * xy_dist + armor_z * armor_z); + // Observation model + z[(int)MeasureID::YPD_Y] = ceres::atan2(armor_y, armor_x); // yaw + z[(int)MeasureID::YPD_P] = ceres::atan2(armor_z, xy_dist); // pitch + z[(int)MeasureID::YPD_D] = dist; // distance + z[(int)MeasureID::ORI_YAW] = angle; // orientation_yaw + } + template + T get_angle(const T x[X_N]) const { + return normalize_angle_t(x[(int)StateID::YAW] + ctx.id * 2 * M_PI / ctx.armor_num); + } + template + std::tuple h_armor_xyza(const T x[X_N]) const { + T angle = get_angle(x); + auto outpost = ctx.armor_num == 3; + auto use_l_h = (ctx.armor_num == 4) && (ctx.id == 1 || ctx.id == 3); + T r = (use_l_h) ? x[(int)StateID::R] + x[(int)StateID::L] : x[(int)StateID::R]; + T armor_x = x[(int)StateID::CX] - ceres::cos(angle) * r; + T armor_y = x[(int)StateID::CY] - ceres::sin(angle) * r; + T armor_z = (outpost) ? getoutpost_armor_z(x) + : (use_l_h) ? x[(int)StateID::CZ] + x[(int)StateID::H] + : x[(int)StateID::CZ]; + return { armor_x, armor_y, armor_z, angle }; + } + Eigen::Vector4d h_armor_xyza(const VecX& x) const { + assert(x.size() == X_N); + auto [armor_x, armor_y, armor_z, angle] = h_armor_xyza(x.data()); + + return { armor_x, armor_y, armor_z, angle }; + } + template + T getoutpost_armor_z(const T x[X_N]) const { + return (ctx.id == 0) ? x[(int)StateID::CZ] + : (ctx.id == 1) ? x[(int)StateID::CZ] + x[(int)StateID::outpost01DZ] + : (ctx.id == 2) ? x[(int)StateID::CZ] + x[(int)StateID::outpost02DZ] + : x[(int)StateID::CZ]; + } + void h(const VecX& x, VecZ& z) const { + assert(x.size() == X_N); + assert(z.size() == Z_N); + operator()(x.data(), z.data()); + } +}; + +using RobotStateEKF = kalman_hybird_lib::ExtendedKalmanFilter; +using RobotStateESEKF = kalman_hybird_lib::ErrorStateEKF; + +} // namespace ypdv2armor_motion_model diff --git a/wust_vision-main/tasks/auto_aim/armor_tracker/target.cpp b/wust_vision-main/tasks/auto_aim/armor_tracker/target.cpp new file mode 100644 index 0000000..71b64d4 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_tracker/target.cpp @@ -0,0 +1,379 @@ +#include "target.hpp" +namespace wust_vision { +namespace auto_aim { + Target::Target() { + target_state_.x = Eigen::VectorXd::Zero(MModel::X_N); + } + Target::Target(const Armor& a, TargetConfig::Ptr target_config) { + Eigen::DiagonalMatrix p0; + if (a.number == ArmorNumber::OUTPOST) { + p0.diagonal() << 1, 64, 1, 64, 1, 81, 0.4, 100, 1e-4, 0.1, 0.1; + armor_num_ = 3; + radius_pre_ = 0.2765; + } else if (a.number == ArmorNumber::BASE) { + p0.diagonal() << 1, 64, 1, 64, 1, 64, 0.4, 100, 1e-4, 0, 0; + armor_num_ = 3; + radius_pre_ = 0.3205; + } else { + p0.diagonal() << 1, 64, 1, 64, 1, 64, 0.4, 100, 1, 1, 1; + armor_num_ = 4; + radius_pre_ = 0.2; + } + target_config_ = target_config; + const auto yfv2 = MModel::Predict(0.005); + ctx_.armor_num = armor_num_; + ctx_.id = 0; + const auto yhv2 = MModel::Measure(ctx_); + const auto yu_qv2 = [this]() { + Eigen::Matrix q; + return q; + }; + + const auto yu_rv2 = [this](const Eigen::Matrix& z) { + Eigen::Matrix r; + return r; + }; + esekf_ypd_ = MModel::RobotStateESEKF(yfv2, yhv2, yu_qv2, yu_rv2, p0); + + esekf_ypd_.setResidualFunc([this]( + const Eigen::Matrix& z_pred, + const Eigen::Matrix& z + ) { + Eigen::Matrix r = z - z_pred; + r[0] = angles::shortest_angular_distance( + z_pred[(int)MModel::MeasureID::YPD_Y], + z[(int)MModel::MeasureID::YPD_Y] + ); // yaw + r[3] = angles::shortest_angular_distance( + z_pred[(int)MModel::MeasureID::ORI_YAW], + z[(int)MModel::MeasureID::ORI_YAW] + ); // ori_yaw + return r; + }); + esekf_ypd_.setIterationNum(target_config_->esekf_iter_num_param.get()); + esekf_ypd_.setInjectFunc([this]( + const Eigen::Matrix& delta, + Eigen::Matrix& nominal + ) { + for (int i = 0; i < MModel::X_N; i++) { + if (i == (int)MModel::StateID::YAW) + continue; + nominal[i] += delta[i]; + } + nominal[(int)MModel::StateID::YAW] = angles::normalize_angle( + nominal[(int)MModel::StateID::YAW] + delta[(int)MModel::StateID::YAW] + ); + }); + + const double xa = a.target_pos.x(); + const double ya = a.target_pos.y(); + const double za = a.target_pos.z(); + const double yaw = utils::orientationToYaw(a.target_ori); + + target_state_.x = Eigen::VectorXd::Zero(MModel::X_N); + const double r = radius_pre_; + const double xc = xa + r * cos(yaw); + const double yc = ya + r * sin(yaw); + const double zc = za; + target_state_.x << xc, 0, yc, 0, zc, 0, yaw, 0, r, 0, 0; + esekf_ypd_.setState(target_state_.x); + tracked_id_ = a.number; + type_ = a.type; + last_t_ = a.timestamp; + timestamp_ = a.timestamp; + is_inited = true; + } + Eigen::Matrix + Target::computeMeasurementCovariance(const Eigen::Matrix& z + ) const noexcept { + Eigen::Matrix r; + const double delta_angle = angles::normalize_angle(z[3] - z[0]); + const double abs_delta = std::abs(delta_angle); + + // sin插值函数,小值慢、大值快 + const auto sinInterp = [](double x, double x0, double x1, double y0, double y1) -> double { + double t = (x - x0) / (x1 - x0); + if (t < 0) + t = 0; + if (t > 1) + t = 1; + double s = std::sin(t * M_PI / 2.0); + return y0 + s * (y1 - y0); + }; + // clang-format off + r <yp_r_param.get(), 0, 0, 0, + 0, target_config_->yp_r_param.get() , 0, 0, + 0, 0, sinInterp(abs_delta, 0.0, M_PI/2.0, target_config_->dis_r_front_param.get(), target_config_->dis_r_side_param.get())+z[2]*z[2]*target_config_->dis2_r_ratio_param.get(), 0, + 0, 0, 0,log(std::abs(z[2]) + 1) *target_config_->yaw_r_log_ratio_param.get() + sinInterp(M_PI/2.0-abs_delta, 0.0, M_PI/2.0, target_config_->yaw_r_base_side_param.get(), target_config_->yaw_r_base_front_param.get()); + // clang-format on + return r; + } + Eigen::Matrix Target::computeProcessNoise(double dt + ) const noexcept { + Eigen::Matrix q; + Eigen::Vector3d q_xyz; + double q_yaw; + double q_l, q_h; + if (tracked_id_ == ArmorNumber::OUTPOST) { + q_xyz = target_config_->qxyz_output; // 前哨站加速度方差 + q_yaw = target_config_->qyaw_output_param.get(); // 前哨站角加速度方差 + q_l = target_config_->q_outpost_dz_param.get(); + q_h = target_config_->q_outpost_dz_param.get(); + } else { + q_xyz = target_config_->qxyz_common; // 加速度方差 + q_yaw = target_config_->qyaw_common_param.get(); // 角加速度方差 + q_l = target_config_->q_l_param.get(); + q_h = target_config_->q_h_param.get(); + } + const double t = dt; + const double q_x_x = pow(t, 4) / 4 * q_xyz.x(), q_x_vx = pow(t, 3) / 2 * q_xyz.x(), + q_vx_vx = pow(t, 2) * q_xyz.x(); + const double q_y_y = pow(t, 4) / 4 * q_xyz.y(), q_y_vy = pow(t, 3) / 2 * q_xyz.y(), + q_vy_vy = pow(t, 2) * q_xyz.y(); + const double q_z_z = pow(t, 4) / 4 * q_xyz.z(), q_z_vz = pow(t, 3) / 2 * q_xyz.z(), + q_vz_vz = pow(t, 2) * q_xyz.z(); + const double q_yaw_yaw = pow(t, 4) / 4 * q_yaw, q_yaw_vyaw = pow(t, 3) / 2 * q_yaw, + q_vyaw_vyaw = pow(t, 2) * q_yaw; + const double q_r = target_config_->q_r_param.get(); + + // clang-format off + // xc v_xc yc v_yc zc v_zc yaw v_yaw r l h + q << q_x_x, q_x_vx, 0, 0, 0, 0, 0, 0, 0, 0, 0, + q_x_vx, q_vx_vx,0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, q_y_y, q_y_vy, 0, 0, 0, 0, 0, 0, 0, + 0, 0, q_y_vy, q_vy_vy,0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, q_z_z, q_z_vz, 0, 0, 0, 0, 0, + 0, 0, 0, 0, q_z_vz, q_vz_vz,0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, q_yaw_yaw, q_yaw_vyaw, 0, 0, 0, + 0, 0, 0, 0, 0, 0, q_yaw_vyaw, q_vyaw_vyaw,0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, q_r, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, q_l, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, q_h; + // clang-format on + return q; + } + MModel::Predict Target::getPredictFunc(double dt) const noexcept { + MModel::Predict predict_func; + if (tracked_id_ == ArmorNumber::OUTPOST) { + predict_func = MModel::Predict { + dt, + MModel::MotionModel::CONSTANT_ROTATION, + }; + } else { + predict_func = MModel::Predict { + dt, + MModel::MotionModel::CONSTANT_VEL_ROT, + }; + } + return predict_func; + } + void Target::predict(std::chrono::steady_clock::time_point t) noexcept { + const double dt = wust_vl::common::utils::time_utils::durationSec(last_t_, t); + + predict(dt); + + last_t_ = t; + } + void Target::predict(double dt) noexcept { + MModel::Predict predict_func = getPredictFunc(dt); + + esekf_ypd_.setPredictFunc(predict_func); + const auto yu_qv2 = [dt, this]() { return computeProcessNoise(dt); }; + + esekf_ypd_.setUpdateQ(yu_qv2); + + target_state_.x = esekf_ypd_.predict(); + if (target_state_.pos().norm() < 0.5) { + is_tracking = false; + } + } + void Target::predictSimple(std::chrono::steady_clock::time_point t) noexcept { + const double dt = wust_vl::common::utils::time_utils::durationSec(last_t_, t); + + predictSimple(dt); + + last_t_ = t; + } + void Target::predictSimple(double dt) noexcept { + MModel::Predict predict_func = getPredictFunc(dt); + predict_func.f(target_state_.x, target_state_.x); + if (target_state_.pos().norm() < 0.5) { + is_tracking = false; + } + } + bool Target::update(const std::pair& a) noexcept { + const auto armor = a.second; + const auto id = a.first; + const auto yu_rv2 = [this](const Eigen::Matrix& z) { + return this->computeMeasurementCovariance(z); + }; + esekf_ypd_.setUpdateR(yu_rv2); + measurement_ = getMeasure(armor); + + if (id != 0) + jumped = true; + + ctx_.id = id; + esekf_ypd_.setMeasureFunc(MModel::Measure { ctx_ }); + + target_state_.x = esekf_ypd_.update(measurement_); + timestamp_ = armor.timestamp; + last_t_ = timestamp_; + return true; + } + cv::Rect Target::expanded( + Eigen::Matrix4d T_camera_to_odom, + const cv::Mat& camera_intrinsic, + const cv::Mat& camera_distortion, + const cv::Size& image_size + ) const noexcept { + const double dt = wust_vl::common::utils::time_utils::durationSec( + timestamp_, + wust_vl::common::utils::time_utils::now() + ); + if (!is_inited || dt > target_config_->lost_time_thres_param.get()) { + return cv::Rect(0, 0, 0, 0); + } + + const float car_box_half = + std::max(target_state_.r(), target_state_.r() + target_state_.l()) + 0.15; + + static std::vector CAR_BOX; + CAR_BOX = { { 0, car_box_half, -car_box_half }, + { 0, -car_box_half, -car_box_half }, + { 0, -car_box_half, car_box_half }, + { 0, car_box_half, car_box_half } }; + + const Eigen::Matrix4d T_odom_to_camera = T_camera_to_odom.inverse(); + const Eigen::Vector4d + pos_odom(target_state_.cx(), target_state_.cy(), target_state_.cz(), 1.0); + const Eigen::Vector4d pos_cam = T_odom_to_camera * pos_odom; + + if (pos_cam.z() <= 0.2) { + return cv::Rect(0, 0, 0, 0); + } + + const cv::Mat tvec = (cv::Mat_(3, 1) << pos_cam.x(), pos_cam.y(), pos_cam.z()); + + Eigen::Vector3d euler; + euler.z() = M_PI / 2.0; + euler.y() = 0; + euler.x() = std::atan2(pos_odom.y(), pos_odom.x()); + + const Eigen::Quaterniond ori = utils::eulerToQuat(euler, utils::EulerOrder::ZYX); + const auto target_ori = utils::transformOrientation(ori, T_odom_to_camera); + const Eigen::Matrix3d tf_rot = target_ori.toRotationMatrix(); + + const cv::Mat rot_mat = + (cv::Mat_(3, 3) << tf_rot(0, 0), + tf_rot(0, 1), + tf_rot(0, 2), + tf_rot(1, 0), + tf_rot(1, 1), + tf_rot(1, 2), + tf_rot(2, 0), + tf_rot(2, 1), + tf_rot(2, 2)); + + cv::Mat rvec; + cv::Rodrigues(rot_mat, rvec); + + std::vector pts_2d; + cv::projectPoints(CAR_BOX, rvec, tvec, camera_intrinsic, camera_distortion, pts_2d); + + const cv::Rect rect = cv::boundingRect(pts_2d); + + const cv::Rect img_rect(0, 0, image_size.width, image_size.height); + if ((rect & img_rect).area() <= 0) { + return cv::Rect(0, 0, 0, 0); + } + + const int base_side = std::max(rect.width, rect.height); + const int max_side = std::max(image_size.width, image_size.height); + + const double lost_dt = target_config_->lost_time_thres_param.get(); + const double dt_clamped = std::max(0.0, std::min(dt, lost_dt)); + + int side = static_cast(base_side + (max_side - base_side) * (dt_clamped / lost_dt)); + + if (dt >= lost_dt) { + side = max_side; + } + + const int cx = rect.x + rect.width / 2; + const int cy = rect.y + rect.height / 2; + cv::Rect square(cx - side / 2, cy - side / 2, side, side); + + square &= img_rect; + + return square; + } + + std::vector> Target::match(const std::vector& armors) noexcept { + std::vector> result; + const int n_obs = static_cast(armors.size()); + const int armors_num = armor_num_; + const double GATE = target_config_->match_gate_param.get(); + const double max_cost = 1e9; + std::vector> cost(n_obs, std::vector(armors_num, max_cost + 1)); + std::vector meas_list(n_obs); + for (int j = 0; j < n_obs; ++j) { + meas_list[j] = getMeasure(armors[j]); + } + for (int j = 0; j < n_obs; ++j) { + for (int id = 0; id < armors_num; ++id) { + MModel::Measure::MeasureCtx tmp_ctx(id, armors_num); + MModel::Measure measure(tmp_ctx); + MModel::VecZ z_pred; + measure.h(target_state_.x, z_pred); + + MModel::VecZ nu = meas_list[j] - z_pred; + nu[(int)MModel::MeasureID::YPD_Y] = + angles::normalize_angle(nu[(int)MModel::MeasureID::YPD_Y]); + nu[(int)MModel::MeasureID::YPD_P] = + angles::normalize_angle(nu[(int)MModel::MeasureID::YPD_P]); + nu[(int)MModel::MeasureID::ORI_YAW] = + angles::normalize_angle(nu[(int)MModel::MeasureID::ORI_YAW]); + auto R = computeMeasurementCovariance(z_pred); + double d2 = nu.transpose() * R.ldlt().solve(nu); + + // 门控 + if (std::isfinite(d2) && d2 < GATE) { + cost[j][id] = d2; + } + } + } + std::vector used_obs(n_obs, false); + std::vector used_id(armors_num, false); + + while (true) { + double best = max_cost; + int best_j = -1; + int best_id = -1; + + for (int j = 0; j < n_obs; ++j) { + if (used_obs[j]) + continue; + for (int id = 0; id < armors_num; ++id) { + if (used_id[id]) + continue; + if (cost[j][id] < best) { + best = cost[j][id]; + best_j = j; + best_id = id; + } + } + } + + if (best_j < 0 || best_id < 0) { + break; + } + + used_obs[best_j] = true; + used_id[best_id] = true; + result.push_back(std::make_pair(best_id, armors[best_j])); + } + return result; + } +} // namespace auto_aim +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_tracker/target.hpp b/wust_vision-main/tasks/auto_aim/armor_tracker/target.hpp new file mode 100644 index 0000000..498d885 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_tracker/target.hpp @@ -0,0 +1,185 @@ +#pragma once +#include "tasks/auto_aim/armor_tracker/motion_models/motion_modelypdv2.hpp" +#include "tasks/auto_aim/type.hpp" +#include "wust_vl/common/utils/parameter.hpp" +#include +namespace wust_vision { +namespace auto_aim { + namespace MModel = ypdv2armor_motion_model; + + struct TargetConfig: wust_vl::common::utils::ParamGroup { + static constexpr const char* kKey = "armor_tracker"; + const char* key() const override { + return kKey; + } + GEN_PARAM(int, esekf_iter_num); + GEN_PARAM(double, lost_time_thres); + GEN_PARAM(int, tracking_thres); + GEN_PARAM(double, max_yaw_diff_deg); + GEN_PARAM(double, max_dis_diff); + GEN_PARAM(double, match_gate); + GEN_PARAM(double, qyaw_common); + GEN_PARAM(double, qyaw_output); + GEN_PARAM(double, q_r); + GEN_PARAM(double, q_l); + GEN_PARAM(double, q_h); + GEN_PARAM(double, q_outpost_dz); + GEN_PARAM(double, yp_r); + GEN_PARAM(double, dis_r_front); + GEN_PARAM(double, dis_r_side); + GEN_PARAM(double, dis2_r_ratio); + GEN_PARAM(double, yaw_r_base_front); + GEN_PARAM(double, yaw_r_base_side); + GEN_PARAM(double, yaw_r_log_ratio); + GEN_PARAM(std::vector, qxyz_common); + GEN_PARAM(std::vector, qxyz_output); + Eigen::Vector3d qxyz_common = { 100, 100, 100 }; + Eigen::Vector3d qxyz_output = { 10, 10, 10 }; + using Ptr = std::shared_ptr; + TargetConfig() { + qxyz_output_param.onChange([this](auto o, auto n) { + qxyz_common = Eigen::Vector3d(n[0], n[1], n[2]); + }); + qxyz_output_param.onChange([this](auto o, auto n) { + qxyz_output = Eigen::Vector3d(n[0], n[1], n[2]); + }); + } + static Ptr create() { + return std::make_shared(); + } + void loadSelf(const YAML::Node& node) override { + esekf_iter_num_param.load(node); + lost_time_thres_param.load(node); + tracking_thres_param.load(node); + max_yaw_diff_deg_param.load(node); + max_dis_diff_param.load(node); + match_gate_param.load(node); + qyaw_common_param.load(node); + qyaw_output_param.load(node); + qxyz_common_param.load(node); + qxyz_output_param.load(node); + q_r_param.load(node); + q_l_param.load(node); + q_h_param.load(node); + q_outpost_dz_param.load(node); + yp_r_param.load(node); + dis_r_front_param.load(node); + dis_r_side_param.load(node); + yaw_r_base_front_param.load(node); + yaw_r_base_side_param.load(node); + yaw_r_log_ratio_param.load(node); + } + }; + class Target { + public: + Target(); + Target(const Armor& armor, TargetConfig::Ptr target_config); + MModel::Measure::MeasureCtx ctx_; + ArmorNumber tracked_id_; + std::string type_; + MModel::VecZ measurement_ = Eigen::Matrix::Zero(); + MModel::State target_state_ = MModel::State(); + double radius_pre_; + + int armor_num_ = 4; + bool jumped = false; + bool is_inited = false; + bool is_tracking = false; + std::chrono::steady_clock::time_point last_t_; + std::chrono::steady_clock::time_point timestamp_; + MModel::RobotStateESEKF esekf_ypd_; + TargetConfig::Ptr target_config_; + [[nodiscard]] cv::Rect expanded( + Eigen::Matrix4d T_camera_to_odom, + const cv::Mat& camera_intrinsic, + const cv::Mat& camera_distortion, + const cv::Size& image_size + ) const noexcept; + void predict(std::chrono::steady_clock::time_point t) noexcept; + void predict(double dt) noexcept; + void predictSimple(std::chrono::steady_clock::time_point t) noexcept; + void predictSimple(double dt) noexcept; + [[nodiscard]] MModel::Predict getPredictFunc(double dt) const noexcept; + bool update(const std::pair& armor) noexcept; + [[nodiscard]] Eigen::Matrix + computeMeasurementCovariance(const Eigen::Matrix& z) const noexcept; + [[nodiscard]] Eigen::Matrix computeProcessNoise(double dt + ) const noexcept; + [[nodiscard]] std::optional getArmorNumber() const noexcept { + if (!checkTargetAppear()) { + return std::nullopt; + } + return tracked_id_; + } + + [[nodiscard]] std::vector getArmorYaws() const noexcept { + std::vector yaw_list; + yaw_list.reserve(armor_num_); + for (int i = 0; i < armor_num_; i++) { + MModel::Measure::MeasureCtx _ctx(i, armor_num_); + MModel::Measure measure(_ctx); + yaw_list.push_back(measure.get_angle(target_state_.x.data())); + } + return yaw_list; + } + [[nodiscard]] std::vector getArmorPositions() const noexcept { + std::vector armor_positions; + armor_positions.reserve(armor_num_); + for (int i = 0; i < armor_num_; i++) { + MModel::Measure::MeasureCtx _ctx(i, armor_num_); + MModel::Measure measure(_ctx); + const Eigen::Vector4d xyza = measure.h_armor_xyza(target_state_.x); + armor_positions.push_back(xyza.head<3>()); + } + return armor_positions; + } + [[nodiscard]] std::vector getArmorPosAndYaw() const noexcept { + std::vector pos_yaw; + pos_yaw.reserve(armor_num_); + for (int i = 0; i < armor_num_; ++i) { + MModel::Measure::MeasureCtx _ctx(i, armor_num_); + MModel::Measure measure(_ctx); + const Eigen::Vector4d xyza = measure.h_armor_xyza(target_state_.x); + pos_yaw.push_back(xyza); + } + return pos_yaw; + } + [[nodiscard]] double getMeanZ() const noexcept { + double mean = 0; + for (const auto& p: getArmorPositions()) { + mean += p.z(); + } + return mean / armor_num_; + } + [[nodiscard]] double getArmor2CenterXYDis(int id) const noexcept { + const auto use_l_h = (armor_num_ == 4) && (id == 1 || id == 3); + const auto r = (use_l_h) ? target_state_.r() + target_state_.l() : target_state_.r(); + return r; + } + [[nodiscard]] std::vector> match(const std::vector& armors + ) noexcept; + + [[nodiscard]] inline bool checkTargetAppear() const noexcept { + const bool appear = is_tracking + && wust_vl::common::utils::time_utils::durationSec( + timestamp_, + wust_vl::common::utils::time_utils::now() + ) < target_config_->lost_time_thres_param.get(); + return appear; + } + + [[nodiscard]] Eigen::Matrix getMeasure(const Armor& a) noexcept { + const auto p = a.target_pos; + const double measured_yaw = utils::orientationToYaw(a.target_ori); + double ypd_y = std::atan2(p.y(), p.x()); + static double last_ypd_y = 0; + ypd_y = last_ypd_y + angles::shortest_angular_distance(last_ypd_y, ypd_y); + last_ypd_y = ypd_y; + const double ypd_p = std::atan2(p.z(), std::sqrt(p.x() * p.x() + p.y() * p.y())); + const double ypd_d = std::sqrt(p.x() * p.x() + p.y() * p.y() + p.z() * p.z()); + + return Eigen::Vector4d(ypd_y, ypd_p, ypd_d, measured_yaw); + } + }; +} // namespace auto_aim +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_tracker/trackerv3.cpp b/wust_vision-main/tasks/auto_aim/armor_tracker/trackerv3.cpp new file mode 100644 index 0000000..5036e21 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_tracker/trackerv3.cpp @@ -0,0 +1,209 @@ +#include "trackerv3.hpp" + +namespace wust_vision { +namespace auto_aim { + struct Tracker::Impl { + public: + enum State { + LOST, + DETECTING, + TRACKING, + TEMP_LOST, + } tracker_state = LOST; + Impl(wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter) { + tracker_state = LOST; + target_config_ = TargetConfig::create(); + auto_aim_config_parameter->registerGroup(*target_config_); + auto_aim_config_parameter->reloadFromOldPath(); + target_ = Target(); + } + + Target track(const Armors& armors_msg) noexcept { + const double dt = + std::chrono::duration(armors_msg.timestamp - last_time_).count(); + last_time_ = armors_msg.timestamp; + lost_thres_ = + std::abs(static_cast(target_config_->lost_time_thres_param.get() / dt)); + Armors armors; + armors = armors_msg; + std::erase_if(armors.armors, [this](const Armor& a) { + double center_yaw = + std::atan2(target_.target_state_.cy(), target_.target_state_.cx()); + bool state_check = tracker_state == TRACKING; + bool outpost_check = target_.tracked_id_ == ArmorNumber::OUTPOST && !a.is_ok; + bool pose_check = + (std::abs(angles::normalize_angle( + orientationToYaw(a.target_ori, center_yaw) - center_yaw + )) > (target_config_->max_yaw_diff_deg_param.get() * M_PI / 180.0) + || std::abs((a.target_pos - target_.target_state_.pos()).norm()) + > target_config_->max_dis_diff_param.get()) + && target_.is_inited + && std::abs(wust_vl::common::utils::time_utils::durationMs( + target_.timestamp_, + wust_vl::common::utils::time_utils::now() + )) + < 1000.0; + + return state_check && outpost_check || pose_check; + }); + + std::sort( + armors.armors.begin(), + armors.armors.end(), + [](const Armor& a, const Armor& b) { + return a.distance_to_image_center < b.distance_to_image_center; + } + ); + bool found; + if (tracker_state == LOST) { + found = initTarget(armors); + } else { + found = updateTarget(armors); + } + updateFsm(found); + + return target_; + } + void updateFsm(bool found) noexcept { + switch (tracker_state) { + case DETECTING: + if (found) { + if (++detect_count_ > target_config_->tracking_thres_param.get()) { + detect_count_ = 0; + tracker_state = TRACKING; + } + } else { + detect_count_ = 0; + tracker_state = LOST; + } + break; + + case TRACKING: + if (!found) { + tracker_state = TEMP_LOST; + lost_count_ = 1; + } + break; + + case TEMP_LOST: + if (!found) { + if (++lost_count_ > lost_thres_) { + lost_count_ = 0; + tracker_state = LOST; + } + } else { + lost_count_ = 0; + tracker_state = TRACKING; + } + break; + + default: + break; + } + + target_.is_tracking = (tracker_state == TRACKING || tracker_state == TEMP_LOST); + + if (found) + ++found_count_; + } + + bool initTarget(const Armors& armors) noexcept { + if (armors.armors.empty()) { + return false; + } + bool found = false; + Armor init_target; + Armors others = armors; + others.armors.clear(); + for (auto& a: armors.armors) { + if (!a.is_none_purple && !found) { + init_target = a; + found = true; + continue; + } + others.armors.push_back(a); + } + if (!found) { + return false; + } + target_ = Target(init_target, target_config_); + // updateTarget(others); + tracker_state = DETECTING; + return true; + } + bool updateTarget(const Armors& armors) noexcept { + if (armors.armors.empty()) + return false; + + target_.predict(armors.timestamp); + std::vector candidates; + candidates.reserve(armors.armors.size()); + + for (const auto& a: armors.armors) { + if (isSameTarget(a.number, target_.tracked_id_) && !a.is_none_purple) { + candidates.emplace_back(a); + } + } + + if (candidates.empty()) + return false; + + int updated = 0; + const auto matches = target_.match(candidates); + + for (const auto& m: matches) { + if (m.second.is_none_purple) { + if (++is_none_purple_count_ > 100) + continue; + } else { + is_none_purple_count_ = 0; + } + + if (target_.update(m)) + ++updated; + } + + return updated > 0; + } + + int lost_thres_; + int detect_count_ = 0; + int lost_count_ = 0; + int is_none_purple_count_ = 0; + int found_count_ = 0; + Target target_; + std::chrono::steady_clock::time_point last_time_; + TargetConfig::Ptr target_config_; + + double orientationToYaw(const Eigen::Quaterniond& q, double from) noexcept { + double roll, pitch, yaw; + Eigen::Vector3d euler = utils::quatToEuler(q, utils::EulerOrder::ZYX, false); + yaw = euler[0]; + yaw = from + angles::shortest_angular_distance(from, yaw); + return yaw; + } + }; + Tracker::Tracker(wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter) { + _impl = std::make_unique(auto_aim_config_parameter); + } + Tracker::~Tracker() { + _impl.reset(); + } + Target Tracker::track(const Armors& armors) noexcept { + return _impl->track(armors); + } + int Tracker::getFoundCount() const noexcept { + return _impl->found_count_; + } + void Tracker::setFoundCount(int count) noexcept { + _impl->found_count_ = count; + } + std::chrono::steady_clock::time_point Tracker::getLastTime() const noexcept { + return _impl->last_time_; + } + void Tracker::setLastTime(std::chrono::steady_clock::time_point t) noexcept { + _impl->last_time_ = t; + } + +} // namespace auto_aim +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_tracker/trackerv3.hpp b/wust_vision-main/tasks/auto_aim/armor_tracker/trackerv3.hpp new file mode 100644 index 0000000..9f5e161 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_tracker/trackerv3.hpp @@ -0,0 +1,23 @@ +#pragma once + +#include "target.hpp" +namespace wust_vision { +namespace auto_aim { + class Tracker { + public: + using Ptr = std::unique_ptr; + Tracker(wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter); + static Ptr create(wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter) { + return std::make_unique(auto_aim_config_parameter); + } + ~Tracker(); + [[nodiscard]] Target track(const Armors& armors) noexcept; + int getFoundCount() const noexcept; + void setFoundCount(int count) noexcept; + void setLastTime(std::chrono::steady_clock::time_point t) noexcept; + std::chrono::steady_clock::time_point getLastTime() const noexcept; + struct Impl; + std::unique_ptr _impl; + }; +} // namespace auto_aim +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_where/armor_where.cpp b/wust_vision-main/tasks/auto_aim/armor_where/armor_where.cpp new file mode 100644 index 0000000..7269f51 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_where/armor_where.cpp @@ -0,0 +1,307 @@ +// Created by Labor 2023.8.25 +// Maintained by Chengfu Zou, Labor +// Copyright (C) FYT Vision Group. All rights reserved. +// Copyright 2025 Xiaojian Wu +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "armor_where.hpp" +#include "wust_vl/algorithm/pnp_solver.hpp" +#include +namespace wust_vision { +namespace auto_aim { + + struct ArmorWhere::Impl { + public: + Impl(const YAML::Node& config, const std::pair& camera_info) { + camera_info_ = camera_info; + params_.load(config); + pnp_solver_ = std::make_unique(cv::SOLVEPNP_IPPE); + pnp_solver_->setObjectPoints( + "small", + ArmorObject::buildObjectPoints(SMALL_ARMOR_WIDTH, SMALL_ARMOR_HEIGHT) + ); + pnp_solver_->setObjectPoints( + "large", + ArmorObject::buildObjectPoints(LARGE_ARMOR_WIDTH, LARGE_ARMOR_HEIGHT) + ); + } + struct Params { + enum class OptMode : int { GOLDEN = 0, CERES = 1, NONE = 2 } opt_mode; + OptMode fromString(const std::string& mode) { + if (mode == "golden" || mode == "GOLDEN") { + return OptMode::GOLDEN; + } else if (mode == "none" || mode == "NONE") { + return OptMode::NONE; + } else { + return OptMode::NONE; + } + } + int golden_search_side_deg = 60; + double distance_fix_a2 = 0; + void load(const YAML::Node& node) { + opt_mode = fromString(node["yaw_opt"]["mode"].as()); + golden_search_side_deg = node["yaw_opt"]["golden_search_side_deg"].as(); + distance_fix_a2 = node["distance_fix_a2"].as(); + } + } params_; + + std::vector where( + const std::vector& armors, + Eigen::Matrix4d T_camera_to_odom + ) const noexcept { + std::vector armors_msg; + + const Eigen::Matrix3d R_imu_cam = T_camera_to_odom.block<3, 3>(0, 0); + auto makeArmor = + [&](const ArmorObject& obj, const Eigen::Vector3d& t, const Eigen::Matrix3d& R) { + Armor msg; + msg.type = (obj.number == ArmorNumber::NO1 || obj.number == ArmorNumber::BASE) + ? "large" + : "small"; + msg.number = obj.number; + Eigen::Quaterniond q(R); + Eigen::Quaterniond add_roll { + Eigen::AngleAxisd(M_PI / 2, Eigen::Vector3d::UnitX()) + }; + Eigen::Quaterniond new_q = q * add_roll; + + auto [yaw, pitch, dist] = utils::xyz2ypd_rad(t.x(), t.y(), t.z()); + dist += params_.distance_fix_a2 * dist * dist; + auto [x, y, z] = utils::ypd2xyz_rad(yaw, pitch, dist); + + msg.pos = { x, y, z }; + msg.ori = new_q; + auto_aim::transformArmorData(msg, T_camera_to_odom); + msg.distance_to_image_center = + pnp_solver_->calculateDistanceToCenter(obj.center, camera_info_.first); + msg.is_ok = obj.is_ok; + if (obj.color == ArmorColor::NONE || obj.color == ArmorColor::PURPLE) { + msg.is_none_purple = true; + } else { + msg.is_none_purple = false; + } + return msg; + }; + + for (auto const& a: armors) { + cv::Mat rvec, tvec; + std::string type = (a.number == ArmorNumber::NO1 || a.number == ArmorNumber::BASE) + ? "large" + : "small"; + + if (!pnp_solver_->solvePnP( + a.landmarks(), + rvec, + tvec, + type, + camera_info_.first, + camera_info_.second + )) + { + WUST_WARN("PNP") << "PNP failed"; + continue; + } + cv::Mat R_cv; + cv::Rodrigues(rvec, R_cv); + Eigen::Matrix3d R = utils::cvToEigen(R_cv); + Eigen::Vector3d t = utils::cvToEigen(tvec); + if (params_.opt_mode != Params::OptMode::NONE) { + Eigen::Matrix3d R0 = R; + R = solveBa_R(a, t, R0, R_imu_cam, type); + } + + armors_msg.push_back(makeArmor(a, t, R)); + } + + return armors_msg; + } + std::vector reprojectionArmor( + double yaw, + const std::vector& object_points, + const std::vector& landmarks, + const Eigen::Matrix3d& Rci, + double pitch, + double roll, + const Eigen::Vector3d& t + ) const noexcept { + const Eigen::AngleAxisd ay(yaw, Eigen::Vector3d::UnitZ()); + const Eigen::AngleAxisd ap(pitch, Eigen::Vector3d::UnitY()); + const Eigen::AngleAxisd ar(roll, Eigen::Vector3d::UnitX()); + const Eigen::Matrix3d R = Rci * (ay * ap * ar).toRotationMatrix(); + + cv::Mat rvec, R_cv; + cv::eigen2cv(R, R_cv); + cv::Rodrigues(R_cv, rvec); + + const cv::Mat tvec = (cv::Mat_(3, 1) << t.x(), t.y(), t.z()); + + std::vector pts_2d; + pts_2d.reserve(object_points.size()); + cv::projectPoints( + object_points, + rvec, + tvec, + camera_info_.first, + camera_info_.second, + pts_2d + ); + + std::vector image_points; + image_points.reserve(pts_2d.size()); + + for (const auto& p: pts_2d) { + image_points.emplace_back(p.x, p.y); + } + + return image_points; + } + + double reprojectionErrorYaw( + double yaw, + const std::vector& object_points, + const std::vector& landmarks, + const std::vector>& sym_pairs, + const Eigen::Matrix3d& Rci, + double pitch, + double roll, + const Eigen::Vector3d& t + ) const noexcept { + const auto image_points = + reprojectionArmor(yaw, object_points, landmarks, Rci, pitch, roll, t); + double cost = 0.0; + + // for (size_t i = 0; i < image_points.size(); ++i) { + // Eigen::Vector2d obs(landmarks[i].x, landmarks[i].y); + // cost += (image_points[i] - obs).squaredNorm(); + // } + + for (auto& p: sym_pairs) { + const Eigen::Vector2d mid = 0.5 * (image_points[p.first] + image_points[p.second]); + + const Eigen::Vector2d meas = 0.5 + * (Eigen::Vector2d(landmarks[p.first].x, landmarks[p.first].y) + + Eigen::Vector2d(landmarks[p.second].x, landmarks[p.second].y)); + + cost += (mid - meas).squaredNorm(); + } + return cost; + } + + double goldenYaw( + double init, + const std::vector& obj, + const std::vector& lm, + const std::vector>& sym_pairs, + const Eigen::Matrix3d& Rci, + double pitch, + double roll, + const Eigen::Vector3d& t + ) const noexcept { + constexpr double phi = 1.618033988749894848; //(1.0 + std::sqrt(5.0)) * 0.5; + double l = init - params_.golden_search_side_deg * M_PI / 180.0; + double r = init + params_.golden_search_side_deg * M_PI / 180.0; + + double y1 = r - (r - l) / phi; + double y2 = l + (r - l) / phi; + + double f1 = reprojectionErrorYaw(y1, obj, lm, sym_pairs, Rci, pitch, roll, t); + double f2 = reprojectionErrorYaw(y2, obj, lm, sym_pairs, Rci, pitch, roll, t); + + while (r - l > 0.0001) { + if (f1 < f2) { + r = y2; + y2 = y1; + f2 = f1; + y1 = r - (r - l) / phi; + f1 = reprojectionErrorYaw(y1, obj, lm, sym_pairs, Rci, pitch, roll, t); + } else { + l = y1; + y1 = y2; + f1 = f2; + y2 = l + (r - l) / phi; + f2 = reprojectionErrorYaw(y2, obj, lm, sym_pairs, Rci, pitch, roll, t); + } + } + + return 0.5 * (l + r); + } + + Eigen::Matrix3d solveBa_R( + const ArmorObject& armor, + const Eigen::Vector3d& t_camera_armor, + const Eigen::Matrix3d& R_camera_armor, + const Eigen::Matrix3d& R_imu_camera, + const std::string& type + ) const noexcept { + const Eigen::Matrix3d R_imu_armor = R_imu_camera * R_camera_armor; + const Eigen::Matrix3d R_camera_imu = R_imu_camera.transpose(); + //double roll = std::atan2(R_imu_armor(2, 1), R_imu_armor(2, 2)); + const double roll = 0; + // initial yaw + const double yaw_init = std::atan2(-R_imu_armor(0, 1), R_imu_armor(1, 1)); + + const double armor_pitch = + (armor.number == ArmorNumber::OUTPOST) ? -FIFTTEN_DEGREE_RAD : FIFTTEN_DEGREE_RAD; + + const Eigen::Vector2d armor_size = (type == "large") + ? Eigen::Vector2d { LARGE_ARMOR_WIDTH, LARGE_ARMOR_HEIGHT } + : Eigen::Vector2d { SMALL_ARMOR_WIDTH, SMALL_ARMOR_HEIGHT }; + + const auto objPts = + ArmorObject::buildObjectPoints(armor_size.x(), armor_size.y()); + const auto& lm = armor.landmarks(); + const auto& sym_pairs = ArmorObject::buildSymPairs(); + double yaw = yaw_init; + if (params_.opt_mode == Params::OptMode::GOLDEN) { + yaw = goldenYaw( + yaw_init, + objPts, + lm, + sym_pairs, + R_camera_imu, + armor_pitch, + roll, + t_camera_armor + ); + } + + const Eigen::AngleAxisd ay(yaw, Eigen::Vector3d::UnitZ()); + const Eigen::AngleAxisd ap(armor_pitch, Eigen::Vector3d::UnitY()); + const Eigen::AngleAxisd ar(roll, Eigen::Vector3d::UnitX()); + + const Eigen::Matrix3d R_result = R_camera_imu * (ay * ap * ar).toRotationMatrix(); + return R_result; + } + + private: + std::pair camera_info_; + std::unique_ptr pnp_solver_; + }; + ArmorWhere::ArmorWhere( + const YAML::Node& config, + const std::pair& camera_info + ) { + _impl = std::make_unique(config, camera_info); + } + ArmorWhere::~ArmorWhere() { + _impl.reset(); + } + std::vector ArmorWhere::where( + const std::vector& armors, + Eigen::Matrix4d T_camera_to_odom + ) const noexcept { + return _impl->where(armors, T_camera_to_odom); + } +} // namespace auto_aim +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/armor_where/armor_where.hpp b/wust_vision-main/tasks/auto_aim/armor_where/armor_where.hpp new file mode 100644 index 0000000..298fc8b --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/armor_where/armor_where.hpp @@ -0,0 +1,26 @@ + +#pragma once +#include "tasks/auto_aim/type.hpp" +namespace wust_vision { +namespace auto_aim { + class ArmorWhere { + public: + using Ptr = std::unique_ptr; + ArmorWhere(const YAML::Node& config, const std::pair& camera_info); + static Ptr + create(const YAML::Node& config, const std::pair& camera_info) { + return std::make_unique(config, camera_info); + } + ~ArmorWhere(); + std::vector where( + const std::vector& armors, + Eigen::Matrix4d T_camera_to_odom + ) const noexcept; + + private: + struct Impl; + std::unique_ptr _impl; + }; + +} // namespace auto_aim +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/auto_aim.cpp b/wust_vision-main/tasks/auto_aim/auto_aim.cpp new file mode 100644 index 0000000..8a08922 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/auto_aim.cpp @@ -0,0 +1,404 @@ +#include "auto_aim.hpp" +#include "tasks/auto_aim/armor_control/very_aimer.hpp" +#include "tasks/auto_aim/armor_detect/armor_detector_factory.hpp" +#include "tasks/auto_aim/armor_tracker/target.hpp" +#include "tasks/auto_aim/armor_tracker/trackerv3.hpp" +#include "tasks/auto_aim/armor_where/armor_where.hpp" +#include "tasks/auto_aim/debug.hpp" +#include "tasks/type_common.hpp" +#include "tasks/utils/config.hpp" +#include "wust_vl/common/concurrency/queues.hpp" + +namespace wust_vision { +namespace auto_aim { + + struct AutoAim::Impl { + ~Impl() { + run_flag_ = false; + if (armor_detector_) { + armor_detector_.reset(); + } + armor_queue_->stop(); + if (processing_thread_) { + processing_thread_->stop(); + wust_vl::common::concurrency::ThreadManager::instance().unregisterThread( + processing_thread_->getName() + ); + } + } + Impl( + const std::string& config_path, + TFConfig::Ptr tf_config, + const std::pair& camera_info, + bool debug + ) { + tf_config_ = tf_config; + camera_info_ = camera_info; + auto_aim_config_parameter_ = wust_vl::common::utils::Parameter::create(); + auto config = YAML::LoadFile(config_path); + auto_aim_config_parameter_->loadFromFile(config_path); + auto_exposure_cfg_ = AutoExposureCfg::create(); + very_aimer_ = VeryAimer::create(auto_aim_config_parameter_); + auto_aim_config_parameter_->registerGroup(*auto_exposure_cfg_); + auto_aim_config_parameter_->registerGroup(*auto_aim_fsm_cl_.config_); + tracker_ = Tracker::create(auto_aim_config_parameter_); + auto_aim_config_parameter_->reloadFromOldPath(); + + wust_vl::common::utils::ParameterManager::instance().registerParameter( + *auto_aim_config_parameter_.get() + ); + max_detect_armors_ = config["max_detect_armors"].as(10); + armor_where_ = ArmorWhere::create(config["armor_where"], camera_info_); + const std::string armor_detect_backend = + config["armor_detect_backend"].as(""); + armor_detector_ = DetectorFactory::createArmorDetector(armor_detect_backend, true); + armor_detector_->setCallback(std::bind( + &AutoAim::Impl::ArmorDetectCallback, + this, + std::placeholders::_1, + std::placeholders::_2 + )); + WUST_MAIN(logger_) << "Using Armor Detector: " << armor_detect_backend; + armor_queue_ = + std::make_unique>(50, 500); + latency_averager_ = + std::make_unique>(100); + } + void start() { + if (run_flag_) { + return; + } + run_flag_ = true; + processing_thread_ = wust_vl::common::concurrency::MonitoredThread::create( + "AutoAimProcessingThread", + [this](wust_vl::common::concurrency::MonitoredThread::Ptr self) { + this->processingLoop(self); + } + ); + wust_vl::common::concurrency::ThreadManager::instance().registerThread( + processing_thread_ + ); + } + void pushInput(CommonFrame& frame) { + img_recv_count_++; + + auto bbox = target_.expanded( + T_camera_to_odom_, + camera_info_.first, + camera_info_.second, + frame.img_frame.src_img.size() + ); + if (bbox.area() > 100) { + frame.expanded = bbox; + frame.offset = cv::Point2f(bbox.x, bbox.y); + } + expanded_ = frame.expanded; + const std::optional target_number = target_.getArmorNumber(); + if (armor_detector_) { + armor_detector_->pushInput(frame, target_number); + } + } + + void ArmorDetectCallback(const std::vector& objs, const CommonFrame& frame) { + std::vector sorted_objs = objs; + + if (sorted_objs.size() > max_detect_armors_) [[unlikely]] { + WUST_WARN(logger_) << "Detected " << sorted_objs.size() + << " objects, too many, keeping top " << max_detect_armors_; + + std::partial_sort( + sorted_objs.begin(), + sorted_objs.begin() + max_detect_armors_, + sorted_objs.end(), + [](const ArmorObject& a, const ArmorObject& b) { + return a.confidence > b.confidence; + } + ); + + sorted_objs.resize(max_detect_armors_); + } + for (auto& obj: sorted_objs) { + obj.addOffset(frame.offset); + } + + Armors armors; + armors.timestamp = frame.img_frame.timestamp; + armors.id = frame.id; + Eigen::Vector3d v = Eigen::Vector3d::Zero(); + Eigen::Matrix3d R_gimbal2odom = Eigen::Matrix3d::Identity(); + auto ctx = std::any_cast(frame.any_ctx); + std::pair gimbal_py; + if (ctx.motion_buffer) { + const auto delay = + std::chrono::microseconds(static_cast(ctx.communication_delay_μs)); + const auto t_query = armors.timestamp + delay; + auto apply_motion = [&](const auto& att) { + v << att.data.vx, att.data.vy, att.data.vz; + R_gimbal2odom = Eigen::AngleAxisd(att.data.yaw, Eigen::Vector3d::UnitZ()) + * Eigen::AngleAxisd(-att.data.pitch, Eigen::Vector3d::UnitY()) + * Eigen::AngleAxisd(att.data.roll, Eigen::Vector3d::UnitX()); + gimbal_py = std::make_pair(att.data.pitch, att.data.yaw); + }; + if (auto past_att = ctx.motion_buffer->get_interpolated(t_query)) { + apply_motion(*past_att); + } else if (auto last_att = ctx.motion_buffer->get_last()) { + apply_motion(*last_att); + } + } + autoExposureControl(frame.img_frame.src_img, ctx.camera); + T_camera_to_odom_ = utils::computeCameraToOdomTransform( + R_gimbal2odom, + tf_config_->R_camera2gimbal, + tf_config_->t_camera2gimbal + ); + armors.armors = armor_where_->where(sorted_objs, T_camera_to_odom_); + armors.v = v; + for (auto& armor: armors.armors) { + armor.timestamp = armors.timestamp; + } + + armor_queue_->enqueue(armors); + ++detect_finish_count_; + if (debug_mode_) { + std::lock_guard lock(dbg_mutex_); + auto& dbg = auto_aim_debug_; + + dbg.img_frame = frame.img_frame; + dbg.armors = armors; + dbg.T_camera_to_odom = T_camera_to_odom_; + dbg.detect_color = frame.detect_color; + dbg.armor_objs = sorted_objs; + dbg.expanded = frame.expanded; + dbg.gimbal_py = gimbal_py; + } + } + void armorsCallback(const Armors& armors) { + if (armors.timestamp <= tracker_->getLastTime()) { + WUST_WARN(logger_) << "Received out-of-order armor data, discarded."; + return; + } + Target target = tracker_->track(armors); + auto_aim_fsm_cl_.update(std::abs(target.target_state_.vyaw()), target.jumped); + const auto now = std::chrono::steady_clock::now(); + { + std::lock_guard lock(target_mutex_); + target_ = target; + } + const auto latency_ms = + wust_vl::common::utils::time_utils::durationMs(armors.timestamp, now); + latency_averager_->add(latency_ms); + auto& dbg = auto_aim_debug_; + dbg.latency_ms = latency_averager_->average(); + if (debug_mode_) { + std::lock_guard lock(dbg_mutex_); + dbg.target = target; + dbg.fsm = auto_aim_fsm_cl_.fsm_state_; + } + } + Target getTarget() { + Target target; + { + std::lock_guard lock(target_mutex_); + target = target_; + } + return target; + } + GimbalCmd solve(double bullet_speed) { + GimbalCmd gimbal_cmd; + Target target; + { + std::lock_guard lock(target_mutex_); + target = target_; + } + AimTarget aim_target; + const bool appear = target.checkTargetAppear(); + if (appear && target.target_state_.pos().norm() > 0.1) { + try { + gimbal_cmd = + very_aimer_->veryAim(target, bullet_speed, auto_aim_fsm_cl_.fsm_state_); + aim_target = gimbal_cmd.aim_target; + } catch (...) { + WUST_ERROR(logger_) << "VeryAim error"; + } + } + if (gimbal_cmd.fire_advice) { + fire_count_++; + } + if (debug_mode_) { + std::lock_guard lock(dbg_mutex_); + auto_aim_debug_.gimbal_cmd = gimbal_cmd; + auto_aim_debug_.aim_target = aim_target; + } + timer_cout_++; + return gimbal_cmd; + } + void processingLoop(wust_vl::common::concurrency::MonitoredThread::Ptr self) { + while (!self->isAlive()) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + while (self->isAlive() && run_flag_) { + if (!self->waitPoint()) + break; + self->heartbeat(); + printStats(); + Armors armors; + // bool skip; + // if (armor_queue_->dequeue_wait(armors, skip)) { + // armorsCallback(armors); + // tracker_finish_count_++; + // if (skip) { + // WUST_DEBUG(logger_) << "OrderQueue skip"; + // } + // } + if (armor_queue_->try_dequeue(armors)) { + armorsCallback(armors); + tracker_finish_count_++; + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + } + } + void doDebug() { + debug_mode_ = true; + AutoAimDebug dbg; + { + std::lock_guard lock(dbg_mutex_); + dbg = auto_aim_debug_; + } + drawDebugOverlayShm(dbg, camera_info_, false); + debuglog(dbg); + } + + void printStats() { + utils::XSecOnce( + [&] { + double found_ratio = 0.0; + if (img_recv_count_ > 0) { + found_ratio = static_cast(tracker_->getFoundCount()) + / static_cast(img_recv_count_); + } + + WUST_INFO(logger_) + << "Rec: " << img_recv_count_ << ", Det: " << detect_finish_count_ + << ", Fin: " << tracker_finish_count_ << ", Tc: " << timer_cout_ + << ", Lat: " << auto_aim_debug_.latency_ms << "ms" + << ", Fire: " << fire_count_ << ", Found: " << tracker_->getFoundCount() + << ", Found_ratio: " << found_ratio; + + img_recv_count_ = 0; + detect_finish_count_ = 0; + fire_count_ = 0; + tracker_finish_count_ = 0; + timer_cout_ = 0; + tracker_->setFoundCount(0); + }, + 1.0 + ); + } + void + autoExposureControl(const cv::Mat& frame, std::shared_ptr camera) { + const double dt = auto_exposure_cfg_->control_interval_ms_param.get() / 1000.0; + utils::XSecOnce( + [&] { + if (!auto_exposure_cfg_->enable_param.get() || frame.empty()) { + return; + } + if (auto* hik = dynamic_cast(camera->getDevice())) { + cv::Mat i_use = frame(expanded_); + if (expanded_.area() < 100 || i_use.empty()) { + i_use = frame; + } + const double brightness = utils::computeBrightness(i_use); + + const double diff = + brightness - auto_exposure_cfg_->target_brightness_param.get(); + const double exposure_min = auto_exposure_cfg_->exposure_min_param.get(); + const double exposure_max = auto_exposure_cfg_->exposure_max_param.get(); + double exposure_time = hik->getExposureTime(); + const double last_exposure_time = exposure_time; + if (std::fabs(diff) > auto_exposure_cfg_->tolerance_param.get() + && exposure_time > 0.0) { + exposure_time -= diff * auto_exposure_cfg_->step_gain_param.get(); + } else { + exposure_time -= auto_exposure_cfg_->decay_step_param.get(); + } + if (exposure_time < exposure_min) + exposure_time = exposure_min; + if (exposure_time > exposure_max) + exposure_time = exposure_max; + if (std::abs(exposure_time - last_exposure_time) > 10) { + hik->setExposureTime(exposure_time); + } + } + }, + dt + ); + } + wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter_; + Tracker::Ptr tracker_; + ArmorDetectorBase::Ptr armor_detector_; + std::string logger_ = "auto_aim"; + std::unique_ptr> armor_queue_; + wust_vl::common::concurrency::MonitoredThread::Ptr processing_thread_; + std::unique_ptr timer_; + VeryAimer::Ptr very_aimer_; + ArmorWhere::Ptr armor_where_; + AutoAimFsmController auto_aim_fsm_cl_; + AutoExposureCfg::Ptr auto_exposure_cfg_; + + cv::Rect expanded_; + int max_detect_armors_; + bool run_flag_ = false; + int detect_finish_count_ = 0; + int img_recv_count_ = 0; + int tracker_finish_count_ = 0; + int fire_count_ = 0; + int timer_cout_ = 0; + Target target_; + bool debug_mode_ = false; + AutoAimDebug auto_aim_debug_; + std::unique_ptr> latency_averager_; + TFConfig::Ptr tf_config_; + std::pair camera_info_; + Eigen::Matrix4d T_camera_to_odom_; + std::mutex target_mutex_; + std::mutex dbg_mutex_; + }; + AutoAim::AutoAim( + const std::string& config_path, + TFConfig::Ptr tf_config, + const std::pair& camera_info, + bool debug + ): + _impl(std::make_unique(config_path, tf_config, camera_info, debug)) {} + AutoAim::~AutoAim() { + _impl.reset(); + } + + void AutoAim::start() { + _impl->start(); + } + void AutoAim::pushInput(CommonFrame& frame) { + _impl->pushInput(frame); + } + + GimbalCmd AutoAim::solve(double bullet_speed) { + return _impl->solve(bullet_speed); + } + wust_vl::common::concurrency::MonitoredThread::Ptr AutoAim::getThread() { + return _impl->processing_thread_; + } + Target AutoAim::getTarget() { + return _impl->getTarget(); + } + void AutoAim::doDebug() { + _impl->doDebug(); + } + wust_vl::common::utils::Parameter::Ptr AutoAim::getParameter() { + return _impl->auto_aim_config_parameter_; + } + VeryAimer::Ptr AutoAim::getVeryAimer() { + return _impl->very_aimer_; + } +} // namespace auto_aim +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/auto_aim.hpp b/wust_vision-main/tasks/auto_aim/auto_aim.hpp new file mode 100644 index 0000000..ab1d54d --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/auto_aim.hpp @@ -0,0 +1,46 @@ +#pragma once + +#include "tasks/auto_aim/armor_control/very_aimer.hpp" +#include "tasks/auto_aim/armor_tracker/target.hpp" +#include "tasks/imodule.hpp" +#include "tasks/type_common.hpp" +#include "wust_vl/video/camera.hpp" +#include +namespace wust_vision { +namespace auto_aim { + + class AutoAim: public IModule { + public: + using Ptr = std::shared_ptr; + AutoAim( + const std::string& config_path, + TFConfig::Ptr tf_config, + const std::pair& camera_info, + bool debug + ); + static Ptr create( + const std::string& config_path, + TFConfig::Ptr tf_config, + const std::pair& camera_info, + bool debug + ) { + return std::make_shared(config_path, tf_config, camera_info, debug); + } + ~AutoAim(); + + void start() override; + void doDebug() override; + void pushInput(CommonFrame& frame) override; + Target getTarget(); + GimbalCmd solve(double bullet_speed) override; + wust_vl::common::concurrency::MonitoredThread::Ptr getThread() override; + wust_vl::common::utils::Parameter::Ptr getParameter(); + VeryAimer::Ptr getVeryAimer(); + struct Impl; + std::unique_ptr _impl; + }; + inline AutoAim::Ptr toAutoAim(IModule::Ptr module) { + return std::dynamic_pointer_cast(module); + } +} // namespace auto_aim +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/auto_aim_fsm.hpp b/wust_vision-main/tasks/auto_aim/auto_aim_fsm.hpp new file mode 100644 index 0000000..87036da --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/auto_aim_fsm.hpp @@ -0,0 +1,134 @@ +#pragma once + +#include "wust_vl/common/utils/parameter.hpp" +namespace wust_vision { +namespace auto_aim { + enum class AutoAimFsm { + AIM_WHOLE_CAR_ARMOR, + AIM_WHOLE_CAR_CENTER, + AIM_SINGLE_ARMOR, + AIM_WHOLE_CAR_PAIR + }; + + inline std::string auto_aim_fsm_to_string(AutoAimFsm state) { + switch (state) { + case AutoAimFsm::AIM_WHOLE_CAR_ARMOR: + return "AIM_WHOLE_CAR_ARMOR"; + case AutoAimFsm::AIM_WHOLE_CAR_CENTER: + return "AIM_WHOLE_CAR_CENTER"; + case AutoAimFsm::AIM_SINGLE_ARMOR: + return "AIM_SINGLE_ARMOR"; + case AutoAimFsm::AIM_WHOLE_CAR_PAIR: + return "AIM_WHOLE_CAR_PAIR"; + default: + return "UNKNOWN"; + } + } + + class AutoAimFsmController { + public: + AutoAimFsmController() { + config_ = std::make_shared(); + } + AutoAimFsm fsm_state_ { AutoAimFsm::AIM_SINGLE_ARMOR }; + struct AutoAimFsmConfig: wust_vl::common::utils::ParamGroup { + public: + static constexpr const char* Logger = "Config: auto_aim::auto_aim_fsm"; + static constexpr const char* kKey = "auto_aim_fsm"; + const char* key() const override { + return kKey; + } + using Ptr = std::shared_ptr; + AutoAimFsmConfig() {} + GEN_PARAM(int, transfer_thresh); + GEN_PARAM(double, single_whole_up); + GEN_PARAM(double, single_whole_down); + GEN_PARAM(double, whole_pair_up); + GEN_PARAM(double, whole_pair_down); + GEN_PARAM(double, pair_center_up); + GEN_PARAM(double, pair_center_down); + void loadSelf(const YAML::Node& node) override { + transfer_thresh_param.load(node); + single_whole_up_param.load(node); + single_whole_down_param.load(node); + whole_pair_up_param.load(node); + whole_pair_down_param.load(node); + pair_center_up_param.load(node); + pair_center_down_param.load(node); + } + }; + AutoAimFsmConfig::Ptr config_; + int overflow_count_ = 0; + + void update(double v_yaw, bool target_jumped) { + // 无跳变:直接退回单装甲,并清状态 + if (!target_jumped) { + fsm_state_ = AutoAimFsm::AIM_SINGLE_ARMOR; + overflow_count_ = 0; + return; + } + + const double av = std::abs(v_yaw); + + switch (fsm_state_) { + case AutoAimFsm::AIM_SINGLE_ARMOR: { + overflow_count_ = + (av > config_->single_whole_up_param.get()) ? overflow_count_ + 1 : 0; + if (overflow_count_ > config_->transfer_thresh_param.get()) { + fsm_state_ = AutoAimFsm::AIM_WHOLE_CAR_ARMOR; + overflow_count_ = 0; + } + break; + } + + case AutoAimFsm::AIM_WHOLE_CAR_ARMOR: { + if (av > config_->whole_pair_up_param.get()) + ++overflow_count_; + else if (av < config_->single_whole_down_param.get()) + --overflow_count_; + else + overflow_count_ = 0; + + if (std::abs(overflow_count_) > config_->transfer_thresh_param.get()) { + fsm_state_ = (overflow_count_ > 0) ? AutoAimFsm::AIM_WHOLE_CAR_PAIR + : AutoAimFsm::AIM_SINGLE_ARMOR; + overflow_count_ = 0; + } + break; + } + + case AutoAimFsm::AIM_WHOLE_CAR_PAIR: { + if (av > config_->pair_center_up_param.get()) + ++overflow_count_; + else if (av < config_->whole_pair_down_param.get()) + --overflow_count_; + else + overflow_count_ = 0; + + if (std::abs(overflow_count_) > config_->transfer_thresh_param.get()) { + fsm_state_ = (overflow_count_ > 0) ? AutoAimFsm::AIM_WHOLE_CAR_CENTER + : AutoAimFsm::AIM_WHOLE_CAR_ARMOR; + overflow_count_ = 0; + } + break; + } + + case AutoAimFsm::AIM_WHOLE_CAR_CENTER: { + overflow_count_ = + (av < config_->pair_center_down_param.get()) ? overflow_count_ + 1 : 0; + if (overflow_count_ > config_->transfer_thresh_param.get()) { + fsm_state_ = AutoAimFsm::AIM_WHOLE_CAR_PAIR; + overflow_count_ = 0; + } + break; + } + + default: + fsm_state_ = AutoAimFsm::AIM_SINGLE_ARMOR; + overflow_count_ = 0; + break; + } + } + }; +} // namespace auto_aim +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/debug.cpp b/wust_vision-main/tasks/auto_aim/debug.cpp new file mode 100644 index 0000000..8a65a57 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/debug.cpp @@ -0,0 +1,684 @@ +#include "debug.hpp" +namespace wust_vision::auto_aim { +void drawDebugArmorContent( + cv::Mat& debug_img, + const AutoAimDebug& dbg, + std::pair camera_info +) { + if (debug_img.empty()) { + std::cout << "debug_img is empty" << std::endl; + return; + } + const auto now = std::chrono::steady_clock::now(); + const auto& armors = dbg.armors; + const auto& gimbal_cmd = dbg.gimbal_cmd; + const auto& target = dbg.target; + auto aim_target = dbg.aim_target; + const auto& armor_objs = dbg.armor_objs; + const cv::Rect img_rect(0, 0, debug_img.cols, debug_img.rows); + const cv::Rect roi = dbg.expanded & img_rect; + cv::rectangle(debug_img, roi, cv::Scalar(255, 255, 255), 2); + + static const int next_indices[] = { 2, 0, 3, 1 }; + + for (size_t i = 0; i < armor_objs.size(); i++) { + const auto pts = armor_objs[i].toPts(); + + for (size_t j = 0; j < 4; ++j) { + const cv::Scalar color = + armor_objs[i].is_ok ? cv::Scalar(50, 255, 50) : cv::Scalar(50, 255, 255); + cv::line(debug_img, pts[j], pts[next_indices[j]], color, 2); + } + + auto armorName = [](auto_aim::ArmorNumber num) { + switch (num) { + case auto_aim::ArmorNumber::SENTRY: + return "SENTRY"; + case auto_aim::ArmorNumber::BASE: + return "BASE"; + case auto_aim::ArmorNumber::OUTPOST: + return "OUTPOST"; + case auto_aim::ArmorNumber::NO1: + return "NO1"; + case auto_aim::ArmorNumber::NO2: + return "NO2"; + case auto_aim::ArmorNumber::NO3: + return "NO3"; + case auto_aim::ArmorNumber::NO4: + return "NO4"; + case auto_aim::ArmorNumber::NO5: + return "NO5"; + default: + return "UNKNOWN"; + } + }; + + const std::string armor_str = armorName(armor_objs[i].number); + cv::putText( + debug_img, + armor_str, + pts[1] + cv::Point2f(0, 50), + cv::FONT_HERSHEY_SIMPLEX, + 0.7, + cv::Scalar(0, 200, 200), + 2 + ); + } + + const std::string latency_str = fmt::format("Latency: {:.2f}ms", dbg.latency_ms); + cv::putText( + debug_img, + latency_str, + cv::Point(10, 30), + cv::FONT_HERSHEY_SIMPLEX, + 0.8, + cv::Scalar(255, 255, 255), + 2 + ); + + { + // static std::deque> traj3d; + + // double _now = + // std::chrono::duration(std::chrono::steady_clock::now().time_since_epoch()).count(); + + // traj3d.emplace_back(aim_target.pos, _now); + + // while (!traj3d.empty() && _now - traj3d.front().second > 1.0) + // traj3d.pop_front(); + aim_target.tf(dbg.T_camera_to_odom.inverse()); + const auto pts = aim_target.toPts(camera_info.first, camera_info.second); + + if (!pts.empty()) { + // if (traj3d.size() > 1) { + // std::vector> img_pts; + + // for (auto& p: traj3d) { + // auto p3d_odom = p.first; + + // Eigen::Vector4d p_odom(p3d_odom.x(), p3d_odom.y(), p3d_odom.z(), 1); + + // Eigen::Vector4d p_camera = dbg.T_camera_to_odom.inverse() * p_odom; + + // std::vector obj; + // obj.emplace_back(p_camera.x(), p_camera.y(), p_camera.z()); + + // std::vector proj; + + // cv::projectPoints( + // obj, + // cv::Vec3d(0, 0, 0), + // cv::Vec3d(0, 0, 0), + // camera_info.first, + // camera_info.second, + // proj + // ); + + // if (!proj.empty()) { + // const auto& pt = proj[0]; + + // if (std::isfinite(pt.x) && std::isfinite(pt.y)) { + // img_pts.emplace_back(cv::Point(int(pt.x), int(pt.y)), p.second); + // } + // } + // } + + // if (img_pts.size() >= 2) { + // double now = std::chrono::duration( + // std::chrono::steady_clock::now().time_since_epoch() + // ) + // .count(); + + // const double max_age = 1.0; + + // for (size_t i = 1; i < img_pts.size(); ++i) { + // double age = now - img_pts[i].second; + + // double t = std::clamp(age / max_age, 0.0, 1.0); + + // int r = int(255 * (1.0 - t)); + // int b = int(255 * t); + + // cv::Scalar color(b, 0, r); + + // cv::line( + // debug_img, + // img_pts[i - 1].first, + // img_pts[i].first, + // color, + // 2, + // cv::LINE_AA + // ); + // } + // } + // } + + cv::Point2f center(0.f, 0.f); + + for (auto pt: pts) + center += pt; + + center *= 1.0f / pts.size(); + + cv::Scalar color(255, 255, 255); + + for (int i = 0; i < 4; i++) + cv::line(debug_img, pts[i], pts[(i + 1) % 4], color, 2); + + for (int i = 4; i < 8; i++) + cv::line(debug_img, pts[i], pts[4 + (i + 1) % 4], color, 2); + + for (int i = 0; i < 4; i++) + cv::line(debug_img, pts[i], pts[i + 4], color, 2); + + if (gimbal_cmd.fire_advice) { + int cross_len = 60; + + cv::line( + debug_img, + center + cv::Point2f(-cross_len, -cross_len), + center + cv::Point2f(+cross_len, +cross_len), + cv::Scalar(0, 0, 255), + 5 + ); + + cv::line( + debug_img, + center + cv::Point2f(-cross_len, +cross_len), + center + cv::Point2f(+cross_len, -cross_len), + cv::Scalar(0, 0, 255), + 5 + ); + } + + const double scale = 10.0; + + const double v_yaw = gimbal_cmd.v_yaw; + const double v_pitch = gimbal_cmd.v_pitch; + + const double dx = -scale * v_yaw; + const double dy = scale * v_pitch; + + const cv::Point2f start_pt = center; + const cv::Point2f end_pt = start_pt + cv::Point2f(dx, dy); + + const cv::Scalar color_x = + dbg.detect_color ? cv::Scalar(255, 50, 50) : cv::Scalar(50, 50, 255); + + cv::arrowedLine(debug_img, start_pt, end_pt, color_x, 4, cv::LINE_AA, 0, 0.2); + } + } + std::vector all_corners; + + auto visualizeTargetProjection = [&](auto_aim::Target armor_target) -> auto_aim::Armors { + auto_aim::Armors armor_data; + armor_data.timestamp = armor_target.timestamp_; + + if (armor_target.is_tracking) { + Eigen::Vector3d pos = armor_target.target_state_.pos(); + if (pos.norm() > 0.5) { + armor_data.armors.clear(); + const size_t a_n = armor_target.armor_num_; + armor_data.armors.reserve(a_n); + const auto now = wust_vl::common::utils::time_utils::now(); + armor_target.predictSimple(now); + const std::vector armors_posandyaw = + armor_target.getArmorPosAndYaw(); + for (size_t i = 0; i < a_n; ++i) { + const Eigen::Vector3d pos = { armors_posandyaw[i][0], + armors_posandyaw[i][1], + armors_posandyaw[i][2] }; + Eigen::Vector3d euler; + euler.z() = M_PI / 2.0; + euler.y() = (armor_target.tracked_id_ == auto_aim::ArmorNumber::OUTPOST) + ? -0.2618 + : 0.2618; + euler.x() = armors_posandyaw[i][3]; + const Eigen::Quaterniond ori = + utils::eulerToQuat(euler, utils::EulerOrder::ZYX); + armor_data.armors.emplace_back(auto_aim::Armor { + .type = armor_target.type_, + .pos = pos, + .ori = ori, + .is_ok = true, + .id = (int)(i), + }); + } + } + } + return armor_data; + }; + auto armor_data = visualizeTargetProjection(dbg.target); + transformArmorData(armor_data, dbg.T_camera_to_odom.inverse()); + for (size_t i = 0; i < armor_data.armors.size(); ++i) { + const auto& pts = armor_data.armors[i].toPtsDebug(camera_info.first, camera_info.second); + const auto& pos = armor_data.armors[i].pos; + const auto& ori = armor_data.armors[i].ori; + const auto& id = armor_data.armors[i].id; + cv::Scalar color; + if (dbg.detect_color) { + color = cv::Scalar(255, 0, 0); + } else { + color = cv::Scalar(0, 0, 255); + } + // 绘制前表面 + for (size_t j = 0; j < 4; ++j) { + cv::line(debug_img, pts[j], pts[(j + 1) % 4], color, 2); + } + + // 绘制后表面 + for (size_t j = 4; j < 8; ++j) { + cv::line(debug_img, pts[j], pts[4 + (j + 1) % 4], color, 2); + } + + // 绘制侧边 + for (size_t j = 0; j < 4; ++j) { + cv::line(debug_img, pts[j], pts[j + 4], color, 2); + } + + all_corners.insert(all_corners.end(), pts.begin(), pts.end()); + + const Eigen::Vector3d euler = ori.toRotationMatrix().eulerAngles(2, 1, 0); + const double yaw = euler[0]; + const double distance = + std::sqrt(pos.x() * pos.x() + pos.y() * pos.y() + pos.z() * pos.z()); + + const std::vector info_lines = { + fmt::format("Dis: {:.1f}cm", distance * 100), + fmt::format("X: {:.2f}", pos.x()), + fmt::format("Y: {:.2f}", pos.y()), + fmt::format("Z: {:.2f}", pos.z()), + fmt::format("Yaw: {:.2f}", yaw * 180.0 / M_PI), + fmt::format("ID: {:d}", id) + }; + + const cv::Point2f text_org = pts[0] + cv::Point2f(0, 200); + for (int k = 0; k < info_lines.size(); ++k) { + cv::putText( + debug_img, + info_lines[k], + text_org + cv::Point2f(0, -10 - 20 * k), + cv::FONT_HERSHEY_SIMPLEX, + 0.6, + cv::Scalar(50, 255, 255), + 1 + ); + } + } + + if (!all_corners.empty()) { + cv::Point2f avg(0.f, 0.f); + for (const auto& pt: all_corners) + avg += pt; + avg *= 1.0f / all_corners.size(); + cv::circle(debug_img, avg, 10, cv::Scalar(50, 255, 50), -1); + + const double scale = 50.0; + const double dy = scale * target.target_state_.vyaw(); + const cv::Point2f start_pt = avg; + const cv::Point2f end_pt = start_pt + cv::Point2f(0, dy); + cv::arrowedLine( + debug_img, + start_pt, + end_pt, + cv::Scalar(50, 255, 50), + 3, + cv::LINE_AA, + 0, + 0.1 + ); + cv::putText( + debug_img, + fmt::format("V_yaw: {:.2f}", target.target_state_.vyaw()), + avg + cv::Point2f(0, -20), + cv::FONT_HERSHEY_SIMPLEX, + 1.0, + cv::Scalar(50, 255, 50), + 2 + ); + } + + std::string state_str; + state_str = auto_aim_fsm_to_string(dbg.fsm); + + int baseline = 0; + cv::Size text_size = cv::getTextSize(state_str, cv::FONT_HERSHEY_SIMPLEX, 2.5, 2, &baseline); + + // 保证在图像内 + const int x = + std::clamp(debug_img.cols - text_size.width - 10, 0, debug_img.cols - text_size.width); + const int y = std::clamp(text_size.height + 10, text_size.height, debug_img.rows - 1); + + cv::putText( + debug_img, + state_str, + { x, y }, + cv::FONT_HERSHEY_SIMPLEX, + 2.5, + cv::Scalar(0, 0, 255), + 2 + ); + + const std::string id_str = + fmt::format("Attack: {}", armorNumberToString(dbg.target.tracked_id_)); + const cv::Size id_size = cv::getTextSize(id_str, cv::FONT_HERSHEY_SIMPLEX, 1.6, 2, &baseline); + + // 保证在图像内 + const int id_x = std::clamp(debug_img.cols - 300, 0, debug_img.cols - id_size.width - 10); + const int id_y = std::clamp(150, id_size.height, debug_img.rows - 1); + + cv::putText( + debug_img, + id_str, + { id_x, id_y }, + cv::FONT_HERSHEY_SIMPLEX, + 1.6, + cv::Scalar(255, 0, 255), + 2 + ); + + if (gimbal_cmd.fire_advice) { + std::string fire_str = "Fire!"; + cv::putText( + debug_img, + fire_str, + { debug_img.cols / 2 - 100, 200 }, + cv::FONT_HERSHEY_SIMPLEX, + 2.85, + cv::Scalar(0, 0, 255), + 2 + ); + } + + const std::string gimbal_str = fmt::format( + "Pitch: {:.2f}, Yaw: {:.2f}, Enable_pitch_diff: {:.2f}, Enable_yaw_diff: {:.2f}, V_yaw: {:.2f}, V_pitch: {:.2f}", + gimbal_cmd.pitch, + gimbal_cmd.yaw, + gimbal_cmd.enable_pitch_diff, + gimbal_cmd.enable_yaw_diff, + gimbal_cmd.v_yaw, + gimbal_cmd.v_pitch + ); + cv::putText( + debug_img, + gimbal_str, + { 10, debug_img.rows - 30 }, + cv::FONT_HERSHEY_SIMPLEX, + 0.8, + cv::Scalar(255, 255, 0), + 2 + ); + + double scale = 100.0; + double armor_len = 0.135; + + std::vector pts; + pts.reserve(armors.armors.size() + armor_data.armors.size()); + + auto collect_xy = [&](auto& list, bool use_target) { + for (auto& a: list) + pts.emplace_back( + use_target ? a.target_pos.x() : a.pos.x(), + use_target ? a.target_pos.y() : a.pos.y() + ); + }; + + collect_xy(armors.armors, true); + collect_xy(armor_data.armors, false); + + double max_abs_x = 1e-6, max_abs_y = 1e-6; + for (auto& p: pts) { + max_abs_x = std::max(max_abs_x, std::abs(p.x())); + max_abs_y = std::max(max_abs_y, std::abs(p.y())); + } + + const double margin = 200.0; + const double cx = debug_img.cols * 0.5; + const double cy = debug_img.rows * 0.5; + + scale = std::min({ (cx - margin) / max_abs_x, + (debug_img.cols - cx - margin) / max_abs_x, + (cy - margin) / max_abs_y, + (debug_img.rows - cy - margin) / max_abs_y, + 550.0 }); + + const cv::Point2d origin(cx, cy); + + auto to_img = [&](const Eigen::Vector3d& p) { + return cv::Point2d(origin.x + p.x() * scale, origin.y - p.y() * scale); + }; + + auto draw2dArmor = [&](const Eigen::Vector3d& pos, double yaw, const cv::Scalar& color) { + cv::Point2d C = to_img(pos); + cv::circle(debug_img, C, 3, color, -1, cv::LINE_AA); + + double nx = -sin(yaw), ny = cos(yaw); + double half_len_px = armor_len * 0.5 * scale; + + cv::Point2d P1(C.x + nx * half_len_px, C.y - ny * half_len_px); + cv::Point2d P2(C.x - nx * half_len_px, C.y + ny * half_len_px); + cv::line(debug_img, P1, P2, color, 2, cv::LINE_AA); + }; + + Eigen::Vector3d center(0, 0, 0); + if (!armor_data.armors.empty()) { + for (auto& a: armor_data.armors) + center += a.pos; + center /= armor_data.armors.size(); + } + + const cv::Point2d Cc = to_img(center); + if (!armor_data.armors.empty()) + cv::circle(debug_img, Cc, 5, cv::Scalar(255, 0, 0), -1, cv::LINE_AA); + + for (auto& a: armors.armors) + draw2dArmor(a.target_pos, a.yaw, cv::Scalar(0, 255, 255)); + + std::vector data_pts; + + for (auto& a: armor_data.armors) { + double yaw = a.ori.toRotationMatrix().eulerAngles(2, 1, 0)[0]; + draw2dArmor(a.pos, yaw, cv::Scalar(255, 255, 255)); + data_pts.push_back(to_img(a.pos)); + } + + for (auto& pt: data_pts) + cv::line(debug_img, Cc, pt, cv::Scalar(180, 180, 255), 1, cv::LINE_AA); + + for (auto& a: armors.armors) + cv::line(debug_img, Cc, to_img(a.target_pos), cv::Scalar(0, 150, 255), 1, cv::LINE_AA); + + cv::circle( + debug_img, + cv::Point2i(debug_img.cols / 2, debug_img.rows / 2), + 5, + cv::Scalar(255, 255, 255), + 2 + ); +} +void writeTargetLogToJson(const auto_aim::Target& armor_target) { + nlohmann::json j; + + // -------- armor_target 部分 -------- + nlohmann::json jt; + jt["type"] = armor_target.type_; + jt["tracking"] = armor_target.is_tracking; + jt["id"] = static_cast(armor_target.tracked_id_); + jt["armors_num"] = armor_target.armor_num_; + + const auto now = std::chrono::steady_clock::now(); + const auto age_ms_t = + std::chrono::duration_cast(now - armor_target.timestamp_) + .count(); + jt["timestamp_age_ms"] = age_ms_t; + + jt["position"] = { { "x", armor_target.target_state_.cx() }, + { "y", armor_target.target_state_.cy() }, + { "z", armor_target.target_state_.cz() } }; + + jt["velocity"] = { { "x", armor_target.target_state_.vcx() }, + { "y", armor_target.target_state_.vcy() }, + { "z", armor_target.target_state_.vcz() } }; + + jt["r"] = armor_target.target_state_.r(); + jt["l"] = armor_target.target_state_.l(); + jt["h"] = armor_target.target_state_.h(); + jt["yaw"] = armor_target.target_state_.yaw(); + jt["v_yaw"] = armor_target.target_state_.vyaw(); + j["armor_target"] = jt; + // -------- 写文件 -------- + std::ofstream file("/dev/shm/target_log.json"); + if (file.is_open()) { + file << j.dump(2); + } +} + +struct DebugLogs { +#define DEBUG_LOG_LIST(X) \ + X(double, 100, time) \ + X(double, 100, raw_yaw) \ + X(double, 100, raw_pitch) \ + X(double, 100, yaw) \ + X(double, 100, pitch) \ + X(double, 100, armor_dis) \ + X(double, 100, armor_x) \ + X(double, 100, armor_y) \ + X(double, 100, armor_z) \ + X(double, 100, armor_yaw) \ + X(double, 100, ypd_y) \ + X(double, 100, ypd_p) \ + X(double, 100, gimbal_yaw) \ + X(double, 100, gimbal_pitch) \ + X(double, 100, target_v_yaw) \ + X(double, 100, control_v_yaw) \ + X(double, 100, control_v_pitch) \ + X(double, 100, yaw_diff) \ + X(double, 100, fire) \ + X(double, 100, rune_dis) \ + X(double, 100, fly_time) \ + X(double, 100, control_a_yaw) \ + X(double, 100, control_a_pitch) +#define GEN_LOG(TYPE, SIZE, NAME) LogsStream NAME##_log { #NAME }; + +#define X(TYPE, SIZE, NAME) GEN_LOG(TYPE, SIZE, NAME) + DEBUG_LOG_LIST(X) +#undef X + + void clear() { +#define X(TYPE, SIZE, NAME) NAME##_log.clear(); + DEBUG_LOG_LIST(X) +#undef X + } +}; + +void debuglog(const AutoAimDebug& dbg_armor) { + static bool first_log = true; + static std::chrono::steady_clock::time_point start_time; + + static auto_aim::Armor last_armor_; + static double last_armor_yaw_ = 0.0; + static double last_ypd_y_ = 0.0; + static double last_ypd_p_ = 0.0; + static double last_distance_ = 0.0; + static DebugLogs log; + static GimbalCmd last_cmd_; + static double rune_dis = 0.0; + if (first_log) { + start_time = std::chrono::steady_clock::now(); + first_log = false; + } + const auto now = std::chrono::steady_clock::now(); + const auto_aim::Armors& armors = dbg_armor.armors; + const double t = std::chrono::duration(now - start_time).count(); + const auto_aim::Target& target = dbg_armor.target; + writeTargetLogToJson(target); + + double armor_yaw = 0.0, ypd_y = 0.0, ypd_p = 0.0, armor_distance = 0.0; + + if (!armors.armors.empty()) { + std::vector ok_armors; + for (const auto& armor: armors.armors) { + if (armor.number != auto_aim::ArmorNumber::OUTPOST) + ok_armors.push_back(armor); + } + + if (!ok_armors.empty()) { + const auto_aim::Armor& min_armor = *std::min_element( + ok_armors.begin(), + ok_armors.end(), + [](const auto_aim::Armor& a, const auto_aim::Armor& b) { + return a.distance_to_image_center < b.distance_to_image_center; + } + ); + + last_armor_ = min_armor; + + armor_distance = std::hypot( + min_armor.target_pos.x(), + min_armor.target_pos.y(), + min_armor.target_pos.z() + ); + auto orientationToYaw = [](const Eigen::Quaterniond& q) noexcept -> double { + Eigen::Vector3d euler = utils::quatToEuler(q, utils::EulerOrder::ZYX, false); + double yaw = euler[0]; + yaw = last_armor_yaw_ + angles::shortest_angular_distance(last_armor_yaw_, yaw); + last_armor_yaw_ = yaw; + return yaw; + }; + + armor_yaw = orientationToYaw(min_armor.target_ori); + + ypd_y = std::atan2(min_armor.target_pos.y(), min_armor.target_pos.x()); + ypd_y = last_ypd_y_ + angles::shortest_angular_distance(last_ypd_y_, ypd_y); + last_ypd_y_ = ypd_y; + + ypd_p = std::atan2( + min_armor.target_pos.z(), + std::hypot(min_armor.target_pos.x(), min_armor.target_pos.y()) + ); + last_ypd_p_ = ypd_p; + + last_distance_ = armor_distance; + } + } + GimbalCmd i_use; + if (dbg_armor.gimbal_cmd.appear) { + i_use = dbg_armor.gimbal_cmd; + } else { + i_use = last_cmd_; + } + last_cmd_ = i_use; + nlohmann::json j; + log.time_log.handleOnce(t, j); + log.raw_yaw_log.handleOnce(i_use.target_yaw, j); + log.raw_pitch_log.handleOnce(i_use.target_pitch, j); + log.yaw_log.handleOnce(i_use.yaw, j); + log.pitch_log.handleOnce(i_use.pitch, j); + log.armor_yaw_log.handleOnce(armor_yaw * 180.0 / M_PI, j); + log.armor_x_log.handleOnce(last_armor_.target_pos.x(), j); + log.armor_y_log.handleOnce(last_armor_.target_pos.y(), j); + log.armor_z_log.handleOnce(last_armor_.target_pos.z(), j); + log.ypd_y_log.handleOnce(last_ypd_y_ * 180.0 / M_PI, j); + log.ypd_p_log.handleOnce(last_ypd_p_ * 180.0 / M_PI, j); + log.armor_dis_log.handleOnce(last_distance_, j); + log.gimbal_pitch_log.handleOnce(dbg_armor.gimbal_py.first * 180.0 / M_PI, j); + log.gimbal_yaw_log.handleOnce(dbg_armor.gimbal_py.second * 180.0 / M_PI, j); + log.target_v_yaw_log.handleOnce(target.target_state_.vyaw(), j); + log.control_v_pitch_log.handleOnce(i_use.v_pitch, j); + log.control_v_yaw_log.handleOnce(i_use.v_yaw, j); + log.fire_log.handleOnce(i_use.fire_advice, j); + log.rune_dis_log.handleOnce(rune_dis, j); + log.fly_time_log.handleOnce(i_use.fly_time, j); + log.control_a_yaw_log.handleOnce(i_use.a_yaw / 180.0 * M_PI, j); + log.control_a_pitch_log.handleOnce(i_use.a_pitch / 180.0 * M_PI, j); + log.yaw_diff_log.handleOnce( + std::abs(dbg_armor.gimbal_py.second * 180.0 / M_PI - dbg_armor.gimbal_cmd.yaw), + j + ); + + std::ofstream file("/dev/shm/cmd_log.json"); + if (file.is_open()) { + file << j.dump(); + } +} +} // namespace wust_vision::auto_aim \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/debug.hpp b/wust_vision-main/tasks/auto_aim/debug.hpp new file mode 100644 index 0000000..4cd9598 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/debug.hpp @@ -0,0 +1,38 @@ +#pragma once +#include "tasks/auto_aim/armor_tracker/target.hpp" +#include "tasks/auto_aim/auto_aim_fsm.hpp" +#include "tasks/auto_aim/type.hpp" +#include "tasks/utils/debug_utils.hpp" +namespace wust_vision::auto_aim { + +struct AutoAimDebug { + wust_vl::video::ImageFrame img_frame; + auto_aim::Armors armors; + auto_aim::Target target; + GimbalCmd gimbal_cmd; + auto_aim::AutoAimFsm fsm; + AimTarget aim_target; + double latency_ms; + Eigen::Matrix4d T_camera_to_odom; + std::vector armor_objs; + int detect_color = 0; + cv::Rect expanded; + std::pair gimbal_py; +}; +void drawDebugArmorContent( + cv::Mat& debug_img, + const AutoAimDebug& dbg, + std::pair camera_info +); +void writeTargetLogToJson(const auto_aim::Target& armor_target); + +inline void drawDebugOverlayShm( + const AutoAimDebug& dbg, + std::pair camera_info, + bool auto_fps +) { + static ShmWriter shm { "/debug_frame" }; + drawDebugOverlayImpl(dbg, camera_info, auto_fps, drawDebugArmorContent, shm); +} +void debuglog(const AutoAimDebug& dbg_armor); +} // namespace wust_vision::auto_aim diff --git a/wust_vision-main/tasks/auto_aim/type.cpp b/wust_vision-main/tasks/auto_aim/type.cpp new file mode 100644 index 0000000..35df3b4 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/type.cpp @@ -0,0 +1,399 @@ +#include "type.hpp" +#include "tasks/utils/config.hpp" +#include "wust_vl/common/utils/logger.hpp" +#include +namespace wust_vision { +namespace auto_aim { + + Light::Light(const std::vector& contour): cv::RotatedRect(cv::minAreaRect(contour)) { + this->center = std::accumulate( + contour.begin(), + contour.end(), + cv::Point2f(0, 0), + [n = static_cast(contour.size())](const cv::Point2f& a, const cv::Point& b) { + return a + cv::Point2f(b.x, b.y) / n; + } + ); + + cv::Point2f p[4]; + this->points(p); + + std::sort(p, p + 4, [](const cv::Point2f& a, const cv::Point2f& b) { return a.y < b.y; }); + + top = (p[0] + p[1]) / 2; + bottom = (p[2] + p[3]) / 2; + + length = cv::norm(top - bottom); + width = cv::norm(p[0] - p[1]); + + axis = (top - bottom) / cv::norm(top - bottom); + + tilt_angle = + std::atan2(std::abs(top.x - bottom.x), std::abs(top.y - bottom.y)) / CV_PI * 180.0f; + } + + void Light::addOffset(const cv::Point2f& offset) noexcept { + this->center += offset; + top += offset; + bottom += offset; + } + void Light::transform(const Eigen::Matrix& transform_matrix) noexcept { + top = utils::transformPoint2D(transform_matrix, top); + bottom = utils::transformPoint2D(transform_matrix, bottom); + length = cv::norm(top - bottom); + cv::Point2f p[4]; + this->points(p); + + width = cv::norm( + utils::transformPoint2D(transform_matrix, p[0]) + - utils::transformPoint2D(transform_matrix, p[1]) + ); + const cv::Point2f p0 = center; + const cv::Point2f p1 = center + axis; + + const cv::Point2f p0_t = utils::transformPoint2D(transform_matrix, p0); + + const cv::Point2f p1_t = utils::transformPoint2D(transform_matrix, p1); + + axis = p1_t - p0_t; + axis /= cv::norm(axis); + + tilt_angle = + std::atan2(std::abs(top.x - bottom.x), std::abs(top.y - bottom.y)) / CV_PI * 180.0f; + center = utils::transformPoint2D(transform_matrix, center); + } + int formArmorColor(const ArmorColor& color) noexcept { + switch (color) { + case ArmorColor::RED: + return 0; + case ArmorColor::BLUE: + return 1; + case ArmorColor::NONE: + return 2; + case ArmorColor::PURPLE: + return 3; + } + return -1; + } + + std::ostream& operator<<(std::ostream& os, const ArmorNumber& number) noexcept { + switch (number) { + case ArmorNumber::SENTRY: + return os << "SENTRY"; + case ArmorNumber::NO1: + return os << "NO1"; + case ArmorNumber::NO2: + return os << "NO2"; + case ArmorNumber::NO3: + return os << "NO3"; + case ArmorNumber::NO4: + return os << "NO4"; + case ArmorNumber::NO5: + return os << "NO5"; + case ArmorNumber::OUTPOST: + return os << "OUTPOST"; + case ArmorNumber::BASE: + return os << "BASE"; + case ArmorNumber::UNKNOWN: + return os << "UNKNOWN"; + default: + return os << "InvalidArmorNumber(" << static_cast(number) << ")"; + } + } + + int formArmorNumber(const ArmorNumber& number) noexcept { + switch (number) { + case ArmorNumber::SENTRY: + return 0; + case ArmorNumber::NO1: + return 1; + case ArmorNumber::NO2: + return 2; + case ArmorNumber::NO3: + return 3; + case ArmorNumber::NO4: + return 4; + case ArmorNumber::NO5: + return 5; + case ArmorNumber::OUTPOST: + return 6; + case ArmorNumber::BASE: + return 7; + case ArmorNumber::UNKNOWN: + return 8; + } + return -1; + } + ArmorNumber armorNumberFromString(const std::string& s) noexcept { + if (s == "SENTRY") + return ArmorNumber::SENTRY; + if (s == "BASE") + return ArmorNumber::BASE; + if (s == "OUTPOST") + return ArmorNumber::OUTPOST; + if (s == "NO1") + return ArmorNumber::NO1; + if (s == "NO2") + return ArmorNumber::NO2; + if (s == "NO3") + return ArmorNumber::NO3; + if (s == "NO4") + return ArmorNumber::NO4; + if (s == "NO5") + return ArmorNumber::NO5; + return ArmorNumber::UNKNOWN; + } + + std::string armorNumberToString(const ArmorNumber& num) noexcept { + switch (num) { + case ArmorNumber::SENTRY: + return "SENTRY"; + case ArmorNumber::BASE: + return "BASE"; + case ArmorNumber::OUTPOST: + return "OUTPOST"; + case ArmorNumber::NO1: + return "NO1"; + case ArmorNumber::NO2: + return "NO2"; + case ArmorNumber::NO3: + return "NO3"; + case ArmorNumber::NO4: + return "NO4"; + case ArmorNumber::NO5: + return "NO5"; + default: + return "UNKNOWN"; + } + } + namespace { + std::unordered_map armor_map; + std::unordered_map> tracker_to_armors; + bool loaded = false; + + void loadArmorMapOnce() { + if (loaded) + return; + + try { + YAML::Node config = YAML::LoadFile(AUTO_AIM_CONFIG)["armor_map"]; + + for (auto it = config.begin(); it != config.end(); ++it) { + const std::string key = it->first.as(); + const int tracker_id = it->second.as(); + + ArmorNumber armor_num = armorNumberFromString(key); + + armor_map[key] = tracker_id; + tracker_to_armors[tracker_id].emplace_back(armor_num); + } + loaded = true; + } catch (const std::exception& e) { + std::cerr << "[ArmorMap] Failed to load armor_map.yaml: " << e.what() << std::endl; + } + } + } // namespace + + int retypetotracker(const ArmorNumber& a) noexcept { + loadArmorMapOnce(); + + const std::string key = armorNumberToString(a); + const auto it = armor_map.find(key); + if (it != armor_map.end()) + return it->second; + + std::cerr << "[retypetotracker] Invalid ArmorNumber: " << static_cast(a) << std::endl; + return -1; + } + + bool isSameTarget(const ArmorNumber& a, const ArmorNumber& b) noexcept { + return retypetotracker(a) == retypetotracker(b); + } + + std::string armorTypeToString(const ArmorType& type) noexcept { + switch (type) { + case ArmorType::SMALL: + return "small"; + case ArmorType::LARGE: + return "large"; + default: + return "invalid"; + } + } + + std::vector ArmorObject::toPts() const noexcept { + if (is_ok) { + return { lights[0].top, lights[0].bottom, lights[1].bottom, lights[1].top }; + } else { + return { pts[0], pts[1], pts[2], pts[3] }; + } + } + bool ArmorObject::checkOkptsRight(double max_error) const noexcept { + double error = 0.0; + for (int i = 0; i < 4; i++) { + error += cv::norm(pts[i] - toPts()[i]); + } + return error < max_error; + } + std::array ArmorObject::sortCorners(const std::vector& pts + ) const noexcept { + std::array ordered; + + // 先按 x 坐标分成左右两组 + std::vector left, right; + std::vector sorted = pts; + + std::sort(sorted.begin(), sorted.end(), [](const cv::Point2f& a, const cv::Point2f& b) { + return a.x < b.x; + }); + + left.push_back(sorted[0]); + left.push_back(sorted[1]); + right.push_back(sorted[2]); + right.push_back(sorted[3]); + + // 左边两个点,按 y 分为上/下 + std::sort(left.begin(), left.end(), [](const cv::Point2f& a, const cv::Point2f& b) { + return a.y < b.y; + }); + ordered[1] = left[0]; // 左上 + ordered[0] = left[1]; // 左下 + + // 右边两个点,按 y 分为上/下 + std::sort(right.begin(), right.end(), [](const cv::Point2f& a, const cv::Point2f& b) { + return a.y < b.y; + }); + ordered[2] = right[0]; // 右上 + ordered[3] = right[1]; // 右下 + + return ordered; // 顺序: 左下, 左上, 右上, 右下 + } + std::vector ArmorObject::landmarks() const noexcept { + if constexpr (N_LANDMARKS == 4) { + if (is_ok) { + return { lights[0].bottom, lights[0].top, lights[1].top, lights[1].bottom }; + } else { + const auto ordered = sortCorners(pts); + return { ordered[0], ordered[1], ordered[2], ordered[3] }; + } + + } else { + if (is_ok) { + return { lights[0].bottom, lights[0].center, lights[0].top, + lights[1].top, lights[1].center, lights[1].bottom }; + } else { + const auto ordered = sortCorners(pts); + return { ordered[0], (ordered[0] + ordered[1]) / 2.0, ordered[1], + ordered[2], (ordered[2] + ordered[3]) / 2.0, ordered[3] }; + } + } + } + ArmorObject::ArmorObject(const Light& l1, const Light& l2) { + pts.resize(4); + if (l1.center.x < l2.center.x) { + lights.push_back(l1); + lights.push_back(l2); + pts[0] = l1.top; + pts[1] = l1.bottom; + pts[2] = l2.bottom; + pts[3] = l2.top; + } else { + lights.push_back(l2); + lights.push_back(l1); + pts[0] = l2.top; + pts[1] = l2.bottom; + pts[2] = l1.bottom; + pts[3] = l1.top; + } + is_ok = true; + } + + std::vector Armor::toPtsDebug( + const cv::Mat& camera_intrinsic, + const cv::Mat& camera_distortion + ) const noexcept { + std::vector image_points; + const std::vector* model_points; + static std::vector SMALL_ARMOR_3D_POINTS_BLOCK = { + { 0, 0.025, -0.066 }, // 左上前 + { 0, -0.025, -0.066 }, // 左下前 + { 0, -0.025, 0.066 }, // 右下前 + { 0, 0.025, 0.066 }, // 右上前 + { 0.015, 0.025, -0.066 }, // 左上后 + { 0.015, -0.025, -0.066 }, // 左下后 + { 0.015, -0.025, 0.066 }, // 右下后 + { 0.015, 0.025, 0.066 }, // 右上后 + }; + + static std::vector BIG_ARMOR_3D_POINTS_BLOCK = { + { 0, 0.025, -0.1125 }, { 0, -0.025, -0.1125 }, { 0, -0.025, 0.1125 }, + { 0, 0.025, 0.1125 }, { 0.015, 0.025, -0.1125 }, { 0.015, -0.025, -0.1125 }, + { 0.015, -0.025, 0.1125 }, { 0.015, 0.025, 0.1125 }, + }; + + if (type == "large") { + model_points = &BIG_ARMOR_3D_POINTS_BLOCK; + } else if (type == "small") { + model_points = &SMALL_ARMOR_3D_POINTS_BLOCK; + } + const Eigen::Matrix3d tf_rot = target_ori.toRotationMatrix(); + const cv::Mat rot_mat = + (cv::Mat_(3, 3) << tf_rot(0, 0), + tf_rot(0, 1), + tf_rot(0, 2), + tf_rot(1, 0), + tf_rot(1, 1), + tf_rot(1, 2), + tf_rot(2, 0), + tf_rot(2, 1), + tf_rot(2, 2)); + + // 旋转矩阵 -> 旋转向量 + cv::Mat rvec; + cv::Rodrigues(rot_mat, rvec); + + // 平移向量 + const cv::Mat tvec = + (cv::Mat_(3, 1) << target_pos.x(), target_pos.y(), target_pos.z()); + + // 反投影 + cv::projectPoints( + *model_points, + rvec, + tvec, + camera_intrinsic, + camera_distortion, + image_points + ); + return image_points; + } + + void transformArmorData(Armors& armors, Eigen::Matrix4d T_camera_to_odom) noexcept { + for (auto& armor: armors.armors) { + transformArmorData(armor, T_camera_to_odom); + } + } + void transformArmorData(Armor& armor, const Eigen::Matrix4d& T_camera_to_odom) noexcept { + try { + // 位置 + const Eigen::Vector3d pos_camera = armor.pos; + armor.target_pos = utils::transformPosition(pos_camera, T_camera_to_odom); + + // 姿态 + const Eigen::Quaterniond + q_camera(armor.ori.w(), armor.ori.x(), armor.ori.y(), armor.ori.z()); + const Eigen::Quaterniond q_odom = + utils::transformOrientation(q_camera, T_camera_to_odom); + armor.target_ori = q_odom; + + // 提取 yaw + const Eigen::Vector3d euler = q_odom.toRotationMatrix().eulerAngles(2, 1, 0); // ZYX + armor.yaw = euler[0]; // yaw + + } catch (const std::exception& e) { + WUST_ERROR("tf") << "Error in camera-to-odom transform: " << e.what(); + } + } + +} // namespace auto_aim +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_aim/type.hpp b/wust_vision-main/tasks/auto_aim/type.hpp new file mode 100644 index 0000000..5aa3fc3 --- /dev/null +++ b/wust_vision-main/tasks/auto_aim/type.hpp @@ -0,0 +1,172 @@ +#pragma once +#include "tasks/type_common.hpp" +#include "tasks/utils/utils.hpp" +namespace wust_vision { + +namespace auto_aim { + constexpr double SMALL_ARMOR_WIDTH = 133.0 / 1000.0; // 135 + constexpr double SMALL_ARMOR_HEIGHT = 50.0 / 1000.0; // 55 + constexpr double LARGE_ARMOR_WIDTH = 225.0 / 1000.0; + constexpr double LARGE_ARMOR_HEIGHT = 50.0 / 1000.0; // 55 + + constexpr double FIFTTEN_DEGREE_RAD = 15 * CV_PI / 180; + struct Light: public cv::RotatedRect { + Light() = default; + + explicit Light(const std::vector& contour); + + void addOffset(const cv::Point2f& offset) noexcept; + void transform(const Eigen::Matrix& transform_matrix) noexcept; + + cv::Point2f top, bottom; + int color = 0; + cv::Point2f axis; + double length = 0; + double width = 0; + float tilt_angle = 0; + }; + + enum class ArmorColor : int { BLUE = 0, RED, NONE, PURPLE }; + + int formArmorColor(const ArmorColor& color) noexcept; + enum class ArmorNumber : int { SENTRY = 0, NO1, NO2, NO3, NO4, NO5, OUTPOST, BASE, UNKNOWN }; + std::ostream& operator<<(std::ostream& os, const ArmorNumber& number) noexcept; + + int formArmorNumber(const ArmorNumber& number) noexcept; + std::string armorNumberToString(const ArmorNumber& num) noexcept; + ArmorNumber armorNumberFromString(const std::string& s) noexcept; + int retypetotracker(const ArmorNumber& a) noexcept; + bool isSameTarget(const ArmorNumber& a, const ArmorNumber& b) noexcept; + enum class ArmorsNum { NORMAL_4 = 4, OUTPOST_3 = 3 }; + + enum class ArmorType { SMALL, LARGE, INVALID }; + std::string armorTypeToString(const ArmorType& type) noexcept; + + struct ArmorObject { + ArmorColor color; + ArmorNumber number; + std::vector pts; + cv::Rect box; + + cv::Mat number_img; + + double confidence; + + cv::Mat whole_binary_img; + cv::Mat whole_rgb_img; + cv::Mat whole_gray_img; + + std::vector lights; + cv::Point2f local_offset; + cv::Point2f center; + bool is_ok = false; + ArmorType type; + static constexpr const int N_LANDMARKS = 6; + static constexpr const int N_LANDMARKS_2 = N_LANDMARKS * 2; + + template + static std::vector buildObjectPoints(const double& w, const double& h) noexcept { + if constexpr (N_LANDMARKS == 4) { + return { + PointType(0, w / 2, -h / 2), // 右下 + PointType(0, w / 2, h / 2), // 右上 + PointType(0, -w / 2, h / 2), // 左上 + PointType(0, -w / 2, -h / 2) // 左下 + }; + } else { + return { + PointType(0, w / 2, -h / 2), // 右下 + PointType(0, w / 2, 0.0), // 右中 + PointType(0, w / 2, h / 2), // 右上 + + PointType(0, -w / 2, h / 2), // 左上 + PointType(0, -w / 2, 0.0), // 左中 + PointType(0, -w / 2, -h / 2) // 左下 + }; + } + } + template + static std::vector> buildSymPairs() noexcept { + if constexpr (N_LANDMARKS == 4) { + static const std::vector> pairs = { + { 0, 3 }, + { 1, 2 }, + // { 0, 2 }, + // { 1, 3 } + }; + return pairs; + } else { + static const std::vector> pairs = { + { 0, 5 }, + { 1, 4 }, + { 2, 3 }, + // { 0, 3 }, + // { 2, 5 } + + }; + return pairs; + } + } + std::vector toPts() const noexcept; + bool checkOkptsRight(double max_error) const noexcept; + std::array sortCorners(const std::vector& pts) const noexcept; + + // Landmarks start from bottom left in clockwise order + std::vector landmarks() const noexcept; + void addOffset(const cv::Point2f& offset) noexcept { + for (auto& pt: pts) { + pt += offset; + } + center += offset; + box.x += offset.x; + box.y += offset.y; + for (auto& l: lights) { + l.addOffset(offset); + } + } + void transform(const Eigen::Matrix& transform_matrix) noexcept { + for (auto& l: lights) { + l.transform(transform_matrix); + } + center = utils::transformPoint2D(transform_matrix, center); + box = utils::transformRect(transform_matrix, box); + for (auto& pt: pts) { + pt = utils::transformPoint2D(transform_matrix, pt); + } + } + ArmorObject(const Light& l1, const Light& l2); + ArmorObject() = default; + }; + + struct Armor { + public: + ArmorNumber number; + std::string type; + Eigen::Vector3d pos; + Eigen::Quaterniond ori; + Eigen::Vector3d target_pos; + Eigen::Quaterniond target_ori; + float distance_to_image_center; + float yaw; + std::chrono::steady_clock::time_point timestamp; + bool is_ok = false; + bool is_none_purple = false; + int id = -1; + std::vector toPtsDebug( + const cv::Mat& camera_intrinsic, + const cv::Mat& camera_distortion + ) const noexcept; + }; + struct Armors { + public: + std::vector armors; + std::chrono::steady_clock::time_point timestamp; + int id; + Eigen::Vector3d v; + }; + + void transformArmorData(Armors& armors, Eigen::Matrix4d T_camera_to_odom) noexcept; + void transformArmorData(Armor& armor, const Eigen::Matrix4d& T_camera_to_odom) noexcept; + +} // namespace auto_aim +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_buff/auto_buff.cpp b/wust_vision-main/tasks/auto_buff/auto_buff.cpp new file mode 100644 index 0000000..0e5378c --- /dev/null +++ b/wust_vision-main/tasks/auto_buff/auto_buff.cpp @@ -0,0 +1,360 @@ +#include "auto_buff.hpp" +#include "tasks/auto_buff/debug.hpp" +#include "tasks/auto_buff/rune_control/aimer.hpp" +#include "tasks/auto_buff/rune_detector/rune_detector.hpp" +#include "tasks/auto_buff/rune_tracker/rune_tracker.hpp" +#include "tasks/auto_buff/rune_where/rune_where.hpp" +#include "tasks/type_common.hpp" +#include "tasks/utils/utils.hpp" +namespace wust_vision { +namespace auto_buff { + + struct AutoBuff::Impl { + ~Impl() { + run_flag_ = false; + rune_queue_->stop(); + if (processing_thread_) { + processing_thread_->stop(); + wust_vl::common::concurrency::ThreadManager::instance().unregisterThread( + processing_thread_->getName() + ); + } + } + Impl( + const std::string& config_path, + TFConfig::Ptr tf_config, + const std::pair& camera_info, + bool debug + ) { + debug_mode_ = debug; + auto_buff_config_parameter_ = wust_vl::common::utils::Parameter::create(); + auto_buff_config_parameter_->loadFromFile(config_path); + auto_exposure_cfg_ = AutoExposureCfg::create(); + aimer_ = auto_buff::Aimer::create(auto_buff_config_parameter_); + rune_tracker_ = RuneTracker::create(auto_buff_config_parameter_); + auto_buff_config_parameter_->registerGroup(*auto_exposure_cfg_); + auto_buff_config_parameter_->reloadFromOldPath(); + auto config = YAML::LoadFile(config_path); + wust_vl::common::utils::ParameterManager::instance().registerParameter( + *auto_buff_config_parameter_.get() + ); + tf_config_ = tf_config; + camera_info_ = camera_info; + + rune_where_ = auto_buff::RuneWhere::create(config["rune_where"], camera_info); + + rune_detector_ = RuneDetectorCV::make_detector(config["rune_detector"]); + rune_detector_->setCallback(std::bind( + &AutoBuff::Impl::runeDetectCallback, + this, + std::placeholders::_1, + std::placeholders::_2, + std::placeholders::_3 + )); + rune_queue_ = + std::make_unique>( + 50, + 500 + ); + + latency_averager_ = + std::make_unique>(100); + } + void start() { + if (run_flag_) { + return; + } + run_flag_ = true; + processing_thread_ = wust_vl::common::concurrency::MonitoredThread::create( + "AutoBuffProcessingThread", + [this](wust_vl::common::concurrency::MonitoredThread::Ptr self) { + this->processingLoop(self); + } + ); + wust_vl::common::concurrency::ThreadManager::instance().registerThread( + processing_thread_ + ); + } + void pushInput(CommonFrame& frame) { + img_recv_count_++; + auto bbox = rune_target_.expanded( + T_camera_to_odom_, + camera_info_.first, + camera_info_.second, + frame.img_frame.src_img.size() + ); + if (bbox.area() > 100) { + frame.expanded = bbox; + frame.offset = cv::Point2f(bbox.x, bbox.y); + } + expanded_ = frame.expanded; + rune_detector_->pushInput(frame, debug_mode_); + } + void runeDetectCallback( + const auto_buff::RuneFan& fan, + const CommonFrame& frame, + cv::Mat& debug_img + ) { + std::lock_guard lock(callback_mutex_); + Eigen::Vector3d v = Eigen::Vector3d::Zero(); + Eigen::Matrix3d R_gimbal2odom = Eigen::Matrix3d::Identity(); + auto ctx = std::any_cast(frame.any_ctx); + std::pair gimbal_py; + if (ctx.motion_buffer) { + const auto delay = + std::chrono::microseconds(static_cast(ctx.communication_delay_μs)); + const auto t_query = fan.timestamp + delay; + auto apply_motion = [&](const auto& att) { + v << att.data.vx, att.data.vy, att.data.vz; + R_gimbal2odom = Eigen::AngleAxisd(att.data.yaw, Eigen::Vector3d::UnitZ()) + * Eigen::AngleAxisd(-att.data.pitch, Eigen::Vector3d::UnitY()) + * Eigen::AngleAxisd(att.data.roll, Eigen::Vector3d::UnitX()); + gimbal_py = std::make_pair(att.data.pitch, att.data.yaw); + }; + if (auto past_att = ctx.motion_buffer->get_interpolated(t_query)) { + apply_motion(*past_att); + } else if (auto last_att = ctx.motion_buffer->get_last()) { + apply_motion(*last_att); + } + } + autoExposureControl(frame.img_frame.src_img, ctx.camera); + Eigen::Matrix4d T_camera_to_odom = utils::computeCameraToOdomTransform( + R_gimbal2odom, + tf_config_->R_camera2gimbal, + tf_config_->t_camera2gimbal + ); + T_camera_to_odom_ = T_camera_to_odom; + auto_buff::RuneFan copy_fan = rune_where_->where(fan, T_camera_to_odom); + copy_fan.is_big = + InfantryMode::toAttackMode(ctx.mode) == InfantryMode::AttackMode::BIG_RUNE; + rune_queue_->enqueue(copy_fan); + if (debug_mode_) { + std::lock_guard lock(dbg_mutex_); + auto_buff_debug_.img_frame = frame.img_frame; + auto_buff_debug_.T_camera_to_odom = T_camera_to_odom_; + auto_buff_debug_.expanded = frame.expanded; + auto_buff_debug_.pnp_distance = + copy_fan.fans.empty() ? 0.0 : copy_fan.fans[0].pos.norm(); + auto_buff_debug_.gimbal_py = gimbal_py; + } + + detect_finish_count_++; + } + + void runeTargetCallback(const auto_buff::RuneFan& fan) { + if (fan.timestamp <= last_rune_target_time_) { + WUST_WARN(logger_) << "Received out-of-order auto_buff data, discarded."; + return; + } + last_rune_target_time_ = fan.timestamp; + + auto rune_target = rune_tracker_->track(fan); + + { + std::lock_guard lock(target_mutex_); + rune_target_ = rune_target; + } + auto now = std::chrono::steady_clock::now(); + auto latency_ns = std::chrono::duration_cast( + std::chrono::steady_clock::now() - fan.timestamp + ) + .count(); + auto latency_ms = wust_vl::common::utils::time_utils::durationMs(fan.timestamp, now); + latency_averager_->add(latency_ms); + auto_buff_debug_.latency_ms = latency_averager_->average(); + if (debug_mode_) { + std::lock_guard lock(dbg_mutex_); + static double last_unwrapped_roll = 0.0; + static double last_raw_roll = 0.0; + const double raw_roll = rune_target.roll(); + const double raw_pred = rune_target.predictAngle(0.5); + const double obs_angle = last_unwrapped_roll + + angles::shortest_angular_distance(last_raw_roll, raw_roll); + const double pre_angle = + obs_angle + angles::shortest_angular_distance(raw_roll, raw_pred); + last_unwrapped_roll = obs_angle; + last_raw_roll = raw_roll; + auto_buff_debug_.obs_v = rune_target.v_roll(); + auto_buff_debug_.fitter_v = rune_target.getFitterSpd( + wust_vl::common::utils::time_utils::now() + + std::chrono::microseconds(int(0.2 * 1e6)) + ); + auto_buff_debug_.obs_angle = obs_angle; + auto_buff_debug_.pre_angle = pre_angle; + auto_buff_debug_.target = rune_target; + auto_buff_debug_.power_rune = rune_target.getPowerRune(); + } + } + GimbalCmd solve(double bullet_speed) { + GimbalCmd gimbal_cmd; + auto_buff::RuneTarget rune_target; + + { + std::lock_guard lock(target_mutex_); + rune_target = rune_target_; + } + if (rune_target.checkTargetAppear()) { + gimbal_cmd = aimer_->aim(rune_target, bullet_speed); + } + if (gimbal_cmd.fire_advice) { + fire_count_++; + } + if (debug_mode_) { + std::lock_guard lock(dbg_mutex_); + auto_buff_debug_.gimbal_cmd = gimbal_cmd; + auto_buff_debug_.aim_target = gimbal_cmd.aim_target; + } + timer_cout_++; + + return gimbal_cmd; + } + void processingLoop(wust_vl::common::concurrency::MonitoredThread::Ptr self) { + while (!self->isAlive()) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + while (self->isAlive() && run_flag_) { + if (!self->waitPoint()) + break; + self->heartbeat(); + printStats(); + auto_buff::RuneFan auto_buff; + // bool skip; + // if (rune_queue_->dequeue_wait(auto_buff, skip)) { + // runeTargetCallback(auto_buff); + // tracker_finish_count_++; + // if (skip) { + // WUST_DEBUG(logger_) << "OrderQueue skip"; + // } + // } + if (rune_queue_->try_dequeue(auto_buff)) { + runeTargetCallback(auto_buff); + tracker_finish_count_++; + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + } + } + void doDebug() { + debug_mode_ = true; + AutoBuffDebug dbg; + { + std::lock_guard lock(dbg_mutex_); + dbg = auto_buff_debug_; + } + drawDebugOverlayShm(dbg, camera_info_, false); + debuglog(dbg); + } + void printStats() { + utils::XSecOnce( + [&] { + WUST_INFO(logger_) + << "Rec: " << img_recv_count_ << ", Det: " << detect_finish_count_ + << ", Fin: " << tracker_finish_count_ << ", Tc: " << timer_cout_ + << ", Lat: " << auto_buff_debug_.latency_ms << "ms" + << ", Fire: " << fire_count_; + img_recv_count_ = 0; + detect_finish_count_ = 0; + fire_count_ = 0; + tracker_finish_count_ = 0; + timer_cout_ = 0; + }, + 1.0 + ); + } + void + autoExposureControl(const cv::Mat& frame, std::shared_ptr camera) { + const double dt = auto_exposure_cfg_->control_interval_ms_param.get() / 1000.0; + utils::XSecOnce( + [&] { + if (!auto_exposure_cfg_->enable_param.get() || frame.empty()) { + return; + } + if (auto* hik = dynamic_cast(camera->getDevice())) { + cv::Mat i_use = frame(expanded_); + if (expanded_.area() < 100 || i_use.empty()) { + i_use = frame; + } + const double brightness = utils::computeBrightness(i_use); + + const double diff = + brightness - auto_exposure_cfg_->target_brightness_param.get(); + const double exposure_min = auto_exposure_cfg_->exposure_min_param.get(); + const double exposure_max = auto_exposure_cfg_->exposure_max_param.get(); + double exposure_time = hik->getExposureTime(); + const double last_exposure_time = exposure_time; + if (std::fabs(diff) > auto_exposure_cfg_->tolerance_param.get() + && exposure_time > 0.0) { + exposure_time -= diff * auto_exposure_cfg_->step_gain_param.get(); + } else { + exposure_time -= auto_exposure_cfg_->decay_step_param.get(); + } + if (exposure_time < exposure_min) + exposure_time = exposure_min; + if (exposure_time > exposure_max) + exposure_time = exposure_max; + if (std::abs(exposure_time - last_exposure_time) > 10) { + hik->setExposureTime(exposure_time); + } + } + }, + dt + ); + } + std::mutex callback_mutex_; + RuneDetectorCV::Ptr rune_detector_; + RuneTracker::Ptr rune_tracker_; + auto_buff::Aimer::Ptr aimer_; + RuneWhere::Ptr rune_where_; + std::string logger_ = "auto_buff"; + std::unique_ptr> rune_queue_; + wust_vl::common::concurrency::MonitoredThread::Ptr processing_thread_; + AutoExposureCfg::Ptr auto_exposure_cfg_; + cv::Rect expanded_; + auto_buff::RuneTarget rune_target_; + bool run_flag_ = false; + int detect_finish_count_ = 0; + int img_recv_count_ = 0; + int tracker_finish_count_ = 0; + int timer_cout_ = 0; + int fire_count_; + std::chrono::steady_clock::time_point last_rune_target_time_; + bool debug_mode_ = false; + AutoBuffDebug auto_buff_debug_; + std::unique_ptr> latency_averager_; + TFConfig::Ptr tf_config_; + std::pair camera_info_; + wust_vl::common::utils::Parameter::Ptr auto_buff_config_parameter_; + Eigen::Matrix4d T_camera_to_odom_; + + std::mutex target_mutex_; + std::mutex dbg_mutex_; + }; + AutoBuff::AutoBuff( + const std::string& config_path, + TFConfig::Ptr tf_config, + const std::pair& camera_info, + bool debug + ): + _impl(std::make_unique(config_path, tf_config, camera_info, debug)) {} + AutoBuff::~AutoBuff() { + _impl.reset(); + } + void AutoBuff::start() { + _impl->start(); + } + void AutoBuff::pushInput(CommonFrame& frame) { + _impl->pushInput(frame); + } + + GimbalCmd AutoBuff::solve(double bullet_speed) { + return _impl->solve(bullet_speed); + } + + wust_vl::common::concurrency::MonitoredThread::Ptr AutoBuff::getThread() { + return _impl->processing_thread_; + } + void AutoBuff::doDebug() { + _impl->doDebug(); + } +} // namespace auto_buff +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_buff/auto_buff.hpp b/wust_vision-main/tasks/auto_buff/auto_buff.hpp new file mode 100644 index 0000000..b231311 --- /dev/null +++ b/wust_vision-main/tasks/auto_buff/auto_buff.hpp @@ -0,0 +1,38 @@ +#pragma once +#include "tasks/imodule.hpp" +#include "tasks/type_common.hpp" +#include "wust_vl/video/camera.hpp" +namespace wust_vision { +namespace auto_buff { + + class AutoBuff: public IModule { + public: + using Ptr = std::shared_ptr; + AutoBuff( + const std::string& config_path, + TFConfig::Ptr tf_config, + const std::pair& camera_info, + bool debug + ); + static Ptr create( + const std::string& config_path, + TFConfig::Ptr tf_config, + const std::pair& camera_info, + bool debug + ) { + return std::make_shared(config_path, tf_config, camera_info, debug); + } + ~AutoBuff(); + void start() override; + void doDebug() override; + void pushInput(CommonFrame& frame) override; + GimbalCmd solve(double bullet_speed) override; + wust_vl::common::concurrency::MonitoredThread::Ptr getThread() override; + struct Impl; + std::unique_ptr _impl; + }; + inline AutoBuff::Ptr toAutoBuff(IModule::Ptr module) { + return std::dynamic_pointer_cast(module); + } +} // namespace auto_buff +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_buff/debug.cpp b/wust_vision-main/tasks/auto_buff/debug.cpp new file mode 100644 index 0000000..5a8bd0e --- /dev/null +++ b/wust_vision-main/tasks/auto_buff/debug.cpp @@ -0,0 +1,241 @@ +#include "debug.hpp" +#include "tasks/auto_buff/auto_buff.hpp" + +namespace wust_vision::auto_buff { +void drawDebugRuneContent( + cv::Mat& debug_img, + const AutoBuffDebug& dbg, + std::pair camera_info +) { + const auto& gimbal_cmd = dbg.gimbal_cmd; + double predict_angle = dbg.predict_angle; + auto aim_target = dbg.aim_target; + auto auto_buff = dbg.power_rune; + const cv::Rect img_rect(0, 0, debug_img.cols, debug_img.rows); + const cv::Rect roi = dbg.expanded & img_rect; + cv::rectangle(debug_img, roi, cv::Scalar(255, 255, 255), 2); + + const std::string latency_str = fmt::format("Latency: {:.2f}ms", dbg.latency_ms); + cv::putText( + debug_img, + latency_str, + cv::Point(10, 30), + cv::FONT_HERSHEY_SIMPLEX, + 0.8, + cv::Scalar(255, 255, 255), + 2 + ); + aim_target.tf(dbg.T_camera_to_odom.inverse()); + { + const auto pts = aim_target.toPts(camera_info.first, camera_info.second); + if (!pts.empty()) { + cv::Scalar color = cv::Scalar(255, 255, 255); + for (int i = 0; i < 4; i++) + cv::line(debug_img, pts[i], pts[(i + 1) % 4], color, 2); + + // 后表面 + for (int i = 4; i < 8; i++) + cv::line(debug_img, pts[i], pts[4 + (i + 1) % 4], color, 2); + + // 侧边 + for (int i = 0; i < 4; i++) + cv::line(debug_img, pts[i], pts[i + 4], color, 2); + cv::Point2f center(0.f, 0.f); + for (auto pt: pts) { + center += pt; + } + center *= 1.0 / pts.size(); + + if (gimbal_cmd.fire_advice) { + int cross_len = 60; + cv::line( + debug_img, + center + cv::Point2f(-cross_len, -cross_len), + center + cv::Point2f(+cross_len, +cross_len), + cv::Scalar(0, 0, 255), + 5 + ); + cv::line( + debug_img, + center + cv::Point2f(-cross_len, +cross_len), + center + cv::Point2f(+cross_len, -cross_len), + cv::Scalar(0, 0, 255), + 5 + ); + } + + const double scale = 10.0; + const double v_yaw = gimbal_cmd.v_yaw; + const double v_pitch = gimbal_cmd.v_pitch; + const double dx = -scale * v_yaw; + const double dy = scale * v_pitch; + + const cv::Point2f start_pt = center; + const cv::Point2f end_pt = start_pt + cv::Point2f(dx, dy); + const cv::Scalar color_x = cv::Scalar(50, 50, 255); + cv::arrowedLine(debug_img, start_pt, end_pt, color_x, 4, cv::LINE_AA, 0, 0.2); + } + } + if (gimbal_cmd.fire_advice) { + const std::string fire_str = "Fire!"; + cv::putText( + debug_img, + fire_str, + { debug_img.cols / 2 - 100, 200 }, + cv::FONT_HERSHEY_SIMPLEX, + 2.85, + cv::Scalar(0, 0, 255), + 2 + ); + } + + const std::string gimbal_str = fmt::format( + "Pitch: {:.2f}, Yaw: {:.2f}, Enable_pitch_diff: {:.2f}, Enable_yaw_diff: {:.2f}, V_yaw: {:.2f}, V_pitch: {:.2f}", + gimbal_cmd.pitch, + gimbal_cmd.yaw, + gimbal_cmd.enable_pitch_diff, + gimbal_cmd.enable_yaw_diff, + gimbal_cmd.v_yaw, + gimbal_cmd.v_pitch + ); + cv::putText( + debug_img, + gimbal_str, + { 10, debug_img.rows - 30 }, + cv::FONT_HERSHEY_SIMPLEX, + 0.8, + cv::Scalar(255, 255, 0), + 2 + ); + auto_buff.tf(dbg.T_camera_to_odom.inverse()); + auto_buff.draw(debug_img, camera_info.first, camera_info.second); + cv::circle( + debug_img, + cv::Point2i(debug_img.cols / 2, debug_img.rows / 2), + 5, + cv::Scalar(255, 255, 255), + 2 + ); +} +void writeTargetLogToJson(const auto_buff::RuneTarget& rune_target) { + nlohmann::json j; + + const auto now = std::chrono::steady_clock::now(); + + nlohmann::json jr; + jr["tracking"] = true; + jr["id"] = static_cast(rune_target.last_id); + + const auto age_ms_r = + std::chrono::duration_cast(now - rune_target.timestamp_).count(); + jr["timestamp_age_ms"] = age_ms_r; + + jr["position"] = { { "x", rune_target.centerPos().x() }, + { "y", rune_target.centerPos().y() }, + { "z", rune_target.centerPos().z() } }; + + jr["roll"] = rune_target.roll() * 180.0 / M_PI; + jr["yaw"] = rune_target.yaw() * 180.0 / M_PI; + jr["v_roll"] = rune_target.v_roll() * 180.0 / M_PI; + + j["rune_target"] = jr; + + // -------- 写文件 -------- + std::ofstream file("/dev/shm/target_log.json"); + if (file.is_open()) { + file << j.dump(2); + } +} +struct DebugLogs { +#define DEBUG_LOG_LIST(X) \ + X(double, 100, time) \ + X(double, 100, raw_yaw) \ + X(double, 100, raw_pitch) \ + X(double, 100, yaw) \ + X(double, 100, pitch) \ + X(double, 100, ypd_y) \ + X(double, 100, ypd_p) \ + X(double, 100, rune_obs) \ + X(double, 100, rune_pre) \ + X(double, 100, rune_obsv) \ + X(double, 100, rune_fitv) \ + X(double, 100, gimbal_yaw) \ + X(double, 100, gimbal_pitch) \ + X(double, 100, target_v_yaw) \ + X(double, 100, control_v_yaw) \ + X(double, 100, control_v_pitch) \ + X(double, 100, yaw_diff) \ + X(double, 100, fire) \ + X(double, 100, rune_dis) \ + X(double, 100, fly_time) \ + X(double, 100, control_a_yaw) \ + X(double, 100, control_a_pitch) +#define GEN_LOG(TYPE, SIZE, NAME) LogsStream NAME##_log { #NAME }; + +#define X(TYPE, SIZE, NAME) GEN_LOG(TYPE, SIZE, NAME) + DEBUG_LOG_LIST(X) +#undef X + + void clear() { +#define X(TYPE, SIZE, NAME) NAME##_log.clear(); + DEBUG_LOG_LIST(X) +#undef X + } +}; + +void debuglog(const AutoBuffDebug& dbg_rune) { + static bool first_log = true; + static std::chrono::steady_clock::time_point start_time; + static DebugLogs log; + static GimbalCmd last_cmd_; + static double rune_dis = 0.0; + if (first_log) { + start_time = std::chrono::steady_clock::now(); + first_log = false; + } + + const auto now = std::chrono::steady_clock::now(); + const double t = std::chrono::duration(now - start_time).count(); + + const auto_buff::RuneTarget& rune_target = dbg_rune.target; + writeTargetLogToJson(rune_target); + + double armor_yaw = 0.0, ypd_y = 0.0, ypd_p = 0.0, armor_distance = 0.0; + if (dbg_rune.pnp_distance > 1.0) { + rune_dis = dbg_rune.pnp_distance; + } + + GimbalCmd i_use; + if (dbg_rune.gimbal_cmd.appear) { + i_use = dbg_rune.gimbal_cmd; + } else { + i_use = last_cmd_; + } + last_cmd_ = i_use; + nlohmann::json j; + log.time_log.handleOnce(t, j); + log.raw_yaw_log.handleOnce(i_use.target_yaw, j); + log.raw_pitch_log.handleOnce(i_use.target_pitch, j); + log.yaw_log.handleOnce(i_use.yaw, j); + log.pitch_log.handleOnce(i_use.pitch, j); + log.rune_obs_log.handleOnce(dbg_rune.obs_angle, j); + log.rune_pre_log.handleOnce(dbg_rune.pre_angle, j); + log.rune_fitv_log.handleOnce(dbg_rune.fitter_v * 180.0 / M_PI, j); + log.rune_obsv_log.handleOnce(dbg_rune.obs_v * 180.0 / M_PI, j); + log.gimbal_pitch_log.handleOnce(dbg_rune.gimbal_py.first * 180.0 / M_PI, j); + log.gimbal_yaw_log.handleOnce(dbg_rune.gimbal_py.second * 180.0 / M_PI, j); + log.control_v_pitch_log.handleOnce(i_use.v_pitch, j); + log.control_v_yaw_log.handleOnce(i_use.v_yaw, j); + log.fire_log.handleOnce(i_use.fire_advice, j); + log.rune_dis_log.handleOnce(rune_dis, j); + log.fly_time_log.handleOnce(i_use.fly_time, j); + log.control_a_yaw_log.handleOnce(i_use.a_yaw / 180.0 * M_PI, j); + log.control_a_pitch_log.handleOnce(i_use.a_pitch / 180.0 * M_PI, j); + + std::ofstream file("/dev/shm/cmd_log.json"); + if (file.is_open()) { + file << j.dump(); + } +} + +} // namespace wust_vision::auto_buff \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_buff/debug.hpp b/wust_vision-main/tasks/auto_buff/debug.hpp new file mode 100644 index 0000000..967129a --- /dev/null +++ b/wust_vision-main/tasks/auto_buff/debug.hpp @@ -0,0 +1,39 @@ +#pragma once +#include "tasks/auto_buff/auto_buff.hpp" +#include "tasks/auto_buff/rune_tracker/rune_target.hpp" +#include "tasks/utils/debug_utils.hpp" +#include +namespace wust_vision::auto_buff { +struct AutoBuffDebug { + wust_vl::video::ImageFrame img_frame; + auto_buff::RuneTarget target; + AimTarget aim_target; + auto_buff::PowerRune power_rune; + double predict_angle; + GimbalCmd gimbal_cmd; + Eigen::Matrix4d T_camera_to_odom; + double latency_ms; + double obs_angle; + double pre_angle; + double fitter_v; + double obs_v; + cv::Rect expanded; + double pnp_distance; + std::pair gimbal_py; +}; +void drawDebugRuneContent( + cv::Mat& debug_img, + const AutoBuffDebug& dbg, + std::pair camera_info +); +void writeTargetLogToJson(const auto_buff::RuneTarget& rune_target); +inline void drawDebugOverlayShm( + const AutoBuffDebug& dbg, + std::pair camera_info, + bool auto_fps +) { + static ShmWriter shm { "/debug_frame" }; + drawDebugOverlayImpl(dbg, camera_info, auto_fps, drawDebugRuneContent, shm); +} +void debuglog(const AutoBuffDebug& dbg); +} // namespace wust_vision::auto_buff \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_buff/rune_control/aimer.cpp b/wust_vision-main/tasks/auto_buff/rune_control/aimer.cpp new file mode 100644 index 0000000..1f0f5cd --- /dev/null +++ b/wust_vision-main/tasks/auto_buff/rune_control/aimer.cpp @@ -0,0 +1,237 @@ +#include "aimer.hpp" +#include "tasks/auto_buff/rune_tracker/rune_target.hpp" +#include "wust_vl/common/utils/manual_compensator.hpp" +namespace wust_vision { + +namespace auto_buff { + + struct Aimer::Impl { + public: + struct AimerConfig: wust_vl::common::utils::ParamGroup { + static constexpr const char* Logger = "Config: auto_buff::aimer"; + static constexpr const char* kKey = "aimer"; + const char* key() const override { + return kKey; + } + using Ptr = std::shared_ptr; + AimerConfig() {} + static Ptr create() { + return std::make_shared(); + } + std::shared_ptr manual_compensator; + GEN_PARAM(double, prediction_delay); + GEN_PARAM(double, shooting_range_h); + GEN_PARAM(double, shooting_range_w); + GEN_PARAM(double, min_enable_pitch_deg); + GEN_PARAM(double, min_enable_yaw_deg); + bool first_load = false; + using OffsetEntry = wust_vl::common::utils::OffsetEntry; + + static std::vector + parseTrajectoryOffset(const YAML::Node& node, double& base_pitch, double& base_yaw) { + std::vector entries; + + if (!node || !node["trajectory_offset"]) { + return entries; + } + + for (const auto& n: node["trajectory_offset"]) { + entries.push_back(OffsetEntry { .d_min = n["d_min"].as(), + .d_max = n["d_max"].as(), + .h_min = n["h_min"].as(), + .h_max = n["h_max"].as(), + .pitch_off = n["pitch_off"].as(), + .yaw_off = n["yaw_off"].as() }); + } + + base_pitch = node["base_offset"]["pitch"].as(); + base_yaw = node["base_offset"]["yaw"].as(); + + return entries; + } + std::vector last_entries_; + double last_base_pitch_ = 0.0; + double last_base_yaw_ = 0.0; + bool first_load_ = true; + void loadSelf(const YAML::Node& node) override { + double base_pitch = 0.0; + double base_yaw = 0.0; + + auto entries = parseTrajectoryOffset(node, base_pitch, base_yaw); + + const bool trajectory_changed = first_load_ || entries != last_entries_ + || base_pitch != last_base_pitch_ || base_yaw != last_base_yaw_; + if (trajectory_changed) { + manual_compensator = + std::make_shared(); + + manual_compensator->setBasePitch(base_pitch); + manual_compensator->setBaseYaw(base_yaw); + + if (!manual_compensator->updateMapFlow(entries) || entries.empty()) { + std::cout << "manual_compensator init failed" << std::endl; + } + + last_entries_ = entries; + last_base_pitch_ = base_pitch; + last_base_yaw_ = base_yaw; + first_load_ = false; + } + shooting_range_h_param.load(node); + shooting_range_w_param.load(node); + min_enable_pitch_deg_param.load(node); + min_enable_yaw_deg_param.load(node); + prediction_delay_param.load(node); + } + }; + Impl(wust_vl::common::utils::Parameter::Ptr auto_buff_config_parameter) { + aimer_config_ = AimerConfig::create(); + trajectory_compensator_config_ = TrajectoryCompensatorConfig::create(); + auto_buff_config_parameter->registerGroup(*aimer_config_); + auto_buff_config_parameter->registerGroup(*trajectory_compensator_config_); + auto_buff_config_parameter->reloadFromOldPath(); + } + std::tuple calEnableDiff(Eigen::Vector3d aim_target_pos) const noexcept { + const double distance = aim_target_pos.norm(); + double shooting_range_yaw = + std::abs(atan2(aimer_config_->shooting_range_w_param.get() / 2, distance)); + double shooting_range_pitch = + std::abs(atan2(aimer_config_->shooting_range_h_param.get() / 2, distance)); + constexpr double yaw_factor = 1.0; + constexpr double pitch_factor = 1.0; + shooting_range_yaw = std::max( + shooting_range_yaw, + aimer_config_->min_enable_yaw_deg_param.get() * M_PI / 180 + ); + shooting_range_pitch = std::max( + shooting_range_pitch, + aimer_config_->min_enable_pitch_deg_param.get() * M_PI / 180 + ); + shooting_range_yaw *= yaw_factor; + shooting_range_pitch *= pitch_factor; + + return std::make_tuple(std::abs(shooting_range_yaw), std::abs(shooting_range_pitch)); + } + struct ControlPoint { + double yaw; + double pitch; + Eigen::Vector3d aim_pos; + }; + ControlPoint getControlPoint(RuneTarget target, double bullet_speed) const noexcept { + auto [aim_target_pos, _] = target.getHitPoint(); + ControlPoint cp; + double control_yaw = std::atan2(aim_target_pos.y(), aim_target_pos.x()); + double raw_pitch = std::atan2( + aim_target_pos.z(), + std::sqrt( + aim_target_pos.x() * aim_target_pos.x() + + aim_target_pos.y() * aim_target_pos.y() + ) + ); + try { + trajectory_compensator_config_->trajectory_compensator + ->compensate(aim_target_pos, raw_pitch, bullet_speed); + } catch (std::exception& e) { + std::cout << "compensate error: " << e.what() << std::endl; + } + + double control_pitch = raw_pitch; + const auto offs = aimer_config_->manual_compensator->angleHardCorrect( + aim_target_pos.head(2).norm(), + aim_target_pos.z() + ); + control_yaw = angles::normalize_angle((control_yaw + offs[1] * M_PI / 180.0)); + control_pitch = (control_pitch + offs[0] * M_PI / 180.0); + cp.pitch = control_pitch; + cp.yaw = control_yaw; + cp.aim_pos = aim_target_pos; + return cp; + } + GimbalCmd aim(RuneTarget target, double bullet_speed) { + GimbalCmd cmd; + + // 当前时间 + const auto now = wust_vl::common::utils::time_utils::now(); + + // 首次预测目标位置 + target.predictWithFitter(now); + auto [p0, q0] = target.getHitPoint(); + + // 迭代计算飞行时间 + bool converged = false; + double prev_fly_time = trajectory_compensator_config_->trajectory_compensator + ->getFlyingTime(p0, bullet_speed); + + std::vector iteration_target(10, target); + for (int iter = 0; iter < 10; ++iter) { + iteration_target[iter].predictWithFitter(prev_fly_time); + auto [pb, qb] = iteration_target[iter].getHitPoint(); + double iter_fly_time = trajectory_compensator_config_->trajectory_compensator + ->getFlyingTime(pb, bullet_speed); + if (std::abs(iter_fly_time - prev_fly_time) < 0.001) { + converged = true; + break; + } + prev_fly_time = iter_fly_time; + } + + const double predict_time = prev_fly_time + aimer_config_->prediction_delay_param.get(); + target.predictWithFitter(predict_time); + const auto cp = getControlPoint(target, bullet_speed); + // RuneTarget target_prev, target_next = target; + RuneTarget target_prev = target; + RuneTarget target_next = target; + + const double dt = 0.01; + target_prev.predictWithFitter(-dt); + auto cp_prev = getControlPoint(target_prev, bullet_speed); + + target_next.predictWithFitter(dt); + auto cp_next = getControlPoint(target_next, bullet_speed); + + double yaw_speed = (cp_next.yaw - cp_prev.yaw) / (2.0 * dt); + double pitch_speed = (cp_next.pitch - cp_prev.pitch) / (2.0 * dt); + double yaw_acc = (cp_next.yaw - 2.0 * cp.yaw + cp_prev.yaw) / (dt * dt); + double pitch_acc = (cp_next.pitch - 2.0 * cp.pitch + cp_prev.pitch) / (dt * dt); + + AimTarget aim_target; + aim_target.pos = cp.aim_pos; + + // 填充 GimbalCmd + cmd.distance = cp.aim_pos.norm(); + cmd.aim_target = aim_target; + cmd.yaw = cp.yaw * 180.0 / M_PI; + cmd.pitch = cp.pitch * 180.0 / M_PI; + // cmd.v_yaw = yaw_speed * 180.0 / M_PI; // 转为度/秒 + // cmd.v_pitch = pitch_speed * 180.0 / M_PI; + // cmd.a_yaw = yaw_acc * 180.0 / M_PI; // 转为度/秒² + // cmd.a_pitch = pitch_acc * 180.0 / M_PI; + cmd.v_yaw = 0.0; // 转为度/秒 + cmd.v_pitch = 0.0; + cmd.a_yaw = 0.0; // 转为度/秒² + cmd.a_pitch = 0.0; + const auto [enable_yaw, enable_pitch] = calEnableDiff(cp.aim_pos); + cmd.fire_advice = true; + cmd.enable_yaw_diff = enable_yaw; + cmd.enable_pitch_diff = enable_pitch; + cmd.target_yaw = cp.yaw * 180.0 / M_PI; + cmd.target_pitch = cp.pitch * 180.0 / M_PI; + cmd.fly_time = prev_fly_time; + cmd.appear = true; + + return cmd; + } + TrajectoryCompensatorConfig::Ptr trajectory_compensator_config_; + AimerConfig::Ptr aimer_config_; + }; + Aimer::Aimer(wust_vl::common::utils::Parameter::Ptr auto_buff_config_parameter) { + _impl = std::make_unique(auto_buff_config_parameter); + } + Aimer::~Aimer() { + _impl.reset(); + } + GimbalCmd Aimer::aim(const auto_buff::RuneTarget& target, double bullet_speed) { + return _impl->aim(target, bullet_speed); + } +} // namespace auto_buff +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_buff/rune_control/aimer.hpp b/wust_vision-main/tasks/auto_buff/rune_control/aimer.hpp new file mode 100644 index 0000000..45d2eed --- /dev/null +++ b/wust_vision-main/tasks/auto_buff/rune_control/aimer.hpp @@ -0,0 +1,22 @@ +#pragma once +#include "tasks/auto_buff/rune_tracker/rune_target.hpp" +#include "tasks/type_common.hpp" +namespace wust_vision { +namespace auto_buff { + class Aimer { + public: + using Ptr = std::unique_ptr; + Aimer(wust_vl::common::utils::Parameter::Ptr auto_buff_config_parameter); + static Ptr create(wust_vl::common::utils::Parameter::Ptr auto_buff_config_parameter) { + return std::make_unique(auto_buff_config_parameter); + } + ~Aimer(); + GimbalCmd aim(const auto_buff::RuneTarget& target, double bullet_speed); + + private: + struct Impl; + std::unique_ptr _impl; + }; + +} // namespace auto_buff +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_buff/rune_detector/rune_detector.cpp b/wust_vision-main/tasks/auto_buff/rune_detector/rune_detector.cpp new file mode 100644 index 0000000..6710427 --- /dev/null +++ b/wust_vision-main/tasks/auto_buff/rune_detector/rune_detector.cpp @@ -0,0 +1,497 @@ +#include "rune_detector.hpp" +#include "tasks/utils/utils.hpp" +#include +namespace wust_vision { +namespace auto_buff { + struct RuneDetectorCV::Impl { + public: + Impl(const YAML::Node& node) { + params_.load(node); + } + + void setCallback(DetectorCallback callback) { + callback_ = callback; + } + cv::Mat preProcess(const cv::Mat& src, bool use_red = false) { + // cv::Mat bin; + // cv::cvtColor(src, bin, cv::COLOR_RGB2GRAY); + // cv::threshold(bin, bin, params_.bin_threshold, 255, cv::THRESH_BINARY); + std::vector channels; + cv::split(src, channels); // BGR + + cv::Mat diff; + if (use_red) { + cv::subtract(channels[2], channels[0], diff); // R - B + } else { + cv::subtract(channels[0], channels[2], diff); // B - R + } + + cv::Mat bin; + cv::threshold(diff, bin, params_.bin_threshold, 255, cv::THRESH_BINARY); + // cv::imshow("bin", bin); + // cv::waitKey(1); + return bin; + } + inline auto_buff::RuneCenter getRuneCenter( + const std::vector>& contours, + const std::vector& hierarchy, + cv::Size image_size, + cv::Point2f offset, + cv::Mat& debug_img, + std::vector& used_flags + ) { + auto_buff::RuneCenter result; + struct Node { + cv::Point2f center; + int idx; + cv::RotatedRect rr; + }; + + std::vector nodes; + + for (int i = 0; i < contours.size(); i++) { + if (used_flags[i]) + continue; + if (hierarchy[i][3] != -1) + continue; + + double area = cv::contourArea(contours[i]); + if (area < params_.rune_center_min_area || area > params_.rune_center_max_area) + continue; + + cv::RotatedRect rr = cv::minAreaRect(contours[i]); + float w = rr.size.width; + float h = rr.size.height; + + if (w < 5 || h < 5) + continue; + + double ratio = (w > h ? w / h : h / w); + if (ratio - 1.0 > params_.rune_center_1x1ratio_tol) + continue; + + double rect_area = w * h; + if (rect_area <= 1e-5) + continue; + + double fill_ratio = area / rect_area; + if (fill_ratio < params_.rune_center_fill_ratio_min) + continue; + + nodes.push_back({ rr.center, i, rr }); + + if (!debug_img.empty()) { + cv::Point2f pts[4]; + rr.points(pts); + for (size_t k = 0; k < 4; k++) { + pts[k] += offset; + } + for (int k = 0; k < 4; k++) { + cv::line(debug_img, pts[k], pts[(k + 1) % 4], cv::Scalar(0, 255, 0), 2); + } + } + } + + if (nodes.empty()) + return result; + + cv::Point2f img_center(image_size.width * 0.5f, image_size.height * 0.5f); + + double best_dist = 1e18; + int best_idx = -1; + cv::RotatedRect best_rr; + + for (auto& n: nodes) { + double dx = n.center.x - img_center.x; + double dy = n.center.y - img_center.y; + double dist2 = dx * dx + dy * dy; + + if (dist2 < best_dist) { + best_dist = dist2; + best_idx = n.idx; + best_rr = n.rr; + } + } + + if (!debug_img.empty()) { + cv::circle( + debug_img, + img_center + offset, + 5, + cv::Scalar(0, 255, 255), + -1 + ); // 图像中心 + + cv::Point2f pts[4]; + best_rr.points(pts); + for (size_t k = 0; k < 4; k++) { + pts[k] += offset; + } + for (int k = 0; k < 4; k++) { + cv::line(debug_img, pts[k], pts[(k + 1) % 4], cv::Scalar(0, 0, 255), 2); + } + } + + return auto_buff::RuneCenter(best_rr); + } + + inline int findTopParent(int idx, const std::vector& hierarchy) { + int p = hierarchy[idx][3]; // parent + while (p != -1 && hierarchy[p][3] != -1) { + p = hierarchy[p][3]; // 一直追溯到最顶层 parent + } + return p; // 若 p == -1 表示 contour 本身就是顶层轮廓 + } + + inline std::vector markRuneTarget( + const std::vector>& contours, + const std::vector& hierarchy, + std::vector& used_flags + ) { + std::vector results; + if (hierarchy.empty()) + return results; + + struct Node { + int idx; + cv::Point2f center; + int parent_top_id; + }; + + std::vector candidates; + + for (int i = 0; i < contours.size(); i++) { + if (used_flags[i]) + continue; + + const auto& cnt = contours[i]; + + double contour_area = cv::contourArea(cnt); + if (contour_area < params_.rune_target_min_area + || contour_area > params_.rune_target_max_area) + continue; + + cv::Moments m = cv::moments(cnt); + if (m.m00 == 0) + continue; + + cv::Point2f center(m.m10 / m.m00, m.m01 / m.m00); + int top_parent = findTopParent(i, hierarchy); + candidates.push_back({ i, center, top_parent }); + } + + if (candidates.size() < 3) + return results; + + std::unordered_map> groups; + for (int i = 0; i < candidates.size(); i++) { + groups[candidates[i].parent_top_id].push_back(i); + } + + for (auto& [parent_top_id, idx_list]: groups) { + int M = idx_list.size(); + if (M < 3 || M > 7) + continue; + + std::vector cluster_id(M, -1); + int cluster_count = 0; + + for (int i = 0; i < M; i++) { + if (cluster_id[i] != -1) + continue; + + cluster_id[i] = cluster_count; + + std::queue q; + q.push(i); + + while (!q.empty()) { + int u = q.front(); + q.pop(); + + for (int v = 0; v < M; v++) { + if (cluster_id[v] != -1) + continue; + + auto& cu = candidates[idx_list[u]].center; + auto& cv = candidates[idx_list[v]].center; + + double dx = cu.x - cv.x; + double dy = cu.y - cv.y; + double dist = std::sqrt(dx * dx + dy * dy); + + if (dist <= params_.rune_target_cluster_radius) { + cluster_id[v] = cluster_count; + q.push(v); + } + } + } + cluster_count++; + } + + std::vector cluster_size(cluster_count, 0); + for (int id: cluster_id) + cluster_size[id]++; + + std::vector> cluster_points(cluster_count); + + for (int i = 0; i < M; i++) { + int cid = cluster_id[i]; + + if (cluster_size[cid] >= 3) { + int contour_index = candidates[idx_list[i]].idx; + used_flags[contour_index] = true; + cluster_points[cid].push_back(candidates[idx_list[i]].center); + } + } + + for (int cid = 0; cid < cluster_count; cid++) { + if (cluster_points[cid].size() < 3) + continue; + + cv::RotatedRect rr = cv::minAreaRect(cluster_points[cid]); + double w = rr.size.width; + double h = rr.size.height; + + if (w < 1 || h < 1) + continue; + + double ratio = (w > h ? w / h : h / w); + if (ratio > params_.rune_target_max_square_ratio) + continue; + + std::vector> dist_list; + dist_list.reserve(cluster_points[cid].size()); + + for (auto& p: cluster_points[cid]) { + double dx = p.x - rr.center.x; + double dy = p.y - rr.center.y; + double dist = dx * dx + dy * dy; + dist_list.emplace_back(dist, p); + } + + std::sort(dist_list.begin(), dist_list.end(), [](auto& a, auto& b) { + return a.first > b.first; + }); + + std::vector corner_points; + for (int i = 0; i < 4 && i < dist_list.size(); i++) + corner_points.push_back(dist_list[i].second); + + auto_buff::RunePan pan; + pan.center = rr.center; + pan.corners = corner_points; + if (corner_points.size() > 3) + pan.is_valid = true; + + results.push_back(pan); + } + } + + return results; + } + + inline void markInvalidContours( + cv::Mat& color, + cv::Mat& debug_img, + const std::vector>& contours, + std::vector& used_flags, + const cv::Rect& valid_rect, + bool filter_red, + double diff_thresh + ) { + used_flags.assign(contours.size(), false); + + for (int i = 0; i < contours.size(); i++) { + cv::Rect r = cv::boundingRect(contours[i]); + if (r.width < 5 || r.height < 5) + continue; + + cv::Rect rr = r & cv::Rect(0, 0, color.cols, color.rows); + if (rr.width < 2 || rr.height < 2) + continue; + + const cv::Mat roi = color(rr); + const cv::Scalar avg = cv::mean(roi); + + const double B = avg[0], G = avg[1], R = avg[2]; + + const double diff_RB = R - B; + const double diff_BR = B - R; + + const bool is_red = (diff_RB > diff_thresh); + const bool is_blue = (diff_BR > diff_thresh); + + bool invalid = false; + + if (filter_red) { + if (is_red) + invalid = true; + } else { + if (is_blue) + invalid = true; + } + cv::Rect inter = r & valid_rect; + bool inside_region = (inter.area() > 0); + + used_flags[i] = !invalid || !inside_region; + + if (!used_flags[i]) { + if (!debug_img.empty()) + cv::drawContours(debug_img, contours, i, cv::Scalar(255, 0, 0), 2); + } + } + } + static bool isUpscaled(const cv::Rect& roi, int model_w, int model_h) { + float scale = std::min(model_w / float(roi.width), model_h / float(roi.height)); + + return scale > 1.0f; + } + void pushInput(CommonFrame& frame, bool debug) { + frame.id = current_id_++; + auto_buff::RuneFan fan { + .is_valid = false, + .id = frame.id, + .timestamp = frame.img_frame.timestamp, + + }; + cv::Mat debug_img; + if (debug) { + debug_img = frame.img_frame.src_img.clone(); + } + cv::Mat roi = frame.img_frame.src_img(frame.expanded); + + cv::Mat processed_img = preProcess(roi, frame.detect_color); + + std::vector> contours; + std::vector hierarchy; + + cv::findContours( + processed_img, + contours, + hierarchy, + cv::RETR_TREE, + cv::CHAIN_APPROX_SIMPLE + ); + std::vector used_flags; + used_flags.assign(contours.size(), false); + markInvalidContours( + roi, + debug_img, + contours, + used_flags, + cv::Rect(0, 0, roi.cols, roi.rows), + frame.detect_color, + params_.color_diff_threshold + ); + auto rune_center = + getRuneCenter(contours, hierarchy, roi.size(), frame.offset, debug_img, used_flags); + std::vector rune_pans = + markRuneTarget(contours, hierarchy, used_flags); + for (auto& rune_pan: rune_pans) { + if (rune_center.is_valid) { + rune_pan.addReferRuneCenter(rune_center); + } + if (rune_pan.is_valid && rune_pan.has_refer) { + auto_buff::RuneFan::Simple simple; + simple.points2d.push_back(rune_center.center); + for (auto& pt: rune_pan.corners) { + simple.points2d.push_back(pt); + } + simple.points2d.push_back(rune_pan.center); + fan.fans.push_back(simple); + } + if (!debug_img.empty()) + rune_pan.draw(debug_img, frame.offset); + } + auto_buff::RuneFan tmp = fan; + for (int i = 0; i < tmp.fans.size(); i++) { + for (int j = 0; j < tmp.fans.size(); j++) { + if (i == j) + continue; + + fan.fans[i].addOther(tmp.fans[j]); + } + } + fan.addOffset(frame.offset); + if (callback_) { + callback_(fan, frame, debug_img); + } + } + + DetectorCallback callback_; + cv::Mat tmp_R_; + int current_id_ = 0; + struct Params { + double rune_center_min_area = 100.0; + double rune_center_max_area = 2000.0; + double rune_center_1x1ratio_tol = 0.7; + double rune_center_fill_ratio_min = 0.7; + + double rune_target_min_area = 100.0; + double rune_target_max_area = 3000.0; + double rune_target_max_square_ratio = 1.3; + double rune_target_cluster_radius = 70.0; + + double bin_threshold = 150.0; + double color_diff_threshold = 40.0; + + int target_width = 416; + int target_height = 416; + + void load(const YAML::Node& node) { + // center params + rune_center_min_area = node["rune_center_min_area"] + ? node["rune_center_min_area"].as() + : rune_center_min_area; + rune_center_max_area = node["rune_center_max_area"] + ? node["rune_center_max_area"].as() + : rune_center_max_area; + rune_center_1x1ratio_tol = node["rune_center_1x1ratio_tol"] + ? node["rune_center_1x1ratio_tol"].as() + : rune_center_1x1ratio_tol; + rune_center_fill_ratio_min = node["rune_center_fill_ratio_min"] + ? node["rune_center_fill_ratio_min"].as() + : rune_center_fill_ratio_min; + + // target params + rune_target_min_area = node["rune_target_min_area"] + ? node["rune_target_min_area"].as() + : rune_target_min_area; + rune_target_max_area = node["rune_target_max_area"] + ? node["rune_target_max_area"].as() + : rune_target_max_area; + rune_target_max_square_ratio = node["rune_target_max_square_ratio"] + ? node["rune_target_max_square_ratio"].as() + : rune_target_max_square_ratio; + rune_target_cluster_radius = node["rune_target_cluster_radius"] + ? node["rune_target_cluster_radius"].as() + : rune_target_cluster_radius; + + bin_threshold = + node["bin_threshold"] ? node["bin_threshold"].as() : bin_threshold; + + color_diff_threshold = node["color_diff_threshold"] + ? node["color_diff_threshold"].as() + : color_diff_threshold; + + target_width = node["target_width"] ? node["target_width"].as() : target_width; + target_height = + node["target_height"] ? node["target_height"].as() : target_height; + } + } params_; + }; + RuneDetectorCV::RuneDetectorCV(const YAML::Node& node) { + _impl = std::make_unique(node); + } + RuneDetectorCV::~RuneDetectorCV() { + _impl.reset(); + } + void RuneDetectorCV::pushInput(CommonFrame& frame, bool debug) { + _impl->pushInput(frame, debug); + } + void RuneDetectorCV::setCallback(DetectorCallback callback) { + _impl->setCallback(callback); + } +} // namespace auto_buff +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_buff/rune_detector/rune_detector.hpp b/wust_vision-main/tasks/auto_buff/rune_detector/rune_detector.hpp new file mode 100644 index 0000000..f140bfb --- /dev/null +++ b/wust_vision-main/tasks/auto_buff/rune_detector/rune_detector.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include "tasks/auto_buff/type.hpp" +namespace wust_vision { +namespace auto_buff { + class RuneDetectorCV { + public: + using DetectorCallback = + std::function; + using Ptr = std::unique_ptr; + RuneDetectorCV(const YAML::Node& node); + static inline std::unique_ptr make_detector(const YAML::Node& node) { + return std::make_unique(node); + } + ~RuneDetectorCV(); + void pushInput(CommonFrame& frame, bool debug = false); + void setCallback(DetectorCallback callback); + + private: + struct Impl; + std::unique_ptr _impl; + }; +} // namespace auto_buff +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_buff/rune_tracker/motion_models/motion_modelrypd.hpp b/wust_vision-main/tasks/auto_buff/rune_tracker/motion_models/motion_modelrypd.hpp new file mode 100644 index 0000000..527436d --- /dev/null +++ b/wust_vision-main/tasks/auto_buff/rune_tracker/motion_models/motion_modelrypd.hpp @@ -0,0 +1,71 @@ +// Copyright 2025 Xiaojian Wu +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "KalmanHyLib/kalman_hybird_lib.hpp" + +namespace ypdrune_motion_model { + +constexpr int X_N = 6, Z_N = 5; +using VecZ = Eigen::Matrix; +using VecX = Eigen::Matrix; +enum class Meas : uint8_t { YPD_Y = 0, YPD_P = 1, YPD_D = 2, ORI_YAW = 3, ORI_ROLL = 4, Z_N = 5 }; +enum class State : uint8_t { CX = 0, CY = 1, CZ = 2, YAW = 3, ROLL = 4, VROLL = 5, X_N = 6 }; +struct Predict { + Predict() = default; + explicit Predict(double dt): dt(dt) {} + template + void operator()(const T x0[X_N], T x1[X_N]) const { + for (int i = 0; i < X_N; ++i) { + x1[i] = x0[i]; + } + x1[(int)State::ROLL] += x0[(int)State::VROLL] * dt; + } + double dt; +}; +template +T normalize_angle_t(T angle) { + T two_pi = T(2.0 * M_PI); + return angle - two_pi * floor((angle + T(M_PI)) / two_pi); +} +struct Measure { + Measure() = default; + explicit Measure(int id): id(id) {} + template + void operator()(const T x[X_N], T z[Z_N]) const { + T xy_dist = ceres::sqrt( + x[(int)State::CX] * x[(int)State::CX] + x[(int)State::CY] * x[(int)State::CY] + ); + T dist = ceres::sqrt(xy_dist * xy_dist + x[(int)State::CZ] * x[(int)State::CZ]); + + // Observation model + z[(int)Meas::YPD_Y] = ceres::atan2(x[1], x[0]); // yaw + z[(int)Meas::YPD_P] = ceres::atan2(x[2], xy_dist); // pitch + z[(int)Meas::YPD_D] = dist; // distance + z[(int)Meas::ORI_YAW] = x[(int)State::YAW]; // orientation_yaw + z[(int)Meas::ORI_ROLL] = normalize_angle_t(x[(int)State::ROLL] + id * 2 * M_PI / 5); // roll + } + void h(const VecX& x, VecZ& z) const { + assert(x.size() == X_N); + assert(z.size() == Z_N); + operator()(x.data(), z.data()); + } + int id = 0; +}; + +using RuneESKF = kalman_hybird_lib::ErrorStateEKF; +} // namespace ypdrune_motion_model \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_buff/rune_tracker/rune_target.cpp b/wust_vision-main/tasks/auto_buff/rune_tracker/rune_target.cpp new file mode 100644 index 0000000..487c882 --- /dev/null +++ b/wust_vision-main/tasks/auto_buff/rune_tracker/rune_target.cpp @@ -0,0 +1,336 @@ +#include "rune_target.hpp" +#include +namespace wust_vision { +namespace auto_buff { + RuneTarget::RuneTarget( + const auto_buff::RuneFan& fan, + RuneTargetConfig::Ptr target_config, + double pre_v_roll + ) { + is_big_ = false; + start_time_ = fan.timestamp; + target_config_ = target_config; + fitter_.setWindow(target_config_->big_window_sec_param.get()); + auto f = MModel::Predict(0.005); + auto h = MModel::Measure(0); + auto u_q = [this]() { + Eigen::Matrix q; + return q; + }; + + auto u_r = [this](const Eigen::Matrix& z) { + Eigen::Matrix r; + return r; + }; + Eigen::DiagonalMatrix p0; + p0.setIdentity(); + esekf_ypd_ = MModel::RuneESKF(f, h, u_q, u_r, p0); + esekf_ypd_.setResidualFunc([](const Eigen::VectorXd& z_pred, const Eigen::VectorXd& z) { + Eigen::VectorXd r = z - z_pred; + r[(int)MModel::Meas::YPD_Y] = angles::shortest_angular_distance( + z_pred[(int)MModel::Meas::YPD_Y], + z[(int)MModel::Meas::YPD_Y] + ); // yaw + r[(int)MModel::Meas::ORI_YAW] = angles::shortest_angular_distance( + z_pred[(int)MModel::Meas::ORI_YAW], + z[(int)MModel::Meas::ORI_YAW] + ); // ori_yaw + r[(int)MModel::Meas::ORI_ROLL] = angles::shortest_angular_distance( + z_pred[(int)MModel::Meas::ORI_ROLL], + z[(int)MModel::Meas::ORI_ROLL] + ); // ori_roll + return r; + }); + esekf_ypd_.setIterationNum(target_config_->esekf_iter_num_param.get()); + esekf_ypd_.setInjectFunc([](const Eigen::Matrix& delta, + Eigen::Matrix& nominal) { + for (int i = 0; i < MModel::X_N; i++) { + if (i == (int)MModel::Meas::ORI_YAW || i == (int)MModel::Meas::ORI_ROLL) + continue; + nominal[i] += delta[i]; + } + nominal[(int)MModel::Meas::ORI_YAW] = angles::normalize_angle( + nominal[(int)MModel::Meas::ORI_YAW] + delta[(int)MModel::Meas::ORI_YAW] + ); + nominal[(int)MModel::Meas::ORI_ROLL] = angles::normalize_angle( + nominal[(int)MModel::Meas::ORI_ROLL] + delta[(int)MModel::Meas::ORI_ROLL] + ); + }); + + double xc = fan.fans.front().target_pos.x(); + double yc = fan.fans.front().target_pos.y(); + double zc = fan.fans.front().target_pos.z(); + double yaw = utils::orientationToYaw(fan.fans.front().target_ori); + double roll = utils::orientationToRoll(fan.fans.front().target_ori); + target_state_ = Eigen::VectorXd::Zero(MModel::X_N); + target_state_ << xc, yc, zc, yaw, roll, pre_v_roll; + esekf_ypd_.setState(target_state_); + fitter_.update(0, 0); + last_time_ = 0; + is_inited = true; + last_t_ = fan.timestamp; + timestamp_ = fan.timestamp; + } + Eigen::Matrix + RuneTarget::computeMeasurementCovariance(const Eigen::Matrix& z) const { + Eigen::Matrix r; + // clang-format off + r << target_config_->yp_r_param.get() , 0 , 0 , 0 , 0, + 0 , target_config_->yp_r_param.get() , 0 , 0 , 0, + 0 , 0 , target_config_->dis_r_param.get() , 0 , 0, + 0 , 0 , 0 , target_config_->yaw_r_param.get() , 0, + 0 , 0 , 0 , 0 , target_config_->roll_r_param.get(); + // clang-format on + return r; + } + Eigen::Matrix RuneTarget::computeProcessNoise(double dt + ) const { + Eigen::Matrix q; + double t = dt; + double v1 = target_config_->q_roll_param.get(); + double q_roll_roll = pow(t, 4) / 4 * v1, q_roll_vroll = pow(t, 3) / 2 * v1, + q_vroll_vroll = pow(t, 2) * v1; + double q_xyz = target_config_->q_xyz_param.get(); + double q_yaw = target_config_->q_yaw_param.get(); + // clang-format off + // xc yc zc yaw roll v_roll + q << q_xyz, 0, 0, 0, 0, 0, + 0, q_xyz, 0, 0, 0, 0, + 0, 0, q_xyz, 0, 0, 0, + 0, 0, 0, q_yaw, 0, 0, + 0, 0, 0, 0, q_roll_roll, q_roll_vroll, + 0, 0, 0, 0, q_roll_vroll, q_vroll_vroll; + + // clang-format on + return q; + } + void RuneTarget::predict(std::chrono::steady_clock::time_point t) { + double dt = wust_vl::common::utils::time_utils::durationSec(last_t_, t); + + predict(dt); + + last_t_ = t; + } + void RuneTarget::predict(double dt) { + dt_ = dt; + + esekf_ypd_.setPredictFunc(MModel::Predict { dt }); + auto u_q = [dt, this]() { return computeProcessNoise(dt); }; + + esekf_ypd_.setUpdateQ(u_q); + + target_state_ = esekf_ypd_.predict(); + } + bool RuneTarget::update(const auto_buff::RuneFan& fans) { + timestamp_ = fans.timestamp; + if (fans.fans.empty()) { + return false; + } + + update_ids.clear(); + auto matched = match(fans.fans); + bool has_match = false; + for (auto [id, fan]: matched) { + measurement_ = getMeasure(fan); + update_ids.push_back(id); + auto yu_rv2 = [this](const Eigen::Matrix& z) { + return this->computeMeasurementCovariance(z); + }; + esekf_ypd_.setUpdateR(yu_rv2); + esekf_ypd_.setMeasureFunc(MModel::Measure { id }); + + esekf_ypd_.update(measurement_); + if (!is_big_) + last_id = id; + has_match = true; + } + bool no_change = true; + for (auto id: update_ids) { + if (id != last_id) + no_change = false; + } + if (!no_change && update_ids.size() > 1) + last_id = update_ids[0]; + // if (update_ids.size() > 1) + // is_big_ = true; + double tostart = + wust_vl::common::utils::time_utils::durationSec(start_time_, fans.timestamp); + fitter_.update(tostart, v_roll()); + fitter_.setAngleRef(tostart, roll()); + fitter_.fitAsync(); + last_time_ = tostart; + + return has_match; + } + cv::Rect RuneTarget::expanded( + Eigen::Matrix4d T_camera_to_odom, + const cv::Mat& camera_intrinsic, + const cv::Mat& camera_distortion, + const cv::Size& image_size + ) { + double dt = wust_vl::common::utils::time_utils::durationSec( + timestamp_, + wust_vl::common::utils::time_utils::now() + ); + if (!is_inited || dt > target_config_->lost_time_thres_param.get()) { + return cv::Rect(0, 0, 0, 0); + } + + const float car_box_half = 1.0; + + static std::vector CAR_BOX; + CAR_BOX = { { 0, car_box_half, -car_box_half }, + { 0, -car_box_half, -car_box_half }, + { 0, -car_box_half, car_box_half }, + { 0, car_box_half, car_box_half } }; + + Eigen::Matrix4d T_odom_to_camera = T_camera_to_odom.inverse(); + Eigen::Vector4d pos_odom(centerPos().x(), centerPos().y(), centerPos().z(), 1.0); + Eigen::Vector4d pos_cam = T_odom_to_camera * pos_odom; + + if (pos_cam.z() <= 0.2) { + return cv::Rect(0, 0, 0, 0); + } + + cv::Mat tvec = (cv::Mat_(3, 1) << pos_cam.x(), pos_cam.y(), pos_cam.z()); + + Eigen::Vector3d euler; + euler.z() = M_PI / 2.0; + euler.y() = 0; + euler.x() = std::atan2(pos_odom.y(), pos_odom.x()); + + Eigen::Quaterniond ori = utils::eulerToQuat(euler, utils::EulerOrder::ZYX); + auto target_ori = utils::transformOrientation(ori, T_odom_to_camera); + Eigen::Matrix3d tf_rot = target_ori.toRotationMatrix(); + + cv::Mat rot_mat = + (cv::Mat_(3, 3) << tf_rot(0, 0), + tf_rot(0, 1), + tf_rot(0, 2), + tf_rot(1, 0), + tf_rot(1, 1), + tf_rot(1, 2), + tf_rot(2, 0), + tf_rot(2, 1), + tf_rot(2, 2)); + + cv::Mat rvec; + cv::Rodrigues(rot_mat, rvec); + + std::vector pts_2d; + cv::projectPoints(CAR_BOX, rvec, tvec, camera_intrinsic, camera_distortion, pts_2d); + + cv::Rect rect = cv::boundingRect(pts_2d); + + cv::Rect img_rect(0, 0, image_size.width, image_size.height); + if ((rect & img_rect).area() <= 0) { + return cv::Rect(0, 0, 0, 0); + } + + int base_side = std::max(rect.width, rect.height); + int max_side = std::max(image_size.width, image_size.height); + + double lost_dt = target_config_->lost_time_thres_param.get(); + double dt_clamped = std::max(0.0, std::min(dt, lost_dt)); + + int side = static_cast(base_side + (max_side - base_side) * (dt_clamped / lost_dt)); + + if (dt >= lost_dt) { + side = max_side; + } + + int cx = rect.x + rect.width / 2; + int cy = rect.y + rect.height / 2; + cv::Rect square(cx - side / 2, cy - side / 2, side, side); + + square &= img_rect; + + return square; + } + std::vector> + RuneTarget::match(const std::vector& fans) { + std::vector> result; + const int n_obs = (int)(fans.size()); + const int armors_num = 5; + const double GATE = target_config_->match_gate_param.get(); + const double max_cost = 1e9; + std::vector> cost(n_obs, std::vector(armors_num, max_cost + 1)); + std::vector meas_list(n_obs); + for (int j = 0; j < n_obs; ++j) { + meas_list[j] = getMeasure(fans[j]); + } + for (int j = 0; j < n_obs; ++j) { + for (int id = 0; id < armors_num; ++id) { + MModel::Measure measure(id); + MModel::VecZ z_pred; + measure.h(target_state_, z_pred); + + MModel::VecZ nu = meas_list[j] - z_pred; + nu[(int)MModel::Meas::YPD_Y] = + angles::normalize_angle(nu[(int)MModel::Meas::YPD_Y]); + nu[(int)MModel::Meas::YPD_P] = + angles::normalize_angle(nu[(int)MModel::Meas::YPD_P]); + nu[(int)MModel::Meas::ORI_YAW] = + angles::normalize_angle(nu[(int)MModel::Meas::ORI_YAW]); + nu[(int)MModel::Meas::ORI_ROLL] = + angles::normalize_angle(nu[(int)MModel::Meas::ORI_ROLL]); + auto R = computeMeasurementCovariance(z_pred); + + double d2 = nu.transpose() * R.ldlt().solve(nu); + + // 门控 + if (std::isfinite(d2) && d2 < GATE) { + cost[j][id] = d2; + } + } + } + std::vector used_obs(n_obs, false); + std::vector used_id(armors_num, false); + + while (true) { + double best = max_cost; + int best_j = -1; + int best_id = -1; + + for (int j = 0; j < n_obs; ++j) { + if (used_obs[j]) + continue; + for (int id = 0; id < armors_num; ++id) { + if (used_id[id]) + continue; + if (cost[j][id] < best) { + best = cost[j][id]; + best_j = j; + best_id = id; + } + } + } + + if (best_j < 0 || best_id < 0) { + break; + } + + used_obs[best_j] = true; + used_id[best_id] = true; + result.push_back(std::make_pair(best_id, fans[best_j])); + } + + // for (auto fan: fans) { + // int id; + // auto min_angle_error = 1e10; + // const auto angles = getAngles(); + // for (int i = 0; i < angles.size(); i++) { + // auto angle_error = std::abs(angles::normalize_angle( + // angles::normalize_angle(orientationToRoll(fan.target_ori)) - angles[i] + // )); + // if (angle_error < min_angle_error) { + // min_angle_error = angle_error; + // id = i; + // } + // } + // result.push_back(std::make_pair(id, fan)); + // } + return result; + } +} // namespace auto_buff +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_buff/rune_tracker/rune_target.hpp b/wust_vision-main/tasks/auto_buff/rune_tracker/rune_target.hpp new file mode 100644 index 0000000..66a4933 --- /dev/null +++ b/wust_vision-main/tasks/auto_buff/rune_tracker/rune_target.hpp @@ -0,0 +1,230 @@ +#pragma once + +#include "spd_fitter.hpp" +#include "tasks/auto_buff/rune_tracker/motion_models/motion_modelrypd.hpp" +#include "tasks/auto_buff/type.hpp" +#include "tasks/utils/utils.hpp" +#include +namespace wust_vision { +namespace auto_buff { + namespace MModel = ypdrune_motion_model; + struct RuneTargetConfig: wust_vl::common::utils::ParamGroup { + static constexpr const char* kKey = "rune_tracker"; + const char* key() const override { + return kKey; + } + GEN_PARAM(int, esekf_iter_num); + GEN_PARAM(double, lost_time_thres); + GEN_PARAM(int, tracking_thres); + GEN_PARAM(double, max_dis_diff); + GEN_PARAM(double, match_gate); + GEN_PARAM(double, q_roll); + GEN_PARAM(double, qyaw_output); + GEN_PARAM(double, q_xyz); + GEN_PARAM(double, q_yaw); + GEN_PARAM(double, yp_r); + GEN_PARAM(double, dis_r); + GEN_PARAM(double, yaw_r); + GEN_PARAM(double, roll_r); + GEN_PARAM(double, big_window_sec); + using Ptr = std::shared_ptr; + RuneTargetConfig() {} + static Ptr create() { + return std::make_shared(); + } + void loadSelf(const YAML::Node& node) override { + esekf_iter_num_param.load(node); + lost_time_thres_param.load(node); + tracking_thres_param.load(node); + + max_dis_diff_param.load(node); + match_gate_param.load(node); + q_roll_param.load(node); + q_xyz_param.load(node); + q_yaw_param.load(node); + yp_r_param.load(node); + dis_r_param.load(node); + roll_r_param.load(node); + big_window_sec_param.load(node); + } + }; + + class RuneTarget { + public: + RuneTarget() = default; + RuneTarget( + const auto_buff::RuneFan& fan, + RuneTargetConfig::Ptr target_config, + double pre_v_roll + ); + bool is_big_ = false; + double last_ypd_y = 0; + bool is_inited = false; + int last_id; + std::vector update_ids; + RuneTargetConfig::Ptr target_config_; + std::chrono::steady_clock::time_point last_t_; + std::chrono::steady_clock::time_point timestamp_; + std::chrono::steady_clock::time_point start_time_; + double dt_; + double last_time_ = 0; + SinSpeedFitter fitter_; + MModel::RuneESKF esekf_ypd_; + Eigen::Matrix measurement_ = + Eigen::Matrix::Zero(); + Eigen::Matrix target_state_ = + Eigen::Matrix::Zero(); + cv::Rect expanded( + Eigen::Matrix4d T_camera_to_odom, + const cv::Mat& camera_intrinsic, + const cv::Mat& camera_distortion, + const cv::Size& image_size + ); + bool update(const auto_buff::RuneFan& fan); + void predict(std::chrono::steady_clock::time_point t); + void predict(double dt); + Eigen::Matrix + computeMeasurementCovariance(const Eigen::Matrix& z) const; + Eigen::Matrix computeProcessNoise(double dt) const; + + inline bool checkTargetAppear() { + bool appear = is_inited + && wust_vl::common::utils::time_utils::durationSec( + timestamp_, + wust_vl::common::utils::time_utils::now() + ) < target_config_->lost_time_thres_param.get(); + return appear; + } + double predictAngle(std::chrono::steady_clock::time_point t) const { + double to_start = wust_vl::common::utils::time_utils::durationSec(start_time_, t); + return fitter_.predictAngle(to_start); + } + double predictAngle(double dt) const { + return fitter_.predictAngle(last_time_ + dt); + } + void predictWithFitter(double dt) { + if (is_big_) { + double to_start = last_time_ + dt; + double angle = fitter_.predictAngle(to_start); + double speed = fitter_.predictSpeed(to_start); + auto state = esekf_ypd_.getState(); + state[(int)MModel::State::ROLL] = angles::normalize_angle(angle); + state[(int)MModel::State::VROLL] = speed; + esekf_ypd_.setState(state); + } else { + predict(dt); + } + } + void predictWithFitter(std::chrono::steady_clock::time_point t) { + double dt = wust_vl::common::utils::time_utils::durationSec(last_t_, t); + + predictWithFitter(dt); + + last_t_ = t; + } + double getFitterSpd(std::chrono::steady_clock::time_point t) { + double to_start = wust_vl::common::utils::time_utils::durationSec(start_time_, t); + return fitter_.predictSpeed(to_start); + } + + Eigen::Vector3d centerPos() const { + return { target_state_((int)MModel::State::CX), + target_state_((int)MModel::State::CY), + target_state_((int)MModel::State::CZ) }; + } + std::vector getAngles() { + std::vector angles; + for (int i = 0; i < 5; i++) { + auto angle = angles::normalize_angle( + target_state_[(int)MModel::State::ROLL] + i * 2 * M_PI / 5 + ); + angles.push_back(angle); + } + return angles; + } + bool diverged() const { + return diverged(target_state_); + } + bool diverged(Eigen::VectorXd target_state) const { + return false; + } + double roll() const { + return target_state_[(int)MModel::State::ROLL]; + } + double curr_roll() const { + return roll() + last_id * 2 * M_PI / 5; + } + double real_roll(int id) const { + return roll() + id * 2 * M_PI / 5; + } + double yaw() const { + return target_state_[(int)MModel::State::YAW]; + } + double v_roll() const { + return target_state_[(int)MModel::State::VROLL]; + } + std::vector> + match(const std::vector& fans); + std::vector> getAllPose() const { + std::vector> poses; + for (int i = 0; i < 5; i++) { + poses.emplace_back(getPose(i)); + } + return poses; + } + std::pair getPose(int id) const { + Eigen::Vector3d euler = Eigen::Vector3d(yaw(), 0.0, real_roll(id)); + auto q = utils::eulerToQuat(euler, utils::EulerOrder::ZYX); + return computeBladeTipPose(centerPos(), q, id); + } + + std::pair + computeBladeTipPose(const Eigen::Vector3d& center_pos, const Eigen::Quaterniond& q, int id) + const { + // tip 的局部坐标(沿 local X 方向) + Eigen::Vector3d local_tip(0.0, 0.0, RUNE_R2PANCENTER); + + Eigen::Vector3d tip_pos = center_pos + q * local_tip; + + Eigen::Vector3d euler = Eigen::Vector3d(yaw(), 0.0, real_roll(id)); + return { tip_pos, utils::eulerToQuat(euler, utils::EulerOrder::ZYX) }; + } + std::pair getHitPoint() const { + return getPose(last_id); + } + auto_buff::PowerRune getPowerRune() const { + auto_buff::PowerRune power_rune; + if (!is_inited) { + return power_rune; + } + power_rune.center.pos = centerPos(); + Eigen::Vector3d euler = Eigen::Vector3d(yaw(), 0.0, real_roll(last_id)); + auto q = Eigen::Quaterniond(); + power_rune.center.ori = q; + auto all_pose = getAllPose(); + for (int i = 0; i < all_pose.size(); i++) { + auto_buff::PowerRune::Pose pose; + pose.pos = all_pose[i].first; + pose.ori = all_pose[i].second; + power_rune.fans.push_back(pose); + } + power_rune.hit_id = last_id; + return power_rune; + } + Eigen::Matrix getMeasure(const auto_buff::RuneFan::Simple& fan) { + auto p = fan.target_pos; + + double measured_yaw = utils::orientationToYaw(fan.target_ori); + double measured_roll = utils::orientationToRoll(fan.target_ori); + double ypd_y = std::atan2(p.y(), p.x()); + ypd_y = this->last_ypd_y + angles::shortest_angular_distance(this->last_ypd_y, ypd_y); + this->last_ypd_y = ypd_y; + double ypd_p = std::atan2(p.z(), std::sqrt(p.x() * p.x() + p.y() * p.y())); + double ypd_d = std::sqrt(p.x() * p.x() + p.y() * p.y() + p.z() * p.z()); + Eigen::Matrix measure; + measure << ypd_y, ypd_p, ypd_d, measured_yaw, measured_roll; + return measure; + } + }; +} // namespace auto_buff +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_buff/rune_tracker/rune_tracker.cpp b/wust_vision-main/tasks/auto_buff/rune_tracker/rune_tracker.cpp new file mode 100644 index 0000000..7916a02 --- /dev/null +++ b/wust_vision-main/tasks/auto_buff/rune_tracker/rune_tracker.cpp @@ -0,0 +1,121 @@ +#include "rune_tracker.hpp" +namespace wust_vision { +namespace auto_buff { + struct RuneTracker::Impl { + public: + Impl(wust_vl::common::utils::Parameter::Ptr auto_buff_config_parameter) { + tracker_state = LOST; + target_ = auto_buff::RuneTarget(); + target_config_ = RuneTargetConfig::create(); + auto_buff_config_parameter->registerGroup(*target_config_); + auto_buff_config_parameter->reloadFromOldPath(); + } + auto_buff::RuneTarget track(const auto_buff::RuneFan& fan) { + double dt = std::chrono::duration(fan.timestamp - last_time_).count(); + last_time_ = fan.timestamp; + lost_thres_ = + std::abs(static_cast(target_config_->lost_time_thres_param.get() / dt)); + bool found; + if (tracker_state == LOST) { + found = initTarget(fan); + } else { + found = updateTarget(fan); + } + updateFsm(found); + return target_; + } + void updateFsm(bool found) { + switch (tracker_state) { + case DETECTING: + if (found) { + if (++detect_count_ > target_config_->tracking_thres_param.get()) { + detect_count_ = 0; + tracker_state = TRACKING; + } + } else { + detect_count_ = 0; + tracker_state = LOST; + } + break; + + case TRACKING: + if (!found) { + tracker_state = TEMP_LOST; + lost_count_ = 1; + } + break; + + case TEMP_LOST: + if (!found) { + if (++lost_count_ > lost_thres_) { + lost_count_ = 0; + tracker_state = LOST; + } + } else { + lost_count_ = 0; + tracker_state = TRACKING; + } + break; + + default: + break; + } + + // target_.is_tracking = (tracker_state == TRACKING || tracker_state == TEMP_LOST); + + if (found) + ++found_count_; + // if (target_.is_tracking) { + // pre_v_roll_ = target_.v_roll(); + // } + } + bool initTarget(const auto_buff::RuneFan& fan) { + if (!fan.is_valid || fan.fans.empty()) { + return false; + } + target_ = auto_buff::RuneTarget(fan, target_config_, pre_v_roll_); + tracker_state = DETECTING; + return true; + } + bool updateTarget(const auto_buff::RuneFan& fan) { + if (!fan.is_valid || fan.fans.empty()) { + return false; + } + auto fan_copy = fan; + std::erase_if(fan_copy.fans, [this](const auto_buff::RuneFan::Simple& f) { + bool pose_check = std::abs((f.target_pos - target_.centerPos()).norm()) + < target_config_->max_dis_diff_param.get() + && f.target_pos.norm() > 1.0; + + return !pose_check; + }); + target_.predict(fan_copy.timestamp); + return target_.update(fan_copy); + } + enum State { + LOST, + DETECTING, + TRACKING, + TEMP_LOST, + } tracker_state = LOST; + auto_buff::RuneTarget target_; + int detect_count_ = 0; + int lost_count_ = 0; + int found_count_ = 0; + double pre_v_roll_ = 0; + int lost_thres_ = 0; + std::chrono::steady_clock::time_point last_time_; + RuneTargetConfig::Ptr target_config_; + }; + RuneTracker::RuneTracker(wust_vl::common::utils::Parameter::Ptr auto_buff_config_parameter) { + _impl = std::make_unique(auto_buff_config_parameter); + } + RuneTracker::~RuneTracker() { + _impl.reset(); + } + auto_buff::RuneTarget RuneTracker::track(const auto_buff::RuneFan& fan) { + return _impl->track(fan); + } + +} // namespace auto_buff +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_buff/rune_tracker/rune_tracker.hpp b/wust_vision-main/tasks/auto_buff/rune_tracker/rune_tracker.hpp new file mode 100644 index 0000000..4cbfe86 --- /dev/null +++ b/wust_vision-main/tasks/auto_buff/rune_tracker/rune_tracker.hpp @@ -0,0 +1,20 @@ +#pragma once +#include "rune_target.hpp" +namespace wust_vision { +namespace auto_buff { + class RuneTracker { + public: + using Ptr = std::unique_ptr; + RuneTracker(wust_vl::common::utils::Parameter::Ptr auto_buff_config_parameter); + static Ptr create(wust_vl::common::utils::Parameter::Ptr auto_buff_config_parameter) { + return std::make_unique(auto_buff_config_parameter); + } + ~RuneTracker(); + auto_buff::RuneTarget track(const auto_buff::RuneFan& fan); + + private: + struct Impl; + std::unique_ptr _impl; + }; +} // namespace auto_buff +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_buff/rune_tracker/spd_fitter.hpp b/wust_vision-main/tasks/auto_buff/rune_tracker/spd_fitter.hpp new file mode 100644 index 0000000..7d2ab09 --- /dev/null +++ b/wust_vision-main/tasks/auto_buff/rune_tracker/spd_fitter.hpp @@ -0,0 +1,265 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace wust_vision { +namespace auto_buff { + class SinSpeedFitter { + public: + struct P { + double a, w, t0; + }; + static constexpr double a_min_ = 0.780; + static constexpr double a_max_ = 1.045; + static constexpr double w_min_ = 1.884; + static constexpr double w_max_ = 2.000; + + SinSpeedFitter() {} + void setWindow(double w) { + window_sec_ = w; + } + + SinSpeedFitter(const SinSpeedFitter& other) { + std::scoped_lock lock(other.mtx_); + params_ = other.params_; + times_ = other.times_; + speeds_ = other.speeds_; + has_angle_ref_ = other.has_angle_ref_; + angle_ref_time_ = other.angle_ref_time_; + angle_ref_value_ = other.angle_ref_value_; + sign_ = other.sign_; + fitting_ = false; + } + + SinSpeedFitter& operator=(const SinSpeedFitter& other) { + if (this != &other) { + std::scoped_lock lock(mtx_, other.mtx_); + params_ = other.params_; + times_ = other.times_; + speeds_ = other.speeds_; + has_angle_ref_ = other.has_angle_ref_; + angle_ref_time_ = other.angle_ref_time_; + angle_ref_value_ = other.angle_ref_value_; + sign_ = other.sign_; + fitting_ = false; + } + return *this; + } + + void update(double time_s, double speed_rads) { + std::scoped_lock lock(mtx_); + + auto it = std::lower_bound(times_.begin(), times_.end(), time_s); + size_t idx = std::distance(times_.begin(), it); + times_.insert(it, time_s); + speeds_.insert(speeds_.begin() + idx, speed_rads); + + const double window_sec = 5.0; + if (!times_.empty()) { + double latest = times_.back(); + while (!times_.empty() && latest - times_.front() > window_sec) { + times_.erase(times_.begin()); + speeds_.erase(speeds_.begin()); + } + } + } + + void fit(bool verbose = false) { + std::scoped_lock lock(mtx_); + fitImpl(verbose); + } + void fitAsync(bool verbose = false) { + if (fitting_.exchange(true)) { + if (verbose) + std::cout << "[SinSpeedFitter] Previous async fit still running, skip.\n"; + return; + } + + std::vector t_copy, s_copy; + P params_snapshot; + { + std::scoped_lock lock(mtx_); + t_copy = times_; + s_copy = speeds_; + params_snapshot = params_; + } + + std::thread([this, t_copy, s_copy, params_snapshot, verbose]() { + fitImpl(verbose, &t_copy, &s_copy, ¶ms_snapshot); + fitting_ = false; + }).detach(); + } + + double predictSpeed(double t) const { + std::scoped_lock lock(mtx_); + double a = params_.a; + double w = params_.w; + double b = 2.090 - a; + return sign_ * (a * std::sin(w * (t - params_.t0)) + b); + } + + double predictAngle(double t) const { + std::scoped_lock lock(mtx_); + if (!has_angle_ref_) + return 0.0; + double a = params_.a; + double w = params_.w; + double b = 2.090 - a; + double theta = sign_ * (-a / w * std::cos(w * (t - params_.t0)) + b * (t - params_.t0)); + double theta_ref = sign_ + * (-a / w * std::cos(w * (angle_ref_time_ - params_.t0)) + + b * (angle_ref_time_ - params_.t0)); + return angle_ref_value_ + (theta - theta_ref); + } + + void setAngleRef(double time_s, double angle_rad) { + std::scoped_lock lock(mtx_); + angle_ref_time_ = time_s; + angle_ref_value_ = angle_rad; + has_angle_ref_ = true; + } + + const P& params() const { + return params_; + } + int sign() const { + return sign_; + } + bool isFitting() const { + return fitting_.load(); + } + + private: + struct SinResidual { + SinResidual(double t, double s, int sign): t_(t), s_(s), sign_(sign) {} + template + bool operator()(const T* const p, T* residual) const { + const T& a = p[0]; + const T& w = p[1]; + const T& t0 = p[2]; + T b = T(2.090) - a; + T pred = T(sign_) * (a * sin(w * (T(t_) - t0)) + b); + residual[0] = T(s_) - pred; + return true; + } + double t_, s_; + int sign_; + }; + + bool fitImpl( + bool verbose, + const std::vector* t_ptr = nullptr, + const std::vector* s_ptr = nullptr, + const P* params_snapshot = nullptr + ) { + const auto& t_in = t_ptr ? *t_ptr : times_; + const auto& s_in = s_ptr ? *s_ptr : speeds_; + + if (t_in.size() < 3) { + if (verbose) + std::cerr << "[SinSpeedFitter] need >=3 samples\n"; + return false; + } + + std::vector> tmp; + tmp.reserve(t_in.size()); + for (size_t i = 0; i < t_in.size(); ++i) + tmp.emplace_back(t_in[i], s_in[i]); + std::sort(tmp.begin(), tmp.end()); + + std::vector t_unique, s_unique; + t_unique.reserve(tmp.size()); + s_unique.reserve(tmp.size()); + double last_t = std::numeric_limits::quiet_NaN(); + for (auto& [t, s]: tmp) { + if (t_unique.empty() || std::abs(t - last_t) > 1e-9) { + t_unique.push_back(t); + s_unique.push_back(s); + last_t = t; + } + } + + P params_initial = params_snapshot ? *params_snapshot : params_; + + double err_pos = fit_with_sign(+1, t_unique, s_unique, params_initial, verbose); + double err_neg = fit_with_sign(-1, t_unique, s_unique, params_initial, verbose); + + std::scoped_lock lock(mtx_); + sign_ = (err_pos <= err_neg) ? +1 : -1; + return true; + } + + double fit_with_sign( + int sgn, + const std::vector& t_unique, + const std::vector& s_unique, + P params_initial, + bool verbose + ) { + ceres::Problem problem; + double params[3] = { params_initial.a, params_initial.w, params_initial.t0 }; + for (size_t i = 0; i < t_unique.size(); ++i) { + problem.AddResidualBlock( + new ceres::AutoDiffCostFunction( + new SinResidual(t_unique[i], s_unique[i], sgn) + ), + nullptr, + params + ); + } + + problem.SetParameterLowerBound(params, 0, a_min_); + problem.SetParameterUpperBound(params, 0, a_max_); + problem.SetParameterLowerBound(params, 1, w_min_); + problem.SetParameterUpperBound(params, 1, w_max_); + + ceres::Solver::Options options; + options.linear_solver_type = ceres::DENSE_QR; + options.max_num_iterations = 100; + options.minimizer_progress_to_stdout = verbose; + + ceres::Solver::Summary summary; + ceres::Solve(options, &problem, &summary); + + double err_sum = 0.0; + for (size_t i = 0; i < t_unique.size(); ++i) { + double pred = sgn + * (params[0] * std::sin(params[1] * (t_unique[i] - params[2])) + + (2.090 - params[0])); + double e = s_unique[i] - pred; + err_sum += e * e; + } + + if (verbose) + std::cout << (sgn > 0 ? "[+] " : "[-] ") << summary.BriefReport() + << " error=" << err_sum << std::endl; + + std::scoped_lock lock(mtx_); + params_.a = params[0]; + params_.w = params[1]; + params_.t0 = params[2]; + return err_sum; + } + + private: + mutable std::mutex mtx_; + P params_ { 1.0, 1.9, 0.0 }; + std::vector times_; + std::vector speeds_; + bool has_angle_ref_ = false; + double angle_ref_time_ = 0.0; + double angle_ref_value_ = 0.0; + int sign_ = 1; + std::atomic fitting_ { false }; + + double window_sec_ = 1.0; + }; +} // namespace auto_buff +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_buff/rune_where/rune_where.cpp b/wust_vision-main/tasks/auto_buff/rune_where/rune_where.cpp new file mode 100644 index 0000000..0146e6a --- /dev/null +++ b/wust_vision-main/tasks/auto_buff/rune_where/rune_where.cpp @@ -0,0 +1,213 @@ +#include "rune_where.hpp" +#include "tasks/utils/utils.hpp" +#include +#include + +namespace wust_vision { +namespace auto_buff { + struct RuneWhere::Impl { + public: + Impl(const YAML::Node& config, const std::pair& camera_info) { + camera_info_ = camera_info; + } + + struct Params { + enum class OptMode : int { GOLDEN = 0, CERES = 1, NONE = 2 } opt_mode; + OptMode fromString(const std::string& mode) { + if (mode == "golden" || mode == "GOLDEN") { + return OptMode::GOLDEN; + } else if (mode == "none" || mode == "NONE") { + return OptMode::NONE; + } else { + return OptMode::NONE; + } + } + + int golden_search_side_deg = 60; + void load(const YAML::Node& node) { + opt_mode = fromString(node["roll_opt"]["mode"].as()); + golden_search_side_deg = node["roll_opt"]["golden_search_side_deg"].as(); + } + } params_; + auto_buff::RuneFan + where(auto_buff::RuneFan f, Eigen::Matrix4d T_camera_to_odom) const noexcept { + const Eigen::Matrix3d R_imu_cam = T_camera_to_odom.block<3, 3>(0, 0); + for (auto& fan: f.fans) { + cv::Mat rvec, tvec; + cv::solvePnP( + fan.getObjs(), + fan.landmarks(), + camera_info_.first, + camera_info_.second, + rvec, + tvec, + false, + cv::SOLVEPNP_IPPE //平移更稳定,(旋转这里纯靠后面优化) + ); + cv::Mat R_cv; + cv::Rodrigues(rvec, R_cv); + Eigen::Matrix3d R = utils::cvToEigen(R_cv); + Eigen::Vector3d t = utils::cvToEigen(tvec); + if (params_.opt_mode != Params::OptMode::NONE) { + R = solveBa_R(fan, t, R, R_imu_cam); + } + + fan.ori = Eigen::Quaterniond(R); + fan.pos = t; + Eigen::Vector3d pos_camera = fan.pos; + fan.target_pos = utils::transformPosition(pos_camera, T_camera_to_odom); + + const Eigen::Quaterniond + q_camera(fan.ori.w(), fan.ori.x(), fan.ori.y(), fan.ori.z()); + const Eigen::Quaterniond q_odom = + utils::transformOrientation(q_camera, T_camera_to_odom); + fan.target_ori = q_odom; + + f.is_valid = true; + } + return f; + } + std::vector reprojection( + double roll, + const std::vector& object_points, + const std::vector& landmarks, + const Eigen::Matrix3d& Rci, + double pitch, + double yaw, + const Eigen::Vector3d& t + ) const noexcept { + const Eigen::AngleAxisd ay(yaw, Eigen::Vector3d::UnitZ()); + const Eigen::AngleAxisd ap(pitch, Eigen::Vector3d::UnitY()); + const Eigen::AngleAxisd ar(roll, Eigen::Vector3d::UnitX()); + const Eigen::Matrix3d R = Rci * (ay * ap * ar).toRotationMatrix(); + + cv::Mat rvec, R_cv; + cv::eigen2cv(R, R_cv); + cv::Rodrigues(R_cv, rvec); + + const cv::Mat tvec = (cv::Mat_(3, 1) << t.x(), t.y(), t.z()); + + std::vector pts_2d; + pts_2d.reserve(object_points.size()); + cv::projectPoints( + object_points, + rvec, + tvec, + camera_info_.first, + camera_info_.second, + pts_2d + ); + + std::vector image_points; + image_points.reserve(pts_2d.size()); + + for (const auto& p: pts_2d) { + image_points.emplace_back(p.x, p.y); + } + + return image_points; + } + double reprojectionErrorRoll( + double roll, + const std::vector& obj, + const std::vector& lm, + const Eigen::Matrix3d& Rci, + double pitch, + double yaw, + const Eigen::Vector3d& t + ) const noexcept { + const auto image_points = reprojection(roll, obj, lm, Rci, pitch, yaw, t); + double cost = 0.0; + + for (size_t i = 0; i < image_points.size(); ++i) { + Eigen::Vector2d obs(lm[i].x, lm[i].y); + cost += (image_points[i] - obs).squaredNorm(); + } + return cost; + } + + double goldenRoll( + double init, + const std::vector& obj, + const std::vector& lm, + const Eigen::Matrix3d& Rci, + double pitch, + double yaw, + const Eigen::Vector3d& t + ) const noexcept { + constexpr double phi = 1.618033988749894848; // golden ratio + double l = init - params_.golden_search_side_deg * M_PI / 180.0; + double r = init + params_.golden_search_side_deg * M_PI / 180.0; + + double r1 = r - (r - l) / phi; + double r2 = l + (r - l) / phi; + + double f1 = reprojectionErrorRoll(r1, obj, lm, Rci, pitch, yaw, t); + double f2 = reprojectionErrorRoll(r2, obj, lm, Rci, pitch, yaw, t); + + while (r - l > 0.0001) { // 约 0.0057 度 + if (f1 < f2) { + r = r2; + r2 = r1; + f2 = f1; + r1 = r - (r - l) / phi; + f1 = reprojectionErrorRoll(r1, obj, lm, Rci, pitch, yaw, t); + } else { + l = r1; + r1 = r2; + f1 = f2; + r2 = l + (r - l) / phi; + f2 = reprojectionErrorRoll(r2, obj, lm, Rci, pitch, yaw, t); + } + } + + return 0.5 * (l + r); // final best roll + } + + Eigen::Matrix3d solveBa_R( + const auto_buff::RuneFan::Simple& rune_fan, + const Eigen::Vector3d& t_camera_armor, + const Eigen::Matrix3d& R_camera_armor, + const Eigen::Matrix3d& R_imu_camera + ) const noexcept { + Eigen::Matrix3d R_imu_armor = R_imu_camera * R_camera_armor; + Eigen::Matrix3d R_camera_imu = R_imu_camera.transpose(); + double initial_roll = std::atan2(R_imu_armor(2, 1), R_imu_armor(2, 2)); + double roll = initial_roll; + Eigen::Vector3d t_imu_armor = R_imu_camera * t_camera_armor; + double yaw = std::atan2(t_imu_armor.y(), t_imu_armor.x()); + double pitch = 0; + auto cv_points = rune_fan.getObjs(); + const auto& landmarks = rune_fan.landmarks(); + if (params_.opt_mode == Params::OptMode::GOLDEN) { + roll = goldenRoll( + roll, + cv_points, + landmarks, + R_camera_imu, + pitch, + yaw, + t_camera_armor + ); + } + const Eigen::AngleAxisd ay(yaw, Eigen::Vector3d::UnitZ()); + const Eigen::AngleAxisd ap(pitch, Eigen::Vector3d::UnitY()); + const Eigen::AngleAxisd ar(roll, Eigen::Vector3d::UnitX()); + + const Eigen::Matrix3d R_result = R_camera_imu * (ay * ap * ar).toRotationMatrix(); + return R_result; + } + + std::pair camera_info_; + }; + RuneWhere::RuneWhere(const YAML::Node& config, const std::pair& camera_info) { + _impl = std::make_unique(config, camera_info); + } + RuneWhere::~RuneWhere() { + _impl.reset(); + } + auto_buff::RuneFan RuneWhere::where(auto_buff::RuneFan f, Eigen::Matrix4d T_camera_to_odom) { + return _impl->where(f, T_camera_to_odom); + } +} // namespace auto_buff +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_buff/rune_where/rune_where.hpp b/wust_vision-main/tasks/auto_buff/rune_where/rune_where.hpp new file mode 100644 index 0000000..5147adf --- /dev/null +++ b/wust_vision-main/tasks/auto_buff/rune_where/rune_where.hpp @@ -0,0 +1,39 @@ +// Created by Labor 2023.8.25 +// Maintained by Labor, Chengfu Zou +// Copyright (C) FYT Vision Group. All rights reserved. +// Copyright 2025 XiaoJian Wu +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "tasks/auto_buff/type.hpp" +namespace wust_vision { +namespace auto_buff { + class RuneWhere { + public: + using Ptr = std::unique_ptr; + RuneWhere(const YAML::Node& config, const std::pair& camera_info); + static Ptr + create(const YAML::Node& config, const std::pair& camera_info) { + return std::make_unique(config, camera_info); + } + ~RuneWhere(); + auto_buff::RuneFan where(auto_buff::RuneFan f, Eigen::Matrix4d T_camera_to_odom); + + private: + struct Impl; + std::unique_ptr _impl; + }; + +} // namespace auto_buff +} // namespace wust_vision diff --git a/wust_vision-main/tasks/auto_buff/type.cpp b/wust_vision-main/tasks/auto_buff/type.cpp new file mode 100644 index 0000000..26cdb8e --- /dev/null +++ b/wust_vision-main/tasks/auto_buff/type.cpp @@ -0,0 +1,339 @@ +#include "type.hpp" +#include "tasks/utils/utils.hpp" +namespace wust_vision { +namespace auto_buff { + void RunePan::draw(cv::Mat& img, const cv::Point2f& offset) const { + if (!is_valid || corners.size() < 3) + return; + + std::vector sorted_corners = corners; + for (auto& pt: sorted_corners) { + pt += offset; + } + // 画边 + for (size_t i = 0; i < sorted_corners.size(); ++i) { + cv::line( + img, + sorted_corners[i], + sorted_corners[(i + 1) % sorted_corners.size()], + cv::Scalar(0, 255, 255), + 2 + ); + } + + // 画中心点 + cv::circle(img, center, 3, cv::Scalar(255, 0, 255), -1); + if (has_refer) { + // 画角点编号 + for (size_t i = 0; i < sorted_corners.size(); ++i) { + cv::Point2f p = sorted_corners[i]; + + // 绘制角点位置 + cv::circle(img, p, 3, cv::Scalar(0, 0, 255), -1); + + // 让数字稍微往右下偏移,避免盖到角点 + cv::Point2f text_pos = p + cv::Point2f(5, -5); + + cv::putText( + img, + std::to_string(i), + text_pos, + cv::FONT_HERSHEY_SIMPLEX, + 0.5, + cv::Scalar(0, 255, 0), + 1 + ); + } + } + } + double RunePan::getArea() const { + if (corners.size() < 3) + return 0.0; + std::vector sorted_corners = corners; + std::sort( + sorted_corners.begin(), + sorted_corners.end(), + [this](const cv::Point2f& a, const cv::Point2f& b) { + double angA = std::atan2(a.y - center.y, a.x - center.x); + double angB = std::atan2(b.y - center.y, b.x - center.x); + return angA < angB; + } + ); + return cv::contourArea(sorted_corners); + } + + void RunePan::addReferRuneCenter(const RuneCenter& rc) { + if (!rc.is_valid || !is_valid) + return; + if (corners.size() != 4) + return; + + cv::Point2f down_vec = rc.center - center; + float norm = std::sqrt(down_vec.x * down_vec.x + down_vec.y * down_vec.y); + if (norm < 1e-6f) + return; + has_refer = true; + float angle_ref = std::atan2(down_vec.y, down_vec.x); + + // 获取4个点在旋转后的角度 + struct Node { + float ang; + cv::Point2f p; + }; + std::vector arr; + arr.reserve(4); + + for (auto& p: corners) { + cv::Point2f v = p - center; + + // 旋转坐标,使 down_vec 对齐 angle=0 + float ang = std::atan2(v.y, v.x) - angle_ref; + + // 归一化到 (-π, π] + while (ang <= -CV_PI) + ang += 2 * CV_PI; + while (ang > CV_PI) + ang -= 2 * CV_PI; + + arr.push_back({ ang, p }); + } + + // 按角度排序(从 -π 到 π) + std::sort(arr.begin(), arr.end(), [](const Node& a, const Node& b) { + return a.ang < b.ang; + }); + + // 准备象限变量并标记 + cv::Point2f lu(0, 0), ru(0, 0), rd(0, 0), ld(0, 0); + bool has_lu = false, has_ru = false, has_rd = false, has_ld = false; + + for (const auto& n: arr) { + float a = n.ang; + + if (a > CV_PI / 2 && a <= CV_PI) { + lu = n.p; + has_lu = true; + } else if (a > 0 && a <= CV_PI / 2) { + ru = n.p; + has_ru = true; + } else if (a > -CV_PI / 2 && a <= 0) { + rd = n.p; + has_rd = true; + } else { // a > -CV_PI && a <= -CV_PI/2 + ld = n.p; + has_ld = true; + } + } + + std::array ordered; + + if (has_lu && has_ru && has_rd && has_ld) { + ordered[0] = lu; + ordered[1] = ru; + ordered[2] = rd; + ordered[3] = ld; + corners.assign(ordered.begin(), ordered.end()); + return; + } + + float target = 3.0f * CV_PI / 4.0f; // 135° + int best_idx = 0; + float best_diff = std::numeric_limits::max(); + for (int i = 0; i < (int)arr.size(); ++i) { + float d = std::fabs(angles::shortest_angular_distance(target, arr[i].ang) + ); // 如果没有 angles::shortest_angular_distance,可以用下面替代 + if (d < best_diff) { + best_diff = d; + best_idx = i; + } + } + + for (int i = 0; i < 4; ++i) { + int idx = (best_idx + i) % 4; + ordered[i] = arr[idx].p; + } + + corners.assign(ordered.begin(), ordered.end()); + } + + void RuneFan::Simple::addOther(const Simple& other) { + auto l1 = points2d[0] - points2d[5]; + auto l2 = other.points2d[0] - other.points2d[5]; + float a1 = std::atan2(l1.y, l1.x); + float a2 = std::atan2(l2.y, l2.x); + + float d = a1 - a2; + d = normalizeAngle0to2pi(d); + + int id = 0; + double min_err = 1e9; + for (int i = 0; i < angle_diffs.size(); i++) { + double err = std::abs(angle_diffs[i] - d); + if (err < min_err) { + min_err = err; + id = i; + } + } + + if (id < 1) { + return; + } + has_other++; + points2d.push_back(other.points2d[1]); + points2d.push_back(other.points2d[2]); + points2d.push_back(other.points2d[3]); + points2d.push_back(other.points2d[4]); + + double roll = -angle_diffs[id]; + + points3d.push_back(rotateX(points3d[1], roll)); + points3d.push_back(rotateX(points3d[2], roll)); + points3d.push_back(rotateX(points3d[3], roll)); + points3d.push_back(rotateX(points3d[4], roll)); + } + + void RuneFan::Simple::drawLandmarks(cv::Mat& image) const { + std::vector lm = landmarks(); + for (size_t i = 0; i < lm.size(); i++) { + cv::circle(image, lm[i], 3, cv::Scalar(255, 255, 255), -1); + if (i == 0) { + cv::putText( + image, + "R", + lm[i], + cv::FONT_HERSHEY_SIMPLEX, + 1.5, + cv::Scalar(40, 255, 40), + 2 + ); + } else { + cv::putText( + image, + std::to_string(i), + lm[i], + cv::FONT_HERSHEY_SIMPLEX, + 0.5, + cv::Scalar(255, 255, 255), + 2 + ); + } + } + } + + void RuneFan::addOffset(const cv::Point2f& offset) { + for (auto& fan: fans) { + for (auto& point: fan.points2d) { + point += offset; + } + } + } + void RuneFan::transform(const Eigen::Matrix& transform_matrix) { + for (auto& fan: fans) { + for (auto& pt: fan.points2d) { + pt = utils::transformPoint2D(transform_matrix, pt); + } + } + } + + void PowerRune::Pose::tf(Eigen::Matrix4d T_camera_to_odom) { + Eigen::Vector4d pos_camera(pos.x(), pos.y(), pos.z(), 1.0); + Eigen::Vector4d pos_odom = T_camera_to_odom * pos_camera; + + pos.x() = pos_odom.x(); + pos.y() = pos_odom.y(); + pos.z() = pos_odom.z(); + Eigen::Matrix3d R_camera_to_odom = T_camera_to_odom.block<3, 3>(0, 0); + Eigen::Quaterniond q_camera(ori.w(), ori.x(), ori.y(), ori.z()); + Eigen::Matrix3d R_ori_camera = q_camera.normalized().toRotationMatrix(); + + Eigen::Matrix3d R_ori_odom = R_camera_to_odom * R_ori_camera; + Eigen::Quaterniond q_odom(R_ori_odom); + + ori.w() = q_odom.w(); + ori.x() = q_odom.x(); + ori.y() = q_odom.y(); + ori.z() = q_odom.z(); + } + + std::vector PowerRune::Pose::toPts( + const cv::Mat& camera_intrinsic, + const cv::Mat& camera_distortion, + const std::vector& obj_points + ) const { + std::vector pts; + if (pos.norm() < 0.5) { + return pts; + } + + cv::Mat tvec = (cv::Mat_(3, 1) << pos.x(), pos.y(), pos.z()); + Eigen::Matrix3d tf_rot = ori.toRotationMatrix(); + cv::Mat rot_mat = + (cv::Mat_(3, 3) << tf_rot(0, 0), + tf_rot(0, 1), + tf_rot(0, 2), + tf_rot(1, 0), + tf_rot(1, 1), + tf_rot(1, 2), + tf_rot(2, 0), + tf_rot(2, 1), + tf_rot(2, 2)); + + // 旋转矩阵 -> 旋转向量 + cv::Mat rvec; + cv::Rodrigues(rot_mat, rvec); + + cv::projectPoints(obj_points, rvec, tvec, camera_intrinsic, camera_distortion, pts); + + return pts; + } + void PowerRune::Pose::draw( + cv::Mat& img, + const cv::Mat& camera_intrinsic, + const cv::Mat& camera_distortion, + const std::vector& obj_points, + cv::Scalar color + ) const { + auto pts = toPts(camera_intrinsic, camera_distortion, obj_points); + if (!pts.empty()) { + for (int i = 0; i < 4; i++) + cv::line(img, pts[i], pts[(i + 1) % 4], color, 2); + + // 后表面 + for (int i = 4; i < 8; i++) + cv::line(img, pts[i], pts[4 + (i + 1) % 4], color, 2); + + // 侧边 + for (int i = 0; i < 4; i++) + cv::line(img, pts[i], pts[i + 4], color, 2); + cv::Point2f center(0.f, 0.f); + for (auto pt: pts) { + center += pt; + } + center *= 1.0 / pts.size(); + } + } + + void PowerRune::tf(Eigen::Matrix4d T_camera_to_odom) { + center.tf(T_camera_to_odom); + for (auto& fan: fans) + fan.tf(T_camera_to_odom); + } + void + PowerRune::draw(cv::Mat& img, const cv::Mat& camera_intrinsic, const cv::Mat& camera_distortion) + const { + center.draw(img, camera_intrinsic, camera_distortion); + for (int i = 0; i < fans.size(); i++) { + if (i == hit_id) + fans[i].draw( + img, + camera_intrinsic, + camera_distortion, + FAN_BLOCK, + cv::Scalar(40, 255, 40) + ); + else + fans[i].draw(img, camera_intrinsic, camera_distortion, FAN_BLOCK); + } + } +} // namespace auto_buff +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_buff/type.hpp b/wust_vision-main/tasks/auto_buff/type.hpp new file mode 100644 index 0000000..ef3c8b2 --- /dev/null +++ b/wust_vision-main/tasks/auto_buff/type.hpp @@ -0,0 +1,124 @@ +#pragma once +#include "tasks/type_common.hpp" + +namespace wust_vision { +namespace auto_buff { + constexpr double RUNE_PAN_BOX_DIS = 0.16; + constexpr double RUNE_R2PANCENTER = 0.75; + struct RuneCenter { + cv::Point2f center; + cv::RotatedRect rr; + bool is_valid = false; + RuneCenter() = default; + RuneCenter(cv::RotatedRect rect): rr(rect) { + center = rect.center; + is_valid = rr.size.area() > 0; + } + }; + + class RunePan { + public: + cv::Point2f center; + std::vector corners; + bool is_valid = false; + bool has_refer = false; + + void draw(cv::Mat& img, const cv::Point2f& offset) const; + double getArea() const; + void addReferRuneCenter(const RuneCenter& rc); + }; + + struct RuneFan { + public: + bool is_valid = false; + int id; + bool is_big = false; + std::chrono::steady_clock::time_point timestamp; + struct Simple { + int has_other = 0; + std::vector angle_diffs = { 0, + 2 * M_PI / 5, + 2 * M_PI / 5 * 2, + 2 * M_PI / 5 * 3, + 2 * M_PI / 5 * 4 }; + std::vector points2d; + std::vector points3d = { + { 0.0f, 0.0f, 0.0f }, // P0 + { 0.0f, RUNE_PAN_BOX_DIS / 2.0f, RUNE_R2PANCENTER + RUNE_PAN_BOX_DIS / 2.0f }, // P1 + { 0.0f, RUNE_PAN_BOX_DIS / 2.0f, RUNE_R2PANCENTER - RUNE_PAN_BOX_DIS / 2.0f }, // P2 + { 0.0f, + -RUNE_PAN_BOX_DIS / 2.0f, + RUNE_R2PANCENTER - RUNE_PAN_BOX_DIS / 2.0f }, // P3 + { 0.0f, + -RUNE_PAN_BOX_DIS / 2.0f, + RUNE_R2PANCENTER + RUNE_PAN_BOX_DIS / 2.0f }, // P4 + { 0.0f, 0.0f, RUNE_R2PANCENTER } // P5 + }; + inline cv::Point3f rotateX(const cv::Point3f& p, double roll) { + double c = std::cos(roll); + double s = std::sin(roll); + return { p.x, float(p.y * c - p.z * s), float(p.y * s + p.z * c) }; + } + inline double normalizeAngle0to2pi(double a) { + a = std::fmod(a, 2 * M_PI); + if (a < 0) + a += 2 * M_PI; + return a; + } + + Eigen::Vector3d pos; + Eigen::Quaterniond ori; + Eigen::Vector3d target_pos; + Eigen::Quaterniond target_ori; + void addOther(const Simple& other); + std::vector landmarks() const { + return points2d; + } + void drawLandmarks(cv::Mat& image) const; + std::vector getObjs() const { + return points3d; + } + }; + std::vector fans; + void addOffset(const cv::Point2f& offset); + void transform(const Eigen::Matrix& transform_matrix); + }; + static std::vector FAN_BLOCK = { + { -0.05f, -0.20f, -0.15f }, // 0: 左下前 + { 0.05f, -0.20f, -0.15f }, // 1: 右下前 + { 0.05f, 0.20f, -0.15f }, // 2: 右上前 + { -0.05f, 0.20f, -0.15f }, // 3: 左上前 + { -0.05f, -0.20f, 0.15f }, // 4: 左下后 + { 0.05f, -0.20f, 0.15f }, // 5: 右下后 + { 0.05f, 0.20f, 0.15f }, // 6: 右上后 + { -0.05f, 0.20f, 0.15f } // 7: 左上后 + }; + + struct PowerRune { + bool is_valid = false; + struct Pose { + Eigen::Vector3d pos; + Eigen::Quaterniond ori; + void tf(Eigen::Matrix4d T_camera_to_odom); + std::vector toPts( + const cv::Mat& camera_intrinsic, + const cv::Mat& camera_distortion, + const std::vector& obj_points = AIM_TARGET_BLOCK + ) const; + void draw( + cv::Mat& img, + const cv::Mat& camera_intrinsic, + const cv::Mat& camera_distortion, + const std::vector& obj_points = AIM_TARGET_BLOCK, + cv::Scalar color = cv::Scalar(255, 255, 255) + ) const; + }; + Pose center; + std::vector fans; + int hit_id; + void tf(Eigen::Matrix4d T_camera_to_odom); + void + draw(cv::Mat& img, const cv::Mat& camera_intrinsic, const cv::Mat& camera_distortion) const; + }; +} // namespace auto_buff +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_guidance/auto_guidance.cpp b/wust_vision-main/tasks/auto_guidance/auto_guidance.cpp new file mode 100644 index 0000000..983df55 --- /dev/null +++ b/wust_vision-main/tasks/auto_guidance/auto_guidance.cpp @@ -0,0 +1,198 @@ +#include "auto_guidance.hpp" +#include "tasks/auto_guidance/guidance_detector/detector_base.hpp" +#include "tasks/auto_guidance/guidance_detector/detector_factory.hpp" +#include "tasks/auto_guidance/guidance_tracker/guidance_tracker.hpp" +#include "tasks/utils/utils.hpp" +#include "wust_vl/common/concurrency/queues.hpp" +#include "wust_vl/common/utils/logger.hpp" +#include "wust_vl/common/utils/timer.hpp" +namespace wust_vision { +namespace auto_guidance { + struct AutoGuidance::Impl { + ~Impl() { + lights_queue_->stop(); + if (processing_thread_) { + processing_thread_->stop(); + wust_vl::common::concurrency::ThreadManager::instance().unregisterThread( + processing_thread_->getName() + ); + } + } + void init(const YAML::Node& config, const std::pair& camera_info) { + camera_info_ = camera_info; + std::string backend = config["backend"].as(); + std::cout << "backend: " << backend << std::endl; + auto detector_cfg = config["detector"]; + detector_ = DetectorFactory::createDetector(backend, detector_cfg, debug_); + detector_->setCallback(std::bind( + &AutoGuidance::Impl::detectCallback, + this, + std::placeholders::_1, + std::placeholders::_2 + )); + tracker_ = GuidanceTracker::create(config["tracker"]); + lights_queue_ = + std::make_unique>(100, 500); + latency_averager_ = + std::make_unique>(100); + } + void pushInput(CommonFrame& frame) { + img_recv_count_++; + if (detector_) { + detector_->pushInput(frame); + } + } + + void detectCallback(const std::vector& objs, const CommonFrame& frame) { + detect_finish_count_++; + GreenLights lights; + lights.lights = objs; + lights.timestamp = frame.img_frame.timestamp; + lights.id = frame.id; + for (auto& light: lights.lights) { + light.solvePnP(camera_info_.first, camera_info_.second); + light.timestamp = frame.img_frame.timestamp; + light.image_size = frame.img_frame.src_img.size(); + } + green_lights_ = lights; + lights_queue_->enqueue(lights); + + if (debug_) { + std::lock_guard lock(dbg_mutex_); + dbg_.lights = lights; + dbg_.img_frame = frame.img_frame; + } + } + + void lightsCallback(const GreenLights& lights) { + if (lights.timestamp <= tracker_->getLastTime()) { + WUST_WARN(logger_) << "Received out-of-order armor data, discarded."; + return; + } + GuidanceTarget target; + target = tracker_->track(lights); + + { + std::lock_guard lock(target_mutex_); + guidance_target_ = target; + } + auto now = std::chrono::steady_clock::now(); + + auto latency_ms = wust_vl::common::utils::time_utils::durationMs(lights.timestamp, now); + latency_averager_->add(latency_ms); + dbg_.latency_ms = latency_averager_->average(); + if (debug_) { + std::lock_guard lock(dbg_mutex_); + dbg_.target = target; + } + } + void start() { + processing_thread_ = wust_vl::common::concurrency::MonitoredThread::create( + "AutoAimProcessingThread", + [this](wust_vl::common::concurrency::MonitoredThread::Ptr self) { + this->processingLoop(self); + } + ); + + wust_vl::common::concurrency::ThreadManager::instance().registerThread( + processing_thread_ + ); + run_flag_ = true; + } + void processingLoop(wust_vl::common::concurrency::MonitoredThread::Ptr self) { + while (!self->isAlive()) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + while (self->isAlive()) { + if (!self->waitPoint()) + break; + self->heartbeat(); + printStats(); + GreenLights lights; + bool skip; + // if (lights_queue_->dequeue_wait(lights, skip)) { + // lightsCallback(lights); + // tracker_finish_count_++; + // if (skip) { + // WUST_DEBUG(logger_) << "OrderQueue skip"; + // } + // } + if (!lights_queue_->try_dequeue(lights)) { + std::this_thread::sleep_for(std::chrono::milliseconds(3)); + continue; + } + lightsCallback(lights); + tracker_finish_count_++; + } + } + GuidanceTarget getTarget() { + timer_count_++; + + std::lock_guard lock(target_mutex_); + + return guidance_target_; + } + void printStats() { + utils::XSecOnce( + [&] { + WUST_INFO(logger_) + << "Rec: " << img_recv_count_ << ", Det: " << detect_finish_count_ + << ", Fin: " << tracker_finish_count_ << ", Lat: " << dbg_.latency_ms + << "ms" + << ", Tc:" << timer_count_; + + img_recv_count_ = 0; + detect_finish_count_ = 0; + tracker_finish_count_ = 0; + timer_count_ = 0; + }, + 1.0 + ); + } + std::unique_ptr detector_; + std::string logger_ = "auto_guidance"; + std::chrono::steady_clock::time_point last_stat_time_steady_ = + std::chrono::steady_clock::now(); + GuidanceTracker::Ptr tracker_; + bool run_flag_ = false; + int detect_finish_count_ = 0; + int img_recv_count_ = 0; + int tracker_finish_count_ = 0; + int timer_count_ = 0; + bool debug_ = false; + GuidanceTarget guidance_target_; + GreenLights green_lights_; + std::shared_ptr processing_thread_; + std::unique_ptr> lights_queue_; + std::unique_ptr> latency_averager_; + std::pair camera_info_; + AutoGuidanceDebug dbg_; + std::mutex target_mutex_; + std::mutex dbg_mutex_; + }; + AutoGuidance::AutoGuidance(): _impl(std::make_unique()) {} + AutoGuidance::~AutoGuidance() { + _impl.reset(); + } + void + AutoGuidance::init(const YAML::Node& config, const std::pair& camera_info) { + _impl->init(config, camera_info); + } + void AutoGuidance::start() { + _impl->start(); + } + void AutoGuidance::pushInput(CommonFrame& frame) { + _impl->pushInput(frame); + } + void AutoGuidance::setDebug(bool debug) { + _impl->debug_ = debug; + } + AutoGuidanceDebug AutoGuidance::getDebug() { + std::lock_guard lock(_impl->dbg_mutex_); + return _impl->dbg_; + } + GuidanceTarget AutoGuidance::getTarget() { + return _impl->getTarget(); + } +} // namespace auto_guidance +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_guidance/auto_guidance.hpp b/wust_vision-main/tasks/auto_guidance/auto_guidance.hpp new file mode 100644 index 0000000..07b424f --- /dev/null +++ b/wust_vision-main/tasks/auto_guidance/auto_guidance.hpp @@ -0,0 +1,24 @@ +#pragma once +#include "debug.hpp" +#include "tasks/type_common.hpp" +namespace wust_vision { +namespace auto_guidance { + class AutoGuidance { + public: + static inline std::unique_ptr create() { + return std::make_unique(); + } + AutoGuidance(); + ~AutoGuidance(); + void init(const YAML::Node& config, const std::pair& camera_info); + void start(); + void pushInput(CommonFrame& frame); + void setDebug(bool debug); + GuidanceTarget getTarget(); + AutoGuidanceDebug getDebug(); + struct Impl; + std::unique_ptr _impl; + }; + +} // namespace auto_guidance +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_guidance/debug.cpp b/wust_vision-main/tasks/auto_guidance/debug.cpp new file mode 100644 index 0000000..9674bcc --- /dev/null +++ b/wust_vision-main/tasks/auto_guidance/debug.cpp @@ -0,0 +1,295 @@ +#include "debug.hpp" +#include +#include +#include +#include +#include +#include +namespace wust_vision { +namespace auto_guidance { + void drawAutoGuidanceDebugContent(cv::Mat& debug_img, const AutoGuidanceDebug& dbg) { + auto target = dbg.target; + auto lights = dbg.lights; + + lights.drawFront(debug_img); + + if (target.is_tracking_) { + auto now = std::chrono::steady_clock::now(); + target.predict(now); + target.draw(debug_img); + } + + std::string latency_str = fmt::format("Latency: {:.2f}ms", dbg.latency_ms); + cv::putText( + debug_img, + latency_str, + cv::Point(10, 30), + cv::FONT_HERSHEY_SIMPLEX, + 0.8, + cv::Scalar(255, 255, 255), + 2 + ); + + cv::circle( + debug_img, + cv::Point2i(debug_img.cols / 2, debug_img.rows / 2), + 5, + cv::Scalar(255, 255, 255), + 2 + ); + double cx_norm = target.center().x / target.image_size_.width * 2.0 - 1.0; + double diff_center_norm = (target.is_tracking_) ? cx_norm : 0; + { + std::string diff_str = fmt::format("diff_cx_norm: {:.3f}", diff_center_norm); + + int margin = 10; + int font_face = cv::FONT_HERSHEY_SIMPLEX; + double font_scale = 0.7; + int thickness = 2; + + int baseline = 0; + cv::Size text_size = + cv::getTextSize(diff_str, font_face, font_scale, thickness, &baseline); + + // 右上角文本左下角坐标 + int x = debug_img.cols - margin - text_size.width; + int y = margin + text_size.height; + + // 背景框,可选 + cv::rectangle( + debug_img, + cv::Rect( + x - 5, + y - text_size.height - 5, + text_size.width + 10, + text_size.height + 10 + ), + cv::Scalar(0, 0, 0), + cv::FILLED, + cv::LINE_AA + ); + + cv::putText( + debug_img, + diff_str, + cv::Point(x, y), + font_face, + font_scale, + cv::Scalar(0, 255, 255), + thickness, + cv::LINE_AA + ); + } + + if (target.is_tracking_) { + const auto& s = target.target_state_; + + std::string line1 = + fmt::format("pos: {:.1f} {:.1f} {:.1f} {:.1f}", s(0), s(2), s(4), s(6)); + std::string line2 = + fmt::format("vel: {:.2f} {:.2f} {:.2f} {:.2f}", s(1), s(3), s(5), s(7)); + + int x = 10; + int y = debug_img.rows - 30; // 左下角位置 + int dy = 28; + + cv::putText( + debug_img, + line1, + cv::Point(x, y), + cv::FONT_HERSHEY_SIMPLEX, + 0.75, + cv::Scalar(0, 255, 0), + 2 + ); + + cv::putText( + debug_img, + line2, + cv::Point(x, y + dy), + cv::FONT_HERSHEY_SIMPLEX, + 0.75, + cv::Scalar(0, 255, 0), + 2 + ); + + double vx = s(1); + double vy = s(3); + + cv::Point2f p0 = target.center(); + + double scale = 0.5; + + cv::Point2f p1(p0.x + vx * scale, p0.y + vy * scale); + + cv::arrowedLine( + debug_img, + p0, + p1, + cv::Scalar(0, 255, 0), + 2, + cv::LINE_AA, + 0, + 0.25 // 箭头比例 + ); + } + } + + void drawDebugOverlayWrite(const AutoGuidanceDebug& dbg, bool auto_fps) { + static auto last_show_time = std::chrono::steady_clock::now(); + + if (dbg.img_frame.src_img.empty()) + return; + cv::Mat src_img = dbg.img_frame.src_img; + auto now = std::chrono::steady_clock::now(); + const double min_interval_ms = 1000.0 / 30.0; + if (std::chrono::duration(now - last_show_time).count() + < min_interval_ms + && auto_fps) + return; + last_show_time = now; + + // 图像构造 + cv::Mat debug_img; + src_img.convertTo(debug_img, -1, 1, 0); + cv::cvtColor(debug_img, debug_img, cv::COLOR_BGR2RGB); + if (debug_img.empty()) + return; + + // 封装后的绘图函数 + drawAutoGuidanceDebugContent(debug_img, dbg); + cv::cvtColor(debug_img, debug_img, cv::COLOR_RGB2BGR); + // 编码写入共享内存路径 + std::vector buf; + cv::imencode(".jpg", debug_img, buf); + std::ofstream ofs("/dev/shm/debug_frame.jpg.tmp", std::ios::binary); + ofs.write(reinterpret_cast(buf.data()), buf.size()); + ofs.close(); + std::rename("/dev/shm/debug_frame.jpg.tmp", "/dev/shm/debug_frame.jpg"); + } + void drawDebugOverlayShm(const AutoGuidanceDebug& dbg, bool auto_fps) { + static auto last_show_time = std::chrono::steady_clock::now(); + const char* shm_name = "/debug_frame"; + const size_t shm_max_size = 2 * 1024 * 1024; // 2MB 最大图像编码缓存 + + if (dbg.img_frame.src_img.empty()) + return; + cv::Mat src_img = dbg.img_frame.src_img; + + auto now = std::chrono::steady_clock::now(); + const double min_interval_ms = 1000.0 / 30.0; + if (std::chrono::duration(now - last_show_time).count() + < min_interval_ms + && auto_fps) + return; + last_show_time = now; + + // 复制并转RGB + cv::Mat debug_img; + src_img.convertTo(debug_img, -1, 1, 0); + cv::cvtColor(debug_img, debug_img, cv::COLOR_BGR2RGB); + if (debug_img.empty()) + return; + + // 绘制内容 + drawAutoGuidanceDebugContent(debug_img, dbg); + // 编码为 JPG + std::vector buf; + cv::imencode(".jpg", debug_img, buf); + size_t img_size = buf.size(); + + if (img_size > shm_max_size) { + std::cerr << "[drawDebugOverlayWrite] 图像过大: " << img_size << " bytes\n"; + return; + } + + // 创建/打开共享内存 + int fd = shm_open(shm_name, O_CREAT | O_RDWR, 0666); + if (fd == -1) { + perror("shm_open failed"); + return; + } + + // 设置共享内存大小 + if (ftruncate(fd, shm_max_size) == -1) { + perror("ftruncate failed"); + close(fd); + return; + } + + // 映射共享内存 + void* ptr = mmap(nullptr, shm_max_size, PROT_WRITE, MAP_SHARED, fd, 0); + if (ptr == MAP_FAILED) { + perror("mmap failed"); + close(fd); + return; + } + + // 写入图像数据 + uint32_t size = static_cast(img_size); + std::memcpy(ptr, &size, 4); // 前4字节写入长度 + std::memcpy(static_cast(ptr) + 4, buf.data(), img_size); + + // 关闭映射和文件描述符 + munmap(ptr, shm_max_size); + close(fd); + } + void drawDebugOverlayShow(const AutoGuidanceDebug& dbg, bool auto_fps) { + static auto last_show_time = std::chrono::steady_clock::now(); + + if (dbg.img_frame.src_img.empty()) + return; + cv::Mat src_img = dbg.img_frame.src_img; + auto now = std::chrono::steady_clock::now(); + const double min_interval_ms = 1000.0 / 30.0; + if (std::chrono::duration(now - last_show_time).count() + < min_interval_ms + && auto_fps) + return; + last_show_time = now; + + // 图像构造 + cv::Mat debug_img; + src_img.convertTo(debug_img, -1, 1, 0); + cv::cvtColor(debug_img, debug_img, cv::COLOR_BGR2RGB); + if (debug_img.empty()) + return; + + // 封装后的绘图函数 + drawAutoGuidanceDebugContent(debug_img, dbg); + + cv::imshow("debug_armor", debug_img); + cv::waitKey(1); + } + void debuglog(const GuidanceTarget& target) { + static bool first_log = true; + static std::chrono::steady_clock::time_point start_time; + static DebugLogs log; + if (first_log) { + start_time = std::chrono::steady_clock::now(); + first_log = false; + } + auto now = std::chrono::steady_clock::now(); + double t = std::chrono::duration(now - start_time).count(); + log.time_log.push_back(t); + double cx_norm = target.center().x / target.image_size_.width * 2.0 - 1.0; + log.cx_log.push_back(cx_norm); + auto trim = [](std::vector& v) { + if (v.size() > 1000) + v.erase(v.begin()); + }; + + trim(log.time_log); + trim(log.cx_log); + nlohmann::json j; + { + j["time"] = log.time_log; + j["yaw"] = log.cx_log; + } + std::ofstream file("/dev/shm/cmd_log.json"); + if (file.is_open()) { + file << j.dump(); + } + } +} // namespace auto_guidance +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_guidance/debug.hpp b/wust_vision-main/tasks/auto_guidance/debug.hpp new file mode 100644 index 0000000..0b82293 --- /dev/null +++ b/wust_vision-main/tasks/auto_guidance/debug.hpp @@ -0,0 +1,24 @@ +#pragma once +#include "tasks/auto_guidance/guidance_tracker/guidance_target.hpp" +#include "wust_vl/video/icamera.hpp" +namespace wust_vision { +namespace auto_guidance { + + struct AutoGuidanceDebug { + wust_vl::video::ImageFrame img_frame; + double latency_ms; + GuidanceTarget target; + GreenLights lights; + cv::Mat mask; + }; + struct DebugLogs { + std::vector time_log; + std::vector cx_log; + }; + void debuglog(const GuidanceTarget& target); + void drawDebugOverlayShm(const AutoGuidanceDebug& dbg, bool auto_fps); + void drawDebugOverlayWrite(const AutoGuidanceDebug& dbg, bool auto_fps); + void drawDebugOverlayShow(const AutoGuidanceDebug& dbg, bool auto_fps); + void drawAutoGuidanceDebugContent(cv::Mat& debug_img, const AutoGuidanceDebug& dbg); +} // namespace auto_guidance +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_guidance/guidance_detector/detector_base.hpp b/wust_vision-main/tasks/auto_guidance/guidance_detector/detector_base.hpp new file mode 100644 index 0000000..29316c7 --- /dev/null +++ b/wust_vision-main/tasks/auto_guidance/guidance_detector/detector_base.hpp @@ -0,0 +1,18 @@ +#pragma once +#include "tasks/auto_guidance/type.hpp" +#include "tasks/type_common.hpp" +namespace wust_vision { +namespace auto_guidance { + + class detector_base { + public: + virtual ~detector_base() = default; + virtual void pushInput(CommonFrame& frame) = 0; + + using DetectorCallback = + std::function&, const CommonFrame&)>; + + virtual void setCallback(DetectorCallback cb) = 0; + }; +} // namespace auto_guidance +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_guidance/guidance_detector/detector_factory.hpp b/wust_vision-main/tasks/auto_guidance/guidance_detector/detector_factory.hpp new file mode 100644 index 0000000..f045352 --- /dev/null +++ b/wust_vision-main/tasks/auto_guidance/guidance_detector/detector_factory.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include "tasks/auto_guidance/guidance_detector/opencv/guidance_detector_opencv.hpp" +#ifdef USE_OPENVINO + #include "openvino/guidance_detector_openvino.hpp" + +#endif +namespace wust_vision { +namespace auto_guidance { + class DetectorFactory { + public: + static std::unique_ptr + createDetector(const std::string& backend, const YAML::Node& config, bool debug) { +#if defined(USE_OPENVINO) + if (backend == "openvino") { + return std::make_unique(config); + } +#endif + + if (backend == "opencv") { + return std::make_unique(config, debug); + } + + throw std::runtime_error("Unsupported detector backend (or not compiled): " + backend); + } + }; +} // namespace auto_guidance +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_guidance/guidance_detector/green_light_infer.cpp b/wust_vision-main/tasks/auto_guidance/guidance_detector/green_light_infer.cpp new file mode 100644 index 0000000..64b0001 --- /dev/null +++ b/wust_vision-main/tasks/auto_guidance/guidance_detector/green_light_infer.cpp @@ -0,0 +1,100 @@ +#include "green_light_infer.hpp" +namespace wust_vision { +namespace auto_guidance { + GreenLightInfer::GreenLightInfer(const Params& params) { + params_ = params; + } + std::vector GreenLightInfer::postProcess( + const cv::Mat& output_buffer, + const Eigen::Matrix& transform_matrix + ) { + std::vector Lights; + std::vector confidences; + std::vector boxes; + + const int num_boxes = output_buffer.rows; + const int attr = output_buffer.cols; + + for (int i = 0; i < num_boxes; ++i) { + float confidence = output_buffer.at(i, 4); + if (confidence < params_.conf_threshold) + continue; + + cv::Mat class_scores = output_buffer.row(i).colRange(5, 5 + 9); + cv::Mat color_scores = output_buffer.row(i).colRange(5 + 9, 5 + 9 + 4); + + double maxClassConfidence; + cv::Point classIdPoint; + cv::minMaxLoc(class_scores, nullptr, &maxClassConfidence, nullptr, &classIdPoint); + if (maxClassConfidence < params_.conf_threshold) + continue; + + if (classIdPoint.x != 8) + continue; + + float cx = output_buffer.at(i, 0); + float cy = output_buffer.at(i, 1); + float w = output_buffer.at(i, 2); + float h = output_buffer.at(i, 3); + + // === coordinate transform === + Eigen::Vector3f pt(cx, cy, 1.0f); + Eigen::Vector3f pt_trans = transform_matrix * pt; + + float cx_t = pt_trans(0); + float cy_t = pt_trans(1); + + // compute scale for bbox + float scale_x = std::sqrt(transform_matrix.row(0).head<2>().squaredNorm()); + float scale_y = std::sqrt(transform_matrix.row(1).head<2>().squaredNorm()); + + float w_t = w * scale_x; + float h_t = h * scale_y; + + cv::Rect2d bbox(cx_t - w_t / 2.0f, cy_t - h_t / 2.0f, w_t, h_t); + + GreenLight light; + light.id = classIdPoint.x; + light.score = confidence; + light.center_point = cv::Point2f(cx_t, cy_t); + light.box = bbox; + + Lights.emplace_back(light); + confidences.emplace_back(confidence); + boxes.emplace_back(bbox); + } + + std::vector nms_result; + cv::dnn::NMSBoxes( + boxes, + confidences, + params_.conf_threshold, + params_.nms_threshold, + nms_result, + 0.5f, + params_.top_k + ); + auto IoU = [](const cv::Rect2d& a, const cv::Rect2d& b) { + double inter = (a & b).area(); + double uni = a.area() + b.area() - inter; + return inter / uni; + }; + + std::vector final_result; + for (int i = 0; i < nms_result.size(); i++) { + bool keep = true; + for (int j = 0; j < final_result.size(); j++) { + if (IoU(final_result[j].box, Lights[nms_result[i]].box) > 0.3) { + keep = false; + break; + } + } + if (keep) + final_result.push_back(Lights[nms_result[i]]); + } + + return final_result; + } + +} // namespace auto_guidance +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_guidance/guidance_detector/green_light_infer.hpp b/wust_vision-main/tasks/auto_guidance/guidance_detector/green_light_infer.hpp new file mode 100644 index 0000000..a90c248 --- /dev/null +++ b/wust_vision-main/tasks/auto_guidance/guidance_detector/green_light_infer.hpp @@ -0,0 +1,35 @@ +#pragma once +#include "tasks/auto_guidance/type.hpp" +namespace wust_vision { +namespace auto_guidance { + class GreenLightInfer { + public: + using GreenLightInferPtr = std::unique_ptr; + struct Params { + int input_w; + int input_h; + float conf_threshold; + float nms_threshold; + int top_k; + bool use_norm; + } params_; + GreenLightInfer(const Params& params); + static inline GreenLightInferPtr makeGreenLightInfer(const Params& params) { + return std::make_unique(params); + } + std::vector postProcess( + const cv::Mat& output_buffer, + const Eigen::Matrix& transform_matrix + ); + int getInputW() { + return params_.input_w; + } + int getInputH() { + return params_.input_h; + } + bool getUseNorm() { + return params_.use_norm; + } + }; +} // namespace auto_guidance +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_guidance/guidance_detector/opencv/guidance_detector_opencv.cpp b/wust_vision-main/tasks/auto_guidance/guidance_detector/opencv/guidance_detector_opencv.cpp new file mode 100644 index 0000000..8f46d52 --- /dev/null +++ b/wust_vision-main/tasks/auto_guidance/guidance_detector/opencv/guidance_detector_opencv.cpp @@ -0,0 +1,162 @@ +#include "guidance_detector_opencv.hpp" +bool initializing = true; + +int lowH = 35, highH = 85; +int lowS = 50, highS = 255; +int lowV = 80, highV = 255; + +static void onTrackbar(int, void*) { + if (initializing) + return; // 初始化阶段不更新 + + lowH = cv::getTrackbarPos("LowH", "mask"); + highH = cv::getTrackbarPos("HighH", "mask"); + lowS = cv::getTrackbarPos("LowS", "mask"); + highS = cv::getTrackbarPos("HighS", "mask"); + lowV = cv::getTrackbarPos("LowV", "mask"); + highV = cv::getTrackbarPos("HighV", "mask"); +} + +void initGUI() { + cv::namedWindow("mask", cv::WINDOW_NORMAL); // 允许调整大小 + cv::resizeWindow("mask", 400, 600); // 设置初始窗口大小 + + cv::createTrackbar("LowH", "mask", nullptr, 179, onTrackbar); + cv::createTrackbar("HighH", "mask", nullptr, 179, onTrackbar); + cv::createTrackbar("LowS", "mask", nullptr, 255, onTrackbar); + cv::createTrackbar("HighS", "mask", nullptr, 255, onTrackbar); + cv::createTrackbar("LowV", "mask", nullptr, 255, onTrackbar); + cv::createTrackbar("HighV", "mask", nullptr, 255, onTrackbar); + + cv::setTrackbarPos("LowH", "mask", lowH); + cv::setTrackbarPos("HighH", "mask", highH); + cv::setTrackbarPos("LowS", "mask", lowS); + cv::setTrackbarPos("HighS", "mask", highS); + cv::setTrackbarPos("LowV", "mask", lowV); + cv::setTrackbarPos("HighV", "mask", highV); + + initializing = false; +} +namespace wust_vision { +namespace auto_guidance { + struct GuidanceDetectorOpenCV::Impl { + public: + Impl(const YAML::Node& config_gobal, bool debug) { + debug_ = debug; + const auto config = config_gobal["opencv"]; + lowH = config["HSV"]["lowH"].as(); + highH = config["HSV"]["highH"].as(); + lowS = config["HSV"]["lowS"].as(); + highS = config["HSV"]["highS"].as(); + highV = config["HSV"]["highV"].as(); + lowV = config["HSV"]["lowV"].as(); + max_area_ = config["contours"]["max_area"].as(); + min_area_ = config["contours"]["min_area"].as(); + min_aspect_ratio = config["contours"]["min_aspect_ratio"].as(); + min_fill_ratio_ = config["contours"]["min_fill_ratio"].as(); + use_gui_ = config["gui"].as(); + if (debug_ && use_gui_) { + initGUI(); + } + } + + void setCallback(DetectorCallback callback) { + infer_callback_ = callback; + } + + void processCallback(const CommonFrame& frame) { + std::vector lights; + cv::Mat img = frame.img_frame.src_img.clone(); + + cv::Mat hsv; + cv::cvtColor(img, hsv, cv::COLOR_BGR2HSV); + + cv::Scalar lower_green(lowH, lowS, lowV); + cv::Scalar upper_green(highH, highS, highV); + + cv::Mat mask; + cv::inRange(hsv, lower_green, upper_green, mask); + + cv::Mat kernel = cv::getStructuringElement(cv::MORPH_ELLIPSE, cv::Size(5, 5)); + cv::morphologyEx(mask, mask, cv::MORPH_OPEN, kernel); + cv::morphologyEx(mask, mask, cv::MORPH_CLOSE, kernel); + + std::vector> contours; + cv::findContours(mask, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE); + + std::vector valid_indices; + + for (size_t i = 0; i < contours.size(); i++) { + const double area = cv::contourArea(contours[i]); + const double perimeter = cv::arcLength(contours[i], true); + if (perimeter == 0) + continue; + + const double circularity = 4 * CV_PI * area / (perimeter * perimeter); + + const cv::RotatedRect rRect = cv::minAreaRect(contours[i]); + const double width = rRect.size.width; + const double height = rRect.size.height; + + if (width <= 0 || height <= 0) + continue; + + const double rect_area = width * height; + const double fill_ratio = area / rect_area; + const double aspect_ratio = std::min(width, height) / std::max(width, height); + + if (area > min_area_ && area < max_area_ && fill_ratio > min_fill_ratio_ + && aspect_ratio > min_aspect_ratio) + { + cv::Point2f center; + float radius; + cv::minEnclosingCircle(contours[i], center, radius); + GreenLight gl; + gl.center_point = center; + gl.box = cv::boundingRect(contours[i]); + gl.score = circularity; + gl.radius = radius; + lights.push_back(gl); + } + } + static auto last = std::chrono::steady_clock::now(); + const auto now = std::chrono::steady_clock::now(); + const double dt = std::chrono::duration(now - last).count(); + + if (debug_ && dt > 33.3 && use_gui_) { // 30Hz 刷新 + cv::imshow("mask", mask); + cv::waitKey(1); // 非阻塞 + last = now; + } + + if (infer_callback_) { + infer_callback_(lights, frame); + } + } + void pushInput(CommonFrame& frame) { + frame.id = current_id_++; + processCallback(frame); + } + DetectorCallback infer_callback_; + int current_id_ = 0; + double min_area_ = 100; + double max_area_ = 10000; + double min_fill_ratio_ = 0.5; + double min_aspect_ratio = 0.7; + bool debug_ = false; + bool use_gui_ = false; + }; + GuidanceDetectorOpenCV::GuidanceDetectorOpenCV(const YAML::Node& config_gobal, bool debug) { + _impl = std::make_unique(config_gobal, debug); + } + GuidanceDetectorOpenCV::~GuidanceDetectorOpenCV() { + _impl.reset(); + } + void GuidanceDetectorOpenCV::pushInput(CommonFrame& frame) { + _impl->pushInput(frame); + } + void GuidanceDetectorOpenCV::setCallback(DetectorCallback callback) { + _impl->setCallback(callback); + } +} // namespace auto_guidance +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_guidance/guidance_detector/opencv/guidance_detector_opencv.hpp b/wust_vision-main/tasks/auto_guidance/guidance_detector/opencv/guidance_detector_opencv.hpp new file mode 100644 index 0000000..db9a46c --- /dev/null +++ b/wust_vision-main/tasks/auto_guidance/guidance_detector/opencv/guidance_detector_opencv.hpp @@ -0,0 +1,18 @@ +#pragma once +#include "tasks/auto_guidance/guidance_detector/detector_base.hpp" +namespace wust_vision { +namespace auto_guidance { + class GuidanceDetectorOpenCV: public detector_base { + public: + GuidanceDetectorOpenCV(const YAML::Node& config, bool debug); + ~GuidanceDetectorOpenCV(); + void pushInput(CommonFrame& frame) override; + + void setCallback(DetectorCallback callback) override; + + private: + struct Impl; + std::unique_ptr _impl; + }; +} // namespace auto_guidance +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_guidance/guidance_detector/openvino/guidance_detector_openvino.cpp b/wust_vision-main/tasks/auto_guidance/guidance_detector/openvino/guidance_detector_openvino.cpp new file mode 100644 index 0000000..d27c2cf --- /dev/null +++ b/wust_vision-main/tasks/auto_guidance/guidance_detector/openvino/guidance_detector_openvino.cpp @@ -0,0 +1,116 @@ +#ifdef USE_OPENVINO + + #include "guidance_detector_openvino.hpp" + #include "tasks/auto_guidance/guidance_detector/green_light_infer.hpp" + #include "tasks/utils/utils.hpp" + #include "wust_vl/ml_net/openvino/openvino_net.hpp" +namespace wust_vision { +namespace auto_guidance { + struct GuidanceDetectorOpenVino::Impl { + public: + Impl(const YAML::Node& config_gobal) { + auto config = config_gobal["openvino"]; + std::string model_path = utils::expandEnv(config["model_path"].as()); + std::string device_name = config["device_type"].as(); + int top_k = config["top_k"].as(); + float nms_threshold = config["nms_threshold"].as(); + float conf_threshold = config["conf_threshold"].as(); + green_light_infer_ = GreenLightInfer::makeGreenLightInfer(GreenLightInfer::Params { + .input_w = 640, + .input_h = 384, + .conf_threshold = conf_threshold, + .nms_threshold = nms_threshold, + .top_k = top_k, + .use_norm = true }); + openvino_net_ = std::make_unique(); + auto ppp_init_fun = [this](ov::preprocess::PrePostProcessor& ppp) { + ppp.input() + .tensor() + .set_element_type(ov::element::u8) + .set_layout("NHWC") + .set_color_format(ov::preprocess::ColorFormat::BGR); + + ppp.input() + .preprocess() + .convert_element_type(ov::element::f32) + .convert_color(ov::preprocess::ColorFormat::RGB) + .scale(255.f); + + ppp.input().model().set_layout("NCHW"); + + ppp.output(0).tensor().set_element_type(ov::element::f32); + + ppp.output(1).tensor().set_element_type(ov::element::f32); + + ppp.output(2).tensor().set_element_type(ov::element::f32); + }; + wust_vl::ml_net::OpenvinoNet::Params params; + params.model_path = model_path; + params.device_name = device_name; + params.mode = config["use_throughputmode"].as() + ? ov::hint::PerformanceMode::THROUGHPUT + : ov::hint::PerformanceMode::LATENCY; + openvino_net_->init(params, ppp_init_fun); + } + ~Impl() { + openvino_net_.reset(); + green_light_infer_.reset(); + } + + void setCallback(DetectorCallback callback) { + infer_callback_ = callback; + } + + void processCallback(const CommonFrame& frame) const { + Eigen::Matrix3f transform_matrix; + const cv::Mat resized_img = utils::letterbox( + frame.img_frame.src_img, + transform_matrix, + green_light_infer_->getInputW(), + green_light_infer_->getInputH() + ); + + const auto input_info = openvino_net_->getInputInfo(); + const auto input_tensor = + ov::Tensor(input_info.first, input_info.second, resized_img.data); + + auto infer_request = openvino_net_->createInferRequest(); + infer_request.set_input_tensor(input_tensor); + infer_request.infer(); + const auto output = infer_request.get_output_tensor(0); + + // Process output data + const auto output_shape = output.get_shape(); + const float* ptr = output.data(); + cv::Mat + output_buffer(output_shape[1], output_shape[2], CV_32F, const_cast(ptr)); + const auto objs_result = + green_light_infer_->postProcess(output_buffer, transform_matrix); + if (infer_callback_) { + infer_callback_(objs_result, frame); + } + } + void pushInput(CommonFrame& frame) { + frame.id = current_id_++; + processCallback(frame); + } + std::unique_ptr openvino_net_; + std::unique_ptr green_light_infer_; + DetectorCallback infer_callback_; + int current_id_ = 0; + }; + GuidanceDetectorOpenVino::GuidanceDetectorOpenVino(const YAML::Node& config_gobal) { + _impl = std::make_unique(config_gobal); + } + GuidanceDetectorOpenVino::~GuidanceDetectorOpenVino() { + _impl.reset(); + } + void GuidanceDetectorOpenVino::pushInput(CommonFrame& frame) { + _impl->pushInput(frame); + } + void GuidanceDetectorOpenVino::setCallback(DetectorCallback callback) { + _impl->setCallback(callback); + } +} // namespace auto_guidance +} // namespace wust_vision +#endif \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_guidance/guidance_detector/openvino/guidance_detector_openvino.hpp b/wust_vision-main/tasks/auto_guidance/guidance_detector/openvino/guidance_detector_openvino.hpp new file mode 100644 index 0000000..081d8a8 --- /dev/null +++ b/wust_vision-main/tasks/auto_guidance/guidance_detector/openvino/guidance_detector_openvino.hpp @@ -0,0 +1,19 @@ +#pragma once +#include "tasks/auto_guidance/guidance_detector/detector_base.hpp" + +namespace wust_vision { +namespace auto_guidance { + class GuidanceDetectorOpenVino: public detector_base { + public: + GuidanceDetectorOpenVino(const YAML::Node& config); + ~GuidanceDetectorOpenVino(); + void pushInput(CommonFrame& frame) override; + + void setCallback(DetectorCallback callback) override; + + private: + struct Impl; + std::unique_ptr _impl; + }; +} // namespace auto_guidance +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_guidance/guidance_tracker/guidance_target.cpp b/wust_vision-main/tasks/auto_guidance/guidance_tracker/guidance_target.cpp new file mode 100644 index 0000000..e637950 --- /dev/null +++ b/wust_vision-main/tasks/auto_guidance/guidance_tracker/guidance_target.cpp @@ -0,0 +1,143 @@ +#include "guidance_target.hpp" +namespace wust_vision { +namespace auto_guidance { + GuidanceTarget::GuidanceTarget() { + target_state_ = Eigen::VectorXd::Zero(imgbox_model::X_N); + } + GuidanceTarget::GuidanceTarget(const GreenLight& light, TargetConfig target_config) { + target_config_ = target_config; + auto yfv2 = imgbox_model::Predict(0.01); + auto yhv2 = imgbox_model::Measure(); + auto yu_qv2 = [this]() { return computeProcessNoise(0.01); }; + + auto yu_rv2 = [this](const Eigen::Matrix& z) { + return this->computeMeasurementCovariance(z); + }; + Eigen::DiagonalMatrix p0; + p0.diagonal() << 1000, 1000, 1000, 1000, 64000, 64000, 64000, 64000; + esekf_ = imgbox_model::BBox8ESEKF(yfv2, yhv2, yu_qv2, yu_rv2, p0); + + esekf_.setResidualFunc([this]( + const Eigen::Matrix& z_pred, + const Eigen::Matrix& z + ) { + Eigen::Matrix r = z - z_pred; + return r; + }); + esekf_.setIterationNum(target_config_.iter_num); + esekf_.setInjectFunc([this]( + const Eigen::Matrix& delta, + Eigen::Matrix& nominal + ) { + for (int i = 0; i < imgbox_model::X_N; i++) { + nominal[i] += delta[i]; + } + }); + + double cx = light.center_point.x; + double cy = light.center_point.y; + double w = light.box.width; + double h = light.box.height; + target_state_ << cx, 0, cy, 0, w, 0, h, 0; + esekf_.setState(target_state_); + last_t_ = light.timestamp; + position_ = light.position; + timestamp_ = light.timestamp; + image_size_ = light.image_size; + is_inited_ = true; + } + Eigen::Matrix + GuidanceTarget::computeMeasurementCovariance( + const Eigen::Matrix& z + ) const { + Eigen::Matrix r; + // clang-format off + r < + GuidanceTarget::computeProcessNoise(double dt) const { + Eigen::Matrix q; + + double t = dt; + double q_pp = pow(t, 4) / 4.0 * target_config_.q_xy; + double q_pv = pow(t, 3) / 2.0 * target_config_.q_xy; + double q_vv = pow(t, 2) * target_config_.q_xy; + + double q_ss = pow(t, 4) / 4.0 * target_config_.q_wh; + double q_sv = pow(t, 3) / 2.0 * target_config_.q_wh; + double q_vvs = pow(t, 2) * target_config_.q_wh; + + // clang-format off + // cx vx cy vy w vw h vh + q << q_pp, q_pv, 0, 0, 0, 0, 0, 0, + q_pv, q_vv, 0, 0, 0, 0, 0, 0, + 0, 0, q_pp, q_pv, 0, 0, 0, 0, + 0, 0, q_pv, q_vv, 0, 0, 0, 0, + 0, 0, 0, 0, q_ss, q_sv, 0, 0, + 0, 0, 0, 0, q_sv, q_vvs, 0, 0, + 0, 0, 0, 0, 0, 0, q_ss, q_sv, + 0, 0, 0, 0, 0, 0, q_sv, q_vvs; + + // clang-format on + return q; + } + void GuidanceTarget::predict(std::chrono::steady_clock::time_point t) { + double dt = wust_vl::common::utils::time_utils::durationSec(last_t_, t); + + predict(dt); + + last_t_ = t; + } + void GuidanceTarget::predict(double dt) { + dt_ = dt; + + esekf_.setPredictFunc(imgbox_model::Predict { dt }); + auto yu_qv2 = [dt, this]() { return computeProcessNoise(dt); }; + + esekf_.setUpdateQ(yu_qv2); + + target_state_ = esekf_.predict(); + } + + bool GuidanceTarget::update(const GreenLights& lights) { + auto ls = lights.lights; + timestamp_ = lights.timestamp; + + auto yu_rv2 = [this](const Eigen::Matrix& z) { + return this->computeMeasurementCovariance(z); + }; + esekf_.setUpdateR(yu_rv2); + int best_id = -1; + double min_error = std::numeric_limits::max(); + for (int i = 0; i < ls.size(); i++) { + double centor_error = cv::norm(ls[i].center_point - center()); + double pos_error = (ls[i].position - position_).norm(); + if (centor_error < min_error && pos_error < target_config_.max_dis_diff) { + min_error = centor_error; + best_id = i; + } + } + if (best_id == -1) { + return false; + } + measurement_ = Eigen::Vector4d( + ls[best_id].center_point.x, + ls[best_id].center_point.y, + ls[best_id].box.width, + ls[best_id].box.height + ); + esekf_.setMeasureFunc(imgbox_model::Measure()); + target_state_ = esekf_.update(measurement_); + position_ = ls[best_id].position; + image_size_ = ls[best_id].image_size; + last_t_ = timestamp_; + return true; + } + +} // namespace auto_guidance +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_guidance/guidance_tracker/guidance_target.hpp b/wust_vision-main/tasks/auto_guidance/guidance_tracker/guidance_target.hpp new file mode 100644 index 0000000..f6a5542 --- /dev/null +++ b/wust_vision-main/tasks/auto_guidance/guidance_tracker/guidance_target.hpp @@ -0,0 +1,89 @@ +#pragma once +#include "motion_models/imgbox_model.hpp" +#include "tasks/auto_guidance/type.hpp" +#include +#include +namespace wust_vision { +namespace auto_guidance { + struct TargetConfig { + void load(const YAML::Node& config) { + xy_r = config["xy_r"].as(); + wh_r = config["wh_r"].as(); + q_xy = config["q_xy"].as(); + q_wh = config["q_wh"].as(); + iter_num = config["iter_num"].as(); + max_dis_diff = config["max_dis_diff"].as(); + } + double xy_r = 0.05; + double wh_r = 0.05; + double q_xy = 10; + double q_wh = 10; + int iter_num = 2; + double max_dis_diff = 2.0; + }; + class GuidanceTarget { + public: + GuidanceTarget(); + GuidanceTarget(const GreenLight& light, TargetConfig target_config); + GuidanceTarget& operator=(const GuidanceTarget&) = default; + bool update(const GreenLights& lights); + + void predict(std::chrono::steady_clock::time_point t); + void predict(double dt); + Eigen::Matrix + computeMeasurementCovariance(const Eigen::Matrix& z) const; + Eigen::Matrix computeProcessNoise(double dt + ) const; + std::chrono::steady_clock::time_point last_t_; + std::chrono::steady_clock::time_point timestamp_; + double dt_; + cv::Size2d image_size_; + imgbox_model::BBox8ESEKF esekf_; + Eigen::Matrix measurement_ = + Eigen::Matrix::Zero(); + Eigen::Matrix target_state_ = + Eigen::Matrix::Zero(); + Eigen::Vector3d position_; + TargetConfig target_config_; + bool is_inited_ = false; + bool is_tracking_ = false; + bool checkappear() { + return is_tracking_ + && wust_vl::common::utils::time_utils::durationSec( + timestamp_, + wust_vl::common::utils::time_utils::now() + ) + < 3.0; + } + cv::Point2d center() const { + return cv::Point2d(target_state_(0), target_state_(2)); + } + cv::Rect2d box() const { + return cv::Rect2d( + target_state_(0) - target_state_(4) / 2, + target_state_(2) - target_state_(6) / 2, + target_state_(4), + target_state_(6) + ); + } + void draw(cv::Mat& img) const { + cv::rectangle(img, box(), cv::Scalar(255, 50, 0), 2); + cv::circle(img, center(), 3, cv::Scalar(255, 255, 255), -1); + cv::line( + img, + cv::Point(center().x, center().y), + cv::Point(img.cols / 2.0, center().y), + cv::Scalar(0, 0, 255), + 2 + ); + cv::line( + img, + cv::Point2f(img.cols / 2.0, 0), + cv::Point2f(img.cols / 2.0, img.rows), + cv::Scalar(255, 255, 255), + 2 + ); + } + }; +} // namespace auto_guidance +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_guidance/guidance_tracker/guidance_tracker.cpp b/wust_vision-main/tasks/auto_guidance/guidance_tracker/guidance_tracker.cpp new file mode 100644 index 0000000..2b8dccc --- /dev/null +++ b/wust_vision-main/tasks/auto_guidance/guidance_tracker/guidance_tracker.cpp @@ -0,0 +1,109 @@ +#include "guidance_tracker.hpp" +namespace wust_vision { +namespace auto_guidance { + struct GuidanceTracker::Impl { + public: + Impl(const YAML::Node& config) { + target_config_.load(config["target"]); + tracking_thres_ = config["tracking_thres"].as(5); + lost_dt_ = config["lost_time_thres"].as(); + } + + GuidanceTarget track(const GreenLights& lights) { + double dt = std::chrono::duration(lights.timestamp - last_time_).count(); + last_time_ = lights.timestamp; + lost_thres_ = std::abs(static_cast(lost_dt_ / dt)); + bool found; + if (tracker_state == LOST) { + found = initTarget(lights); + } else { + found = updateTarget(lights); + } + updateFsm(found); + + return guidance_target_; + } + void updateFsm(bool found) { + if (tracker_state == DETECTING) { + if (found) { + detect_count_++; + if (detect_count_ > tracking_thres_) { + detect_count_ = 0; + tracker_state = TRACKING; + } + } else { + detect_count_ = 0; + tracker_state = LOST; + } + } else if (tracker_state == TRACKING) { + if (!found) { + tracker_state = TEMP_LOST; + lost_count_++; + } + } else if (tracker_state == TEMP_LOST) { + if (!found) { + lost_count_++; + if (lost_count_ > lost_thres_) { + lost_count_ = 0; + tracker_state = LOST; + } + } else { + tracker_state = TRACKING; + lost_count_ = 0; + } + } + if (tracker_state == LOST || tracker_state == DETECTING) { + guidance_target_.is_tracking_ = false; + } else { + guidance_target_.is_tracking_ = true; + } + } + bool initTarget(const GreenLights& lights) { + int best_id = -1; + double max_score = -1e9; + for (int i = 0; i < lights.lights.size(); i++) { + if (lights.lights[i].score > max_score) { + max_score = lights.lights[i].score; + best_id = i; + } + } + if (best_id == -1) { + return false; + } + tracker_state = DETECTING; + guidance_target_ = GuidanceTarget(lights.lights[best_id], target_config_); + return true; + } + bool updateTarget(const GreenLights& lights) { + guidance_target_.predict(lights.timestamp); + return guidance_target_.update(lights); + } + enum State { + LOST, + DETECTING, + TRACKING, + TEMP_LOST, + } tracker_state = LOST; + GuidanceTarget guidance_target_; + int tracking_thres_; + int lost_thres_; + int detect_count_ = 0; + int lost_count_ = 0; + double lost_dt_; + std::chrono::steady_clock::time_point last_time_; + TargetConfig target_config_; + }; + GuidanceTracker::GuidanceTracker(const YAML::Node& config) { + _impl = std::make_unique(config); + } + GuidanceTracker::~GuidanceTracker() { + _impl.reset(); + } + GuidanceTarget GuidanceTracker::track(const GreenLights& lights) { + return _impl->track(lights); + } + std::chrono::steady_clock::time_point GuidanceTracker::getLastTime() const { + return _impl->last_time_; + } +} // namespace auto_guidance +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_guidance/guidance_tracker/guidance_tracker.hpp b/wust_vision-main/tasks/auto_guidance/guidance_tracker/guidance_tracker.hpp new file mode 100644 index 0000000..18144d5 --- /dev/null +++ b/wust_vision-main/tasks/auto_guidance/guidance_tracker/guidance_tracker.hpp @@ -0,0 +1,21 @@ +#pragma once +#include "tasks/auto_guidance/guidance_tracker/guidance_target.hpp" +namespace wust_vision { +namespace auto_guidance { + class GuidanceTracker { + public: + using Ptr = std::unique_ptr; + GuidanceTracker(const YAML::Node& config); + ~GuidanceTracker(); + static inline Ptr create(const YAML::Node& config) { + return std::make_unique(config); + } + GuidanceTarget track(const GreenLights& lights); + std::chrono::steady_clock::time_point getLastTime() const; + + private: + struct Impl; + std::unique_ptr _impl; + }; +} // namespace auto_guidance +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_guidance/guidance_tracker/motion_models/imgbox_model.hpp b/wust_vision-main/tasks/auto_guidance/guidance_tracker/motion_models/imgbox_model.hpp new file mode 100644 index 0000000..9f4bdbb --- /dev/null +++ b/wust_vision-main/tasks/auto_guidance/guidance_tracker/motion_models/imgbox_model.hpp @@ -0,0 +1,45 @@ +#pragma once + +#include "KalmanHyLib/kalman_hybird_lib.hpp" +#include + +namespace imgbox_model { + +static constexpr int X_N = 8; // cx vx cy vy w vw h vh +static constexpr int Z_N = 4; // measured cx cy w h + +// ========================== Predict Model ========================== +struct Predict { + Predict() = default; + explicit Predict(double dt): dt_(dt) {} + + template + void operator()(const T x0[X_N], T x1[X_N]) { + for (int i = 0; i < X_N; i++) { + x1[i] = x0[i]; + } + x1[0] += x0[1] * dt_; // cx + x1[2] += x0[3] * dt_; // cy + x1[4] += x0[5] * dt_; // w + x1[6] += x0[7] * dt_; // h + } + + double dt_; +}; + +// ========================== Measurement Model ========================== +struct Measure { + Measure() = default; + template + void operator()(const T x[X_N], T z[Z_N]) const { + z[0] = x[0]; // cx + z[1] = x[2]; // cy + z[2] = x[4]; // w + z[3] = x[6]; // h + } +}; + +using BBox8EKF = kalman_hybird_lib::ExtendedKalmanFilter; +using BBox8ESEKF = kalman_hybird_lib::ErrorStateEKF; + +} // namespace imgbox_model diff --git a/wust_vision-main/tasks/auto_guidance/type.hpp b/wust_vision-main/tasks/auto_guidance/type.hpp new file mode 100644 index 0000000..037dd72 --- /dev/null +++ b/wust_vision-main/tasks/auto_guidance/type.hpp @@ -0,0 +1,107 @@ +#pragma once +#include "Eigen/Dense" +#include "opencv2/opencv.hpp" +namespace wust_vision { +namespace auto_guidance { + struct GreenLight { + int id = -1; + double score = 0.; + cv::Rect2d box; // bounding box in pixel coordinates + cv::Point2d center_point; // center in pixel coordinates + double radius; + Eigen::Vector3d position; + std::chrono::steady_clock::time_point timestamp; + cv::Size2d image_size; + // PnP 估计位姿 + bool solvePnP(const cv::Mat& K, const cv::Mat& distCoeffs) { + constexpr float half_w = 0.07; + // 真实世界点,单位米 + std::vector objectPoints = { { -half_w, -half_w, 0.f }, + { half_w, -half_w, 0.f }, + { half_w, half_w, 0.f }, + { -half_w, half_w, 0.f } }; + + // 像素点 + std::vector imagePoints = { + cv::Point2f(box.x, box.y), + cv::Point2f(box.x + box.width, box.y), + cv::Point2f(box.x + box.width, box.y + box.height), + cv::Point2f(box.x, box.y + box.height) + }; + + cv::Mat rvec, tvec; + + bool ok = cv::solvePnP( + objectPoints, + imagePoints, + K, + distCoeffs, + rvec, + tvec, + false, + cv::SOLVEPNP_ITERATIVE + ); + + if (!ok) + return false; + + // 转换到 Eigen 向量 + position = Eigen::Vector3d(tvec.at(0), tvec.at(1), tvec.at(2)); + + return true; + } + }; + + struct GreenLights { + public: + std::vector lights; + std::chrono::steady_clock::time_point timestamp; + int id; + void drawFront(cv::Mat& img) const { + for (const auto& light: lights) { + cv::rectangle(img, light.box, cv::Scalar(0, 255, 255), 2); + cv::circle(img, light.center_point, light.radius, cv::Scalar(0, 255, 0), 2); + cv::circle(img, light.center_point, 3, cv::Scalar(255, 0, 0), -1); + + cv::putText( + img, + std::to_string(light.score), + light.center_point + + cv::Point2d(light.box.width / 2.0, -light.box.height / 2.0), + cv::FONT_HERSHEY_SIMPLEX, + 0.5, + cv::Scalar(255, 0, 0), + 2 + ); + cv::putText( + img, + std::to_string(light.position.norm()), + light.center_point + cv::Point2d(light.box.width / 2.0, light.box.height / 2.0), + cv::FONT_HERSHEY_SIMPLEX, + 0.5, + cv::Scalar(255, 255, 255), + 2 + ); + } + } + void drawBack(cv::Mat& img) const { + for (const auto& light: lights) { + cv::line( + img, + cv::Point(light.center_point.x, light.center_point.y), + cv::Point(img.cols / 2.0, light.center_point.y), + cv::Scalar(0, 0, 255), + 2 + ); + } + cv::line( + img, + cv::Point2f(img.cols / 2.0, 0), + cv::Point2f(img.cols / 2.0, img.rows), + cv::Scalar(255, 255, 255), + 2 + ); + } + }; +} // namespace auto_guidance +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_sniper/auto_sniper.cpp b/wust_vision-main/tasks/auto_sniper/auto_sniper.cpp new file mode 100644 index 0000000..c538441 --- /dev/null +++ b/wust_vision-main/tasks/auto_sniper/auto_sniper.cpp @@ -0,0 +1,344 @@ +#include "3rdparty/angles.h" +//#ifdef USE_ROS2 +#ifdef FUCK + #include "auto_sniper.hpp" + #include "ros2/tf.hpp" + + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include + + #include + + #include + #include + #include + #include + + #include "k1_solver.hpp" + #include "offset_helper.hpp" + #include "tasks/type_common.hpp" + #include "tasks/utils/config.hpp" + #include "voxel_map.hpp" + #include + #include + #include + #include +namespace wust_vision::auto_sniper { + +struct AutoSniper::Impl { + Impl( + rclcpp::Node& node, + std::shared_ptr> motion_buffer + ) { + node_ = &node; + motion_buffer_ = motion_buffer; + tf_ = TF::create(*node_); + auto config = YAML::LoadFile(AUTO_SNIPER_CONFIG); + auto map_config = config["map"]; + auto min_pos_v = map_config["min_pos"].as>(); + auto min_pos = Eigen::Vector3d(min_pos_v[0], min_pos_v[1], min_pos_v[2]); + auto max_pos_v = map_config["max_pos"].as>(); + auto max_pos = Eigen::Vector3d(max_pos_v[0], max_pos_v[1], max_pos_v[2]); + voxel_map_ = std::make_shared>( + map_config["voxel_size"].as(), + min_pos, + max_pos, + true + ); + auto solver_config = config["solver"]; + vis_cloud_ = std::make_shared(); + k1_solver_ = K1BallisticSolver::create( + solver_config["k1"].as(), + solver_config["g"].as() + ); + target_armor_z_ = solver_config["target_armor_z"].as(); + offset_helper_ = OffsetHelper::create(config["offset_helper"]); + pointcloud_sub_ = node_->create_subscription( + "/cloud_registered", + rclcpp::SensorDataQoS(), + std::bind(&AutoSniper::Impl::pointCloudCallback, this, std::placeholders::_1) + ); + + odometry_sub_ = node_->create_subscription( + "/Odometry", + 10, + std::bind(&AutoSniper::Impl::odomCallback, this, std::placeholders::_1) + ); + traj_pub_ = + node_->create_publisher("bullet_trajectory", 10); + } + void start() { + if (run_flag_) { + return; + } + run_flag_ = true; + vis_thread_ = wust_vl::common::concurrency::MonitoredThread::create( + "AutoSniperVisualizer", + [this](wust_vl::common::concurrency::MonitoredThread::Ptr self) { + this->visualizeLoop(self); + } + ); + wust_vl::common::concurrency::ThreadManager::instance().registerThread(vis_thread_); + } + void doDebug() {} + void pushInput(CommonFrame& frame) {} + wust_vl::common::concurrency::MonitoredThread::Ptr getThread() { + return vis_thread_; + } + GimbalCmd solve(double bullet_speed) noexcept { + GimbalCmd cmd; + if (!target_pos_in_map_.has_value()) { + cmd.appear = false; + return cmd; + } + Eigen::Vector3d target_pos_in_self = self_in_map_.inverse() * target_pos_in_map_.value(); + target_pos_in_self.z() = target_armor_z_; + auto pitch = k1_solver_->solvePitch(target_pos_in_self, bullet_speed); + if (!pitch.has_value()) { + cmd.appear = false; + std::cout << "no pitch" << std::endl; + return cmd; + } + double yaw = + angles::normalize_angle(std::atan2(target_pos_in_self.y(), target_pos_in_self.x())); + auto traj = k1_solver_->computeTrajectory( + self_in_map_.translation(), + target_pos_in_map_.value(), + bullet_speed, + 0.01 + ); + double yaw_deg = angles::to_degrees(yaw); + double pitch_deg = angles::to_degrees(pitch.value()); + publishTrajectoryMarker(traj); + + auto control_pitch = pitch_deg + offset_helper_->getPitchOffset(target_pos_in_self.norm()); + + auto control_yaw = yaw_deg + offset_helper_->getYawOffset(target_pos_in_self.norm()); + if (auto last_att = motion_buffer_->get_last()) { + control_yaw += angles::to_degrees(last_att->data.yaw); + } + cmd.appear = true; + cmd.target_pitch = control_pitch; + cmd.target_yaw = control_yaw; + cmd.appear = true; + cmd.yaw = control_yaw; + cmd.pitch = control_pitch; + cmd.distance = target_pos_in_self.norm(); + cmd.enable_pitch_diff = 0.5; + cmd.enable_yaw_diff = 0.5; + static int count = 0; + count++; + if (count % 100 == 0) { + std::cout << cmd.yaw << " " << cmd.pitch << std::endl; + } + return cmd; + } + void publishTrajectoryMarker(const std::vector& traj) { + static auto last_pub_time = std::chrono::steady_clock::now(); + auto now = std::chrono::steady_clock::now(); + auto elapsed = std::chrono::duration_cast(now - last_pub_time); + if (elapsed.count() < 33) { + return; + } + last_pub_time = now; + + if (traj.empty()) + return; + visualization_msgs::msg::Marker marker; + marker.header.frame_id = target_frame_; + marker.header.stamp = node_->now(); + marker.ns = "bullet"; + marker.id = 0; + marker.type = visualization_msgs::msg::Marker::LINE_STRIP; + marker.action = visualization_msgs::msg::Marker::ADD; + marker.scale.x = 0.1; + marker.color.r = 0.0; + marker.color.g = 1.0; + marker.color.b = 1.0; + marker.color.a = 1.0; + marker.lifetime = rclcpp::Duration::from_seconds(1.0); + for (auto& p: traj) { + geometry_msgs::msg::Point pt; + pt.x = p.x(); + pt.y = p.y(); + pt.z = p.z(); + marker.points.push_back(pt); + } + + traj_pub_->publish(marker); + } + void odomCallback(const nav_msgs::msg::Odometry::SharedPtr msg) { + auto T = tf_->getTransform(target_frame_, "gimbal_yaw", msg->header.stamp); + + if (!T.has_value()) { + return; + } + self_in_map_ = T.value().cast(); + // Eigen::Isometry3f + // Eigen::Vector4f p( + // msg->pose.pose.position.x, + // msg->pose.pose.position.y, + // msg->pose.pose.position.z, + // 1.0f + // ); + + // auto p_target = T.value() * p; + // self_pos_ = Eigen::Vector3d(p_target.x(), p_target.y(), p_target.z()); + } + + void visualizeLoop(wust_vl::common::concurrency::MonitoredThread::Ptr self) { + while (!self->isAlive()) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + auto vis_cloud = std::make_shared(); + std::atomic picking = false; + auto future = std::async(std::launch::async, [&, self]() { + while (self->isAlive() && run_flag_) { + picking = true; + pickBlocking(vis_cloud); + picking = false; + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + }); + + while (self->isAlive() && run_flag_) { + self->heartbeat(); + if (!picking) { + std::lock_guard lock(cloud_mutex_); + *vis_cloud = *vis_cloud_; + } + + std::this_thread::sleep_for(std::chrono::milliseconds(30)); + } + if (future.valid()) { + future.wait(); + } + } + + void pickBlocking(std::shared_ptr pcd) { + open3d::visualization::VisualizerWithEditing vis; + vis.CreateVisualizerWindow("Pick points (close window when done)", 1280, 720); + vis.AddGeometry(pcd); + vis.Run(); + + auto picked = vis.GetPickedPoints(); + if (!picked.empty()) { + Eigen::Vector3d picked_pos = pcd->points_[picked.back()]; + target_pos_in_map_ = Eigen::Vector3d(picked_pos); + + RCLCPP_INFO_STREAM( + rclcpp::get_logger("awm"), + " Target pos: " << picked_pos.transpose() + << " self pos: " << self_in_map_.translation().transpose() + ); + } + + vis.DestroyVisualizerWindow(); + } + + void pointCloudCallback(const sensor_msgs::msg::PointCloud2::SharedPtr msg) { + auto T = tf_->getTransform(target_frame_, msg->header.frame_id, msg->header.stamp); + + if (!T.has_value()) { + return; + } + + const size_t size = msg->width * msg->height; + + sensor_msgs::PointCloud2ConstIterator iter_x(*msg, "x"); + sensor_msgs::PointCloud2ConstIterator iter_y(*msg, "y"); + sensor_msgs::PointCloud2ConstIterator iter_z(*msg, "z"); + + std::vector new_points; + new_points.reserve(size); + + for (size_t i = 0; i < size; ++i) { + Eigen::Vector4f p(*iter_x, *iter_y, *iter_z, 1.0f); + + p = T.value() * p; + + if (std::isfinite(p.x()) && std::isfinite(p.y()) && std::isfinite(p.z())) { + new_points.emplace_back(p.head<3>().cast()); + } + + ++iter_x; + ++iter_y; + ++iter_z; + } + for (const auto& p: new_points) { + auto idx = voxel_map_->worldToIndex(p); + if (idx > 0) { + voxel_map_->grid[idx].v = 1; + } + } + std::vector pointcloud; + for (int i = 0; i < voxel_map_->grid.size(); i++) { + if (voxel_map_->grid[i].v == 1) { + pointcloud.emplace_back(voxel_map_->indexToWorld(i)); + } + } + { + std::lock_guard lock(cloud_mutex_); + + vis_cloud_->points_ = pointcloud; + } + } + + std::string target_frame_ = "map"; + + rclcpp::Node* node_; + std::optional target_pos_in_map_ = std::nullopt; + rclcpp::Subscription::SharedPtr pointcloud_sub_; + rclcpp::Publisher::SharedPtr traj_pub_; + rclcpp::Subscription::SharedPtr odometry_sub_; + TF::Ptr tf_; + std::shared_ptr vis_cloud_; + struct Cell { + uint8_t v = 0; + }; + SlidingVoxelMap<3, Cell>::Ptr voxel_map_; + K1BallisticSolver::Ptr k1_solver_; + OffsetHelper::Ptr offset_helper_; + wust_vl::common::concurrency::MonitoredThread::Ptr vis_thread_; + bool run_flag_ = false; + std::mutex cloud_mutex_; + double target_armor_z_ = 0.0; + Eigen::Isometry3d self_in_map_ = Eigen::Isometry3d::Identity(); + std::shared_ptr> motion_buffer_; +}; +AutoSniper::AutoSniper( + rclcpp::Node& node, + std::shared_ptr> motion_buffer +) { + _impl = std::make_unique(node, motion_buffer); +} +AutoSniper::~AutoSniper() { + _impl.reset(); +} +void AutoSniper::start() { + _impl->start(); +} +void AutoSniper::doDebug() { + _impl->doDebug(); +} +void AutoSniper::pushInput(CommonFrame& frame) { + _impl->pushInput(frame); +} + +wust_vl::common::concurrency::MonitoredThread::Ptr AutoSniper::getThread() { + return _impl->getThread(); +} +GimbalCmd AutoSniper::solve(double bullet_speed) { + return _impl->solve(bullet_speed); +} +} // namespace wust_vision::auto_sniper +#endif \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_sniper/auto_sniper.hpp b/wust_vision-main/tasks/auto_sniper/auto_sniper.hpp new file mode 100644 index 0000000..a6ad5a2 --- /dev/null +++ b/wust_vision-main/tasks/auto_sniper/auto_sniper.hpp @@ -0,0 +1,32 @@ +#pragma once +#include "tasks/imodule.hpp" +#include "tasks/type_common.hpp" +#include +#include +namespace wust_vision { +namespace auto_sniper { + class AutoSniper: public IModule { + public: + using Ptr = std::shared_ptr; + AutoSniper( + rclcpp::Node& node, + std::shared_ptr> motion_buffer + ); + static Ptr create( + rclcpp::Node& node, + std::shared_ptr> motion_buffer + ) { + return std::make_shared(node, motion_buffer); + } + ~AutoSniper(); + void start() override; + void doDebug() override; + void pushInput(CommonFrame& frame) override; + GimbalCmd solve(double bullet_speed) override; + wust_vl::common::concurrency::MonitoredThread::Ptr getThread() override; + struct Impl; + std::unique_ptr _impl; + }; +} // namespace auto_sniper + +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_sniper/k1_solver.hpp b/wust_vision-main/tasks/auto_sniper/k1_solver.hpp new file mode 100644 index 0000000..978577e --- /dev/null +++ b/wust_vision-main/tasks/auto_sniper/k1_solver.hpp @@ -0,0 +1,106 @@ +#include +#include +#include +#include + +class K1BallisticSolver { +public: + using Ptr = std::unique_ptr; + K1BallisticSolver(double k1 = 0.05, double g = 9.81): k1_(k1), g_(g) {} + static Ptr create(double k1 = 0.05, double g = 9.81) { + return std::make_unique(k1, g); + } + std::optional solvePitch(const Eigen::Vector3d& target_pos, double v0) const { + double x = std::hypot(target_pos.x(), target_pos.y()); + double z = target_pos.z(); + + if (x < 1e-6) + return std::nullopt; + + auto heightError = [&](double pitch) { + double cos_theta = std::cos(pitch); + double sin_theta = std::sin(pitch); + + double denom = v0 * cos_theta; + if (denom <= 1e-6) + return 1e6; + + double t = -std::log(1.0 - k1_ * x / denom) / k1_; + + if (!std::isfinite(t)) + return 1e6; + + double z_pred = + ((v0 * sin_theta + g_ / k1_) / k1_) * (1.0 - std::exp(-k1_ * t)) - (g_ / k1_) * t; + + return z_pred - z; + }; + + double left = -0.3; + double right = 0.6; + + for (int i = 0; i < 60; ++i) { + double mid = 0.5 * (left + right); + double err = heightError(mid); + + if (err > 0) + right = mid; + else + left = mid; + } + + return 0.5 * (left + right); + } + std::vector computeTrajectory( + const Eigen::Vector3d& start, + const Eigen::Vector3d& target, + double v0, + double dt = 0.01 // 每个离散点的时间间隔 + ) { + std::vector traj; + + Eigen::Vector3d diff = target - start; + auto pitch_opt = solvePitch(diff, v0); + if (!pitch_opt.has_value()) + return traj; + + double pitch = pitch_opt.value(); + double yaw = std::atan2(diff.y(), diff.x()); + + double k1 = k1_; + double g = g_; + + double vx = v0 * std::cos(pitch) * std::cos(yaw); + double vy = v0 * std::cos(pitch) * std::sin(yaw); + double vz = v0 * std::sin(pitch); + + double t = 0.0; + Eigen::Vector3d pos = start; + + while (true) { + double exp_kt = std::exp(-k1 * t); + pos.x() = start.x() + (vx / k1) * (1 - exp_kt); + pos.y() = start.y() + (vy / k1) * (1 - exp_kt); + pos.z() = start.z() + ((vz + g / k1) / k1) * (1 - exp_kt) - (g / k1) * t; + + traj.push_back(pos); + + t += dt; + + double dx = std::abs(pos.x() - start.x()); + double dy = std::abs(pos.y() - start.y()); + double horizontal_dist = std::sqrt(dx * dx + dy * dy); + double target_dist = std::sqrt(diff.x() * diff.x() + diff.y() * diff.y()); + + if (horizontal_dist >= target_dist) { + break; + } + } + + return traj; + } + +private: + double k1_; + double g_; +}; \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_sniper/offset_helper.hpp b/wust_vision-main/tasks/auto_sniper/offset_helper.hpp new file mode 100644 index 0000000..eac9d41 --- /dev/null +++ b/wust_vision-main/tasks/auto_sniper/offset_helper.hpp @@ -0,0 +1,87 @@ +#pragma once + +#include +#include +#include +#include + +namespace wust_vision::auto_sniper { + +class OffsetHelper { +public: + using Ptr = std::shared_ptr; + struct OffsetPoint { + double distance; + double yaw; + double pitch; + }; + OffsetHelper(const YAML::Node& config) { + data_.clear(); + for (auto& v: config["offset_table"]) { + OffsetPoint p; + p.distance = v["distance"].as(); + p.yaw = v["yaw"].as(); + p.pitch = v["pitch"].as(); + + data_.push_back(p); + } + order_ = config["order"].as(); + yaw_base_offset = config["yaw_base_offset"].as(); + pitch_base_offset = config["pitch_base_offset"].as(); + fit(); + } + static Ptr create(const YAML::Node& config) { + return std::make_shared(config); + } + void fit() { + int n = data_.size(); + Eigen::MatrixXd A(n, order_ + 1); + + Eigen::VectorXd y_yaw(n); + Eigen::VectorXd y_pitch(n); + + for (int i = 0; i < n; ++i) { + double x = data_[i].distance; + + double v = 1; + for (int j = 0; j <= order_; ++j) { + A(i, j) = v; + v *= x; + } + + y_yaw(i) = data_[i].yaw; + y_pitch(i) = data_[i].pitch; + } + + yaw_coeff_ = A.colPivHouseholderQr().solve(y_yaw); + pitch_coeff_ = A.colPivHouseholderQr().solve(y_pitch); + } + + double getYawOffset(double distance) const { + return yaw_base_offset + eval(yaw_coeff_, distance); + } + double getPitchOffset(double distance) const { + return pitch_base_offset + eval(pitch_coeff_, distance); + } + + double eval(const Eigen::VectorXd& coeff, double x) const { + double y = 0; + double v = 1; + + for (int i = 0; i < coeff.size(); ++i) { + y += coeff[i] * v; + v *= x; + } + + return y; + } + std::vector data_; + + Eigen::VectorXd yaw_coeff_; + Eigen::VectorXd pitch_coeff_; + double yaw_base_offset = 0; + double pitch_base_offset = 0; + int order_ = 2; +}; + +} // namespace wust_vision::auto_sniper \ No newline at end of file diff --git a/wust_vision-main/tasks/auto_sniper/voxel_map.hpp b/wust_vision-main/tasks/auto_sniper/voxel_map.hpp new file mode 100644 index 0000000..94885ea --- /dev/null +++ b/wust_vision-main/tasks/auto_sniper/voxel_map.hpp @@ -0,0 +1,348 @@ +#pragma once + +#include "3rdparty/ankerl/unordered_dense.h" +#include +#include +#include +#include +namespace wust_vision::auto_sniper { + +template +struct VoxelKey { + std::array data {}; + + int& operator[](int i) noexcept { + return data[i]; + } + const int& operator[](int i) const noexcept { + return data[i]; + } + + // ----- x ----- + int& x() noexcept { + static_assert(Dim >= 1, "x requires Dim >= 1"); + return data[0]; + } + const int& x() const noexcept { + static_assert(Dim >= 1, "x requires Dim >= 1"); + return data[0]; + } + + // ----- y ----- + int& y() noexcept { + static_assert(Dim >= 2, "y requires Dim >= 2"); + return data[1]; + } + const int& y() const noexcept { + static_assert(Dim >= 2, "y requires Dim >= 2"); + return data[1]; + } + + // ----- z ----- + int& z() noexcept { + static_assert(Dim >= 3, "z requires Dim >= 3"); + return data[2]; + } + const int& z() const noexcept { + static_assert(Dim >= 3, "z requires Dim >= 3"); + return data[2]; + } + bool operator==(const VoxelKey& other) const noexcept { + for (int i = 0; i < Dim; ++i) + if (data[i] != other.data[i]) + return false; + return true; + } + + bool operator!=(const VoxelKey& other) const noexcept { + return !(*this == other); + } + + VoxelKey& operator+=(const VoxelKey& other) noexcept { + for (int i = 0; i < Dim; ++i) + data[i] += other.data[i]; + return *this; + } + + VoxelKey& operator-=(const VoxelKey& other) noexcept { + for (int i = 0; i < Dim; ++i) + data[i] -= other.data[i]; + return *this; + } + + VoxelKey& operator*=(int scalar) noexcept { + for (int i = 0; i < Dim; ++i) + data[i] *= scalar; + return *this; + } + + VoxelKey& operator/=(int scalar) noexcept { + for (int i = 0; i < Dim; ++i) + data[i] /= scalar; + return *this; + } + + friend VoxelKey operator+(VoxelKey lhs, const VoxelKey& rhs) noexcept { + lhs += rhs; + return lhs; + } + + friend VoxelKey operator-(VoxelKey lhs, const VoxelKey& rhs) noexcept { + lhs -= rhs; + return lhs; + } + + friend VoxelKey operator*(VoxelKey lhs, int scalar) noexcept { + lhs *= scalar; + return lhs; + } + + friend VoxelKey operator*(int scalar, VoxelKey rhs) noexcept { + rhs *= scalar; + return rhs; + } + + friend VoxelKey operator/(VoxelKey lhs, int scalar) noexcept { + lhs /= scalar; + return lhs; + } + constexpr VoxelKey cwiseMin(const VoxelKey& other) const noexcept { + VoxelKey out {}; + for (int i = 0; i < Dim; ++i) + out.data[i] = data[i] < other.data[i] ? data[i] : other.data[i]; + return out; + } + + constexpr VoxelKey cwiseMax(const VoxelKey& other) const noexcept { + VoxelKey out {}; + for (int i = 0; i < Dim; ++i) + out.data[i] = data[i] > other.data[i] ? data[i] : other.data[i]; + return out; + } +}; + +template +class SlidingVoxelMap { + static_assert(Dim == 2 || Dim == 3, "Dim must be 2 or 3"); + +public: + using Ptr = std::shared_ptr; + using Key = VoxelKey; + using EigenPoint = Eigen::Matrix; + + SlidingVoxelMap(double voxel_size_, const EigenPoint& size_, const EigenPoint& center_): + voxel_size(voxel_size_), + size(size_), + center(center_) { + EigenPoint half = size * 0.5f; + + min_key = worldToKey(center - half); + max_key = worldToKey(center + half); + + for (int i = 0; i < Dim; ++i) { + dims[i] = max_key[i] - min_key[i] + 1; + offset[i] = 0; + } + + center_key = worldToKey(center); + + size_t N = 1; + for (int i = 0; i < Dim; ++i) + N *= static_cast(dims[i]); + + grid.resize(N); + } + SlidingVoxelMap( + double voxel_size_, + const EigenPoint& min_pos, + const EigenPoint& max_pos, + bool /*dummy*/ + ): + voxel_size(voxel_size_) { + min_key = worldToKey(min_pos); + max_key = worldToKey(max_pos); + + for (int i = 0; i < Dim; ++i) { + if (max_key[i] < min_key[i]) + std::swap(max_key[i], min_key[i]); + + dims[i] = max_key[i] - min_key[i] + 1; + offset[i] = 0; + } + + for (int i = 0; i < Dim; ++i) + center_key[i] = (min_key[i] + max_key[i]) / 2; + + EigenPoint min_world = keyToWorld(min_key); + EigenPoint max_world = keyToWorld(max_key); + + for (int i = 0; i < Dim; ++i) { + center[i] = (min_world[i] + max_world[i]) * 0.5f; + size[i] = dims[i] * voxel_size; + } + + size_t N = 1; + for (int i = 0; i < Dim; ++i) + N *= static_cast(dims[i]); + + grid.resize(N); + } + + static Ptr create(double voxel_size, const EigenPoint& size, const EigenPoint& center) { + return std::make_shared(voxel_size, size, center); + } + + size_t gridSize() const noexcept { + return grid.size(); + } + inline int worldToIndex(const EigenPoint& p) const noexcept { + return keyToIndex(worldToKey(p)); + } + + inline Key worldToKey(const EigenPoint& p) const noexcept { + Key k; + const double inv = 1.0f / voxel_size; + for (int i = 0; i < Dim; ++i) + k.data[i] = static_cast(std::floor(p[i] * inv + 1e-6f)); + return k; + } + + inline EigenPoint keyToWorld(const Key& k) const noexcept { + EigenPoint p; + for (int i = 0; i < Dim; ++i) + p[i] = (k.data[i] + 0.5f) * voxel_size; + return p; + } + + inline EigenPoint indexToWorld(int idx) const noexcept { + return keyToWorld(indexToKey(idx)); + } + + inline int keyToIndex(const Key& k) const noexcept { + int idx = 0; + int stride = 1; + + for (int d = Dim - 1; d >= 0; --d) { + int delta = k[d] - center_key[d] + (dims[d] >> 1); + + if (delta < 0 || delta >= dims[d]) + return -1; + + int r = delta + offset[d]; + + if (r >= dims[d]) + r -= dims[d]; + else if (r < 0) + r += dims[d]; + + idx += r * stride; + stride *= dims[d]; + } + + return idx; + } + + inline Key indexToKey(int idx) const noexcept { + Key k; + + for (int d = Dim - 1; d >= 0; --d) { + int r = idx % dims[d]; + idx /= dims[d]; + + int delta = r - offset[d]; + + if (delta < 0) + delta += dims[d]; + else if (delta >= dims[d]) + delta -= dims[d]; + + k[d] = center_key[d] + delta - (dims[d] >> 1); + } + + return k; + } + template + void slideTo(const Key& new_center_key, ClearFunc clear_func) { + Key shift; + for (int d = 0; d < Dim; ++d) + shift[d] = new_center_key[d] - center_key[d]; + + for (int d = 0; d < Dim; ++d) { + if (std::abs(shift[d]) >= dims[d]) { + for (size_t i = 0; i < grid.size(); ++i) + clear_func(i); + + offset = {}; + center_key = new_center_key; + return; + } + } + for (int axis = 0; axis < Dim; ++axis) { + int s = shift[axis]; + if (s == 0) + continue; + + int steps = std::abs(s); + int dir = s > 0 ? 1 : -1; + + for (int step = 0; step < steps; ++step) { + int slice = (offset[axis] + dir * step + dims[axis]) % dims[axis]; + + clearSlice(axis, slice, clear_func); + } + + offset[axis] = (offset[axis] + s + dims[axis]) % dims[axis]; + } + + center_key = new_center_key; + EigenPoint half = size * 0.5f; + center = keyToWorld(center_key); + min_key = worldToKey(center - half); + max_key = worldToKey(center + half); + } + + template + void clearSlice(int axis, int slice, ClearFunc clear_func) { + if constexpr (Dim == 3) { + int dx = dims[0]; + int dy = dims[1]; + int dz = dims[2]; + + if (axis == 0) { + for (int y = 0; y < dy; ++y) + for (int z = 0; z < dz; ++z) { + int idx = (slice * dy + y) * dz + z; + clear_func(idx); + } + } else if (axis == 1) { + for (int x = 0; x < dx; ++x) + for (int z = 0; z < dz; ++z) { + int idx = (x * dy + slice) * dz + z; + clear_func(idx); + } + } else { + for (int x = 0; x < dx; ++x) + for (int y = 0; y < dy; ++y) { + int idx = (x * dy + y) * dz + slice; + clear_func(idx); + } + } + } + } + +public: + double voxel_size; + + Key dims; + Key offset; + + std::vector grid; + + Key center_key; + EigenPoint center; + EigenPoint size; + Key min_key; + Key max_key; +}; + +} // namespace wust_vision::auto_sniper \ No newline at end of file diff --git a/wust_vision-main/tasks/imodule.hpp b/wust_vision-main/tasks/imodule.hpp new file mode 100644 index 0000000..629aeed --- /dev/null +++ b/wust_vision-main/tasks/imodule.hpp @@ -0,0 +1,16 @@ +#pragma once +#include "tasks/type_common.hpp" +#include "wust_vl/common/concurrency/monitored_thread.hpp" +#include +namespace wust_vision { +class IModule { +public: + using Ptr = std::shared_ptr; + virtual void start() = 0; + virtual void pushInput(CommonFrame&) = 0; + virtual GimbalCmd solve(double bullet_speed) = 0; + virtual wust_vl::common::concurrency::MonitoredThread::Ptr getThread() = 0; + virtual void doDebug() = 0; + virtual ~IModule() = default; +}; +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/packet_typedef.hpp b/wust_vision-main/tasks/packet_typedef.hpp new file mode 100644 index 0000000..c934e77 --- /dev/null +++ b/wust_vision-main/tasks/packet_typedef.hpp @@ -0,0 +1,193 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace wust_vision { + +constexpr uint8_t ID_ROBOT_CMD = 0x01; +constexpr uint8_t ID_NAV_CMD = 0x02; + +constexpr uint8_t ID_AIM_INFO = 0X02; +constexpr uint8_t ID_REFEREE_INFO = 0X03; + +constexpr const char* TARGET_TOPIC = "vision_target"; +constexpr const char* NAV_STATE_TOPIC = "rose_state"; +constexpr const char* MODE_TOPIC = "sentry_mode"; +constexpr const char* ROBO_STATE_TOPIC = "robo_state"; +constexpr const char* GOAL_TOPIC = "rose_goal"; + +struct ReceiveAimINFO { + uint8_t cmd_ID; + uint32_t time_stamp; + + float yaw; + float pitch; + float roll; + + float yaw_vel; + float pitch_vel; + float roll_vel; + + float v_x; + float v_y; + float v_z; + + float bullet_speed; + uint8_t detect_color; // 0 red 1 blue +} __attribute__((packed)); + +struct ReceiveReferee { + uint8_t cmd_ID; + uint32_t time_stamp; + + float big_yaw_in_world; + int game_time; + int max_health; + int cur_health; + int cur_bullet; + + uint8_t center_state; +} __attribute__((packed)); +struct SendRobotCmdData { + uint8_t cmd_ID; + uint32_t time_stamp; + + uint8_t appear; + uint8_t shoot_rate = 3; + + float pitch; + float yaw; + + float target_yaw; + float target_pitch; + + float enable_yaw_diff; + float enable_pitch_diff; + + float v_yaw; + float v_pitch; + float a_yaw; + float a_pitch; + + uint8_t detect_color; +} __attribute__((packed)); + +constexpr uint8_t ID_NAV_CONTROL = 0; + +struct NavRobotCmdData { + uint8_t cmd_ID; + uint32_t time_stamp; + uint8_t packet_type; + + float vx; + float vy; + float wz; + +} __attribute__((packed)); + +struct SerialLogBuffer { + std::mutex mtx; + nlohmann::json j; + bool dirty = false; +}; + +inline SerialLogBuffer& getLogBuffer() { + static SerialLogBuffer buf; + return buf; +} + +inline void updateFPS(nlohmann::json& j) { + static int frame_count = 0; + static double fps = 0.0; + static auto last_time = std::chrono::steady_clock::now(); + + ++frame_count; + + auto now = std::chrono::steady_clock::now(); + double elapsed = std::chrono::duration(now - last_time).count(); + + if (elapsed >= 1.0) { + fps = frame_count / elapsed; + frame_count = 0; + last_time = now; + } + + j["fps"] = fps; +} + +inline void updateSerialLog(const ReceiveAimINFO& aim) { + auto& buf = getLogBuffer(); + std::lock_guard lock(buf.mtx); + + auto& j = buf.j["aim"]; + updateFPS(j); + j["timestamp"] = aim.time_stamp; + j["yaw"] = aim.yaw; + j["pitch"] = aim.pitch; + j["roll"] = aim.roll; + + j["yaw_vel"] = aim.yaw_vel; + j["pitch_vel"] = aim.pitch_vel; + j["roll_vel"] = aim.roll_vel; + + j["v_x"] = aim.v_x; + j["v_y"] = aim.v_y; + j["v_z"] = aim.v_z; + + j["bullet_speed"] = aim.bullet_speed; + j["detect_color"] = (aim.detect_color == 0 ? "Red" : "Blue"); + + buf.dirty = true; +} + +inline void updateSerialLog(const ReceiveReferee& ref) { + auto& buf = getLogBuffer(); + std::lock_guard lock(buf.mtx); + + auto& j = buf.j["referee"]; + updateFPS(j); + j["timestamp"] = ref.time_stamp; + j["big_yaw_in_world"] = ref.big_yaw_in_world; + j["game_time"] = ref.game_time; + j["max_health"] = ref.max_health; + j["cur_health"] = ref.cur_health; + j["cur_bullet"] = ref.cur_bullet; + j["center_state"] = ref.center_state; + + buf.dirty = true; +} + +inline void flushSerialLog() { + static auto last_flush = std::chrono::steady_clock::now(); + + auto& buf = getLogBuffer(); + + auto now = std::chrono::steady_clock::now(); + double dt = std::chrono::duration(now - last_flush).count(); + + // 控制写入频率(例如 20Hz) + if (dt < 0.05) + return; + + std::lock_guard lock(buf.mtx); + + if (!buf.dirty) + return; + + // 更新 FPS + + std::ofstream file("/dev/shm/serial_log.json"); + if (file.is_open()) { + file << buf.j.dump(2); + buf.dirty = false; + } + + last_flush = now; +} + +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/type_common.cpp b/wust_vision-main/tasks/type_common.cpp new file mode 100644 index 0000000..2e38514 --- /dev/null +++ b/wust_vision-main/tasks/type_common.cpp @@ -0,0 +1,66 @@ +#include "type_common.hpp" + +namespace wust_vision { +std::string enemyColorToString(EnemyColor color) noexcept { + switch (color) { + case EnemyColor::RED: + return "RED"; + break; + case EnemyColor::BLUE: + return "BLUE"; + break; + case EnemyColor::WHITE: + return "WHITE"; + break; + default: + return "UNKNOWN"; + } +} + +void AimTarget::tf(Eigen::Matrix4d T_camera_to_odom) noexcept { + const Eigen::Vector4d pos_camera(pos.x(), pos.y(), pos.z(), 1.0); + const Eigen::Vector4d pos_odom = T_camera_to_odom * pos_camera; + + pos.x() = pos_odom.x(); + pos.y() = pos_odom.y(); + pos.z() = pos_odom.z(); + const Eigen::Matrix3d R_camera_to_odom = T_camera_to_odom.block<3, 3>(0, 0); + const Eigen::Quaterniond q_camera(ori.w(), ori.x(), ori.y(), ori.z()); + const Eigen::Matrix3d R_ori_camera = q_camera.normalized().toRotationMatrix(); + + const Eigen::Matrix3d R_ori_odom = R_camera_to_odom * R_ori_camera; + const Eigen::Quaterniond q_odom(R_ori_odom); + + ori.w() = q_odom.w(); + ori.x() = q_odom.x(); + ori.y() = q_odom.y(); + ori.z() = q_odom.z(); +} +std::vector +AimTarget::toPts(const cv::Mat& camera_intrinsic, const cv::Mat& camera_distortion) noexcept { + std::vector pts; + if (pos.norm() < 0.5) { + return pts; + } + + const cv::Mat tvec = (cv::Mat_(3, 1) << pos.x(), pos.y(), pos.z()); + const Eigen::Matrix3d tf_rot = ori.toRotationMatrix(); + const cv::Mat rot_mat = + (cv::Mat_(3, 3) << tf_rot(0, 0), + tf_rot(0, 1), + tf_rot(0, 2), + tf_rot(1, 0), + tf_rot(1, 1), + tf_rot(1, 2), + tf_rot(2, 0), + tf_rot(2, 1), + tf_rot(2, 2)); + + cv::Mat rvec; + cv::Rodrigues(rot_mat, rvec); + + cv::projectPoints(AIM_TARGET_BLOCK, rvec, tvec, camera_intrinsic, camera_distortion, pts); + + return pts; +} +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/type_common.hpp b/wust_vision-main/tasks/type_common.hpp new file mode 100644 index 0000000..ef6a75a --- /dev/null +++ b/wust_vision-main/tasks/type_common.hpp @@ -0,0 +1,310 @@ +#pragma once +#include "3rdparty/angles.h" +#include "tasks/packet_typedef.hpp" +#include "wust_vl/common/utils/logger.hpp" +#include "wust_vl/common/utils/motion_buffer.hpp" +#include "wust_vl/common/utils/parameter.hpp" +#include "wust_vl/common/utils/trajectory_compensator.hpp" +#include "wust_vl/video/icamera.hpp" +#include +#include +#include +namespace wust_vision { +struct CommonFrame { + wust_vl::video::ImageFrame img_frame; + int id; + int detect_color; + cv::Rect expanded; + cv::Point2f offset = cv::Point2f(0, 0); + std::any any_ctx; +}; + +enum class EnemyColor { + RED = 0, + BLUE = 1, + WHITE = 2, +}; +std::string enemyColorToString(EnemyColor color) noexcept; +class InfantryMode { +public: + enum class AttackMode { ARMOR = 0, SMALL_RUNE, BIG_RUNE, UNKNOWN }; + static AttackMode toAttackMode(int value) noexcept { + switch (value) { + case 0: + return AttackMode::ARMOR; + case 1: + return AttackMode::SMALL_RUNE; + case 2: + return AttackMode::BIG_RUNE; + default: + return AttackMode::UNKNOWN; + } + } +}; +class HeroMode { +public: + enum class AttackMode { ARMOR = 0, SNIPER, UNKNOWN }; + static AttackMode toAttackMode(int value) noexcept { + switch (value) { + case 0: + return AttackMode::ARMOR; + case 1: + return AttackMode::SNIPER; + default: + return AttackMode::UNKNOWN; + } + } +}; + +struct CarMotion { + double yaw, pitch, roll; // 欧拉角 (rad) + double vyaw, vpitch, vroll; // 角速度 + double vx, vy, vz; // 线速度 + static double unwrap_angle(double prev, double curr) noexcept { + double d = curr - prev; + while (d > M_PI) { + curr -= 2.0 * M_PI; + d -= 2.0 * M_PI; + } + while (d < -M_PI) { + curr += 2.0 * M_PI; + d += 2.0 * M_PI; + } + return curr; + } + + // 角度插值(wrap-safe) + static double interp_angle(double a, double b, double t) noexcept { + double diff = b - a; + while (diff > M_PI) + diff -= 2.0 * M_PI; + while (diff < -M_PI) + diff += 2.0 * M_PI; + return a + diff * t; + } +}; +struct BigYaw { + double big_yaw; + static double unwrap_angle(double prev, double curr) noexcept { + double d = curr - prev; + while (d > M_PI) { + curr -= 2.0 * M_PI; + d -= 2.0 * M_PI; + } + while (d < -M_PI) { + curr += 2.0 * M_PI; + d += 2.0 * M_PI; + } + return curr; + } + + // 角度插值(wrap-safe) + static double interp_angle(double a, double b, double t) noexcept { + double diff = b - a; + while (diff > M_PI) + diff -= 2.0 * M_PI; + while (diff < -M_PI) + diff += 2.0 * M_PI; + return a + diff * t; + } +}; +struct VisionCtx { + std::shared_ptr> motion_buffer; + std::shared_ptr camera; + double communication_delay_μs; + int mode; +}; +static std::vector AIM_TARGET_BLOCK = { + { -0.025f, -0.025f, -0.025f }, // 0: 左下前 + { 0.025f, -0.025f, -0.025f }, // 1: 右下前 + { 0.025f, 0.025f, -0.025f }, // 2: 右上前 + { -0.025f, 0.025f, -0.025f }, // 3: 左上前 + { -0.025f, -0.025f, 0.025f }, // 4: 左下后 + { 0.025f, -0.025f, 0.025f }, // 5: 右下后 + { 0.025f, 0.025f, 0.025f }, // 6: 右上后 + { -0.025f, 0.025f, 0.025f } // 7: 左上后 +}; + +struct AimTarget { + bool valid; + Eigen::Vector3d pos = Eigen::Vector3d(0, 0, 0); + Eigen::Vector3d vel = Eigen::Vector3d(0, 0, 0); + + Eigen::Quaterniond ori; + std::vector armor_posandyaw; + + void tf(Eigen::Matrix4d T_camera_to_odom) noexcept; + std::vector + toPts(const cv::Mat& camera_intrinsic, const cv::Mat& camera_distortion) noexcept; +}; +struct GimbalCmd { + std::chrono::steady_clock::time_point timestamp; + double pitch = 0; + double yaw = 0; + double target_yaw = 0; + double target_pitch = 0; + double v_yaw = 0; + double v_pitch = 0; + double a_yaw = 0; + double a_pitch = 0; + double distance = -1; + bool fire_advice = false; + double enable_yaw_diff = 0; + double enable_pitch_diff = 0; + double fly_time = 0; + bool appear = false; + AimTarget aim_target; + inline bool isValid() const noexcept { + auto bad = [](double x) { return std::isnan(x) || std::isinf(x); }; + + if (bad(pitch) || bad(yaw) || bad(target_yaw) || bad(target_pitch) || bad(target_pitch) + || bad(v_yaw) || bad(v_pitch) || bad(distance) || bad(enable_yaw_diff) + || bad(enable_pitch_diff)) + return false; + + return true; + } + inline void noShoot() { + fire_advice = false; + enable_pitch_diff = 0; + enable_pitch_diff = 0; + } +}; + +struct AutoExposureCfg: wust_vl::common::utils::ParamGroup { + static constexpr const char* Logger = "Config: auto_exposure"; + static constexpr const char* kKey = "auto_exposure"; + const char* key() const override { + return kKey; + } + + using Ptr = std::shared_ptr; + AutoExposureCfg() {} + static Ptr create() { + return std::make_shared(); + } + GEN_PARAM(bool, enable); + GEN_PARAM(double, target_brightness); + GEN_PARAM(double, step_gain); + GEN_PARAM(double, decay_step); + GEN_PARAM(double, tolerance); + GEN_PARAM(double, exposure_min); + GEN_PARAM(double, exposure_max); + GEN_PARAM(double, control_interval_ms); + void loadSelf(const YAML::Node& node) override { + enable_param.load(node); + target_brightness_param.load(node); + step_gain_param.load(node); + decay_step_param.load(node); + tolerance_param.load(node); + exposure_min_param.load(node); + exposure_max_param.load(node); + control_interval_ms_param.load(node); + } +}; +struct TFConfig: wust_vl::common::utils::ParamGroup { +public: + static constexpr const char* kKey = "tf"; + static constexpr const char* Logger = "Config: common::tf"; + const char* key() const override { + return kKey; + } + using Ptr = std::shared_ptr; + TFConfig() {} + static Ptr create() { + return std::make_shared(); + } + + bool first_load = false; + + Eigen::Matrix3d R_camera2gimbal; + Eigen::Vector3d t_camera2gimbal; + void loadSelf(const YAML::Node& node) override { + if (!first_load) { + auto t_vec = node["t_camera2gimbal"].as>(); + if (t_vec.size() != 3) { + throw std::runtime_error("YAML tf.t_camera2gimbal must have 3 elements"); + } + t_camera2gimbal = Eigen::Vector3d(t_vec[0], t_vec[1], t_vec[2]); + + auto R_vec = node["R_camera2gimbal"].as>(); + if (R_vec.size() != 9) { + throw std::runtime_error("YAML tf.R_camera2gimbal must have 9 elements"); + } + R_camera2gimbal = + Eigen::Map>(R_vec.data()); + first_load = true; + } else { + } + } +}; +struct TrajectoryCompensatorConfig: public wust_vl::common::utils::ParamGroup { + static constexpr const char* Logger = "Config: auto_aim::trajectory_compensator"; + static constexpr const char* kKey = "trajectory_compensator"; + const char* key() const override { + return kKey; + } + using Ptr = std::shared_ptr; + TrajectoryCompensatorConfig() {} + static Ptr create() { + return std::make_shared(); + } + std::shared_ptr trajectory_compensator; + bool first_load = false; + void loadSelf(const YAML::Node& node) override + + { + if (!first_load) { + std::string comp_type = node["compenstator_type"].as("ideal"); + trajectory_compensator = + wust_vl::common::utils::CompensatorFactory::createCompensator(comp_type); + trajectory_compensator->load(node); + first_load = true; + } else { + } + } +}; +} // namespace wust_vision +template<> +struct wust_vl::common::utils::MotionTraits { + static void unwrap(const wust_vision::CarMotion& prev, wust_vision::CarMotion& curr) noexcept { + curr.yaw = wust_vision::CarMotion::unwrap_angle(prev.yaw, curr.yaw); + curr.pitch = wust_vision::CarMotion::unwrap_angle(prev.pitch, curr.pitch); + curr.roll = wust_vision::CarMotion::unwrap_angle(prev.roll, curr.roll); + // 速度部分不需要 unwrap + } + + static wust_vision::CarMotion interpolate( + const wust_vision::CarMotion& a, + const wust_vision::CarMotion& b, + double t + ) noexcept { + wust_vision::CarMotion out; + // 欧拉角 wrap-safe 插值 + out.yaw = wust_vision::CarMotion::interp_angle(a.yaw, b.yaw, t); + out.pitch = wust_vision::CarMotion::interp_angle(a.pitch, b.pitch, t); + out.roll = wust_vision::CarMotion::interp_angle(a.roll, b.roll, t); + + // 角速度和线速度线性插值 + out.vyaw = a.vyaw + (b.vyaw - a.vyaw) * t; + out.vpitch = a.vpitch + (b.vpitch - a.vpitch) * t; + out.vroll = a.vroll + (b.vroll - a.vroll) * t; + out.vx = a.vx + (b.vx - a.vx) * t; + out.vy = a.vy + (b.vy - a.vy) * t; + out.vz = a.vz + (b.vz - a.vz) * t; + return out; + } +}; +template<> +struct wust_vl::common::utils::MotionTraits { + static void unwrap(const wust_vision::BigYaw& prev, wust_vision::BigYaw& curr) noexcept { + curr.big_yaw = wust_vision::BigYaw::unwrap_angle(prev.big_yaw, curr.big_yaw); + } + + static wust_vision::BigYaw + interpolate(const wust_vision::BigYaw& a, const wust_vision::BigYaw& b, double t) noexcept { + wust_vision::BigYaw out; + out.big_yaw = wust_vision::BigYaw::interp_angle(a.big_yaw, b.big_yaw, t); + return out; + } +}; \ No newline at end of file diff --git a/wust_vision-main/tasks/utils/ascii_banner.hpp b/wust_vision-main/tasks/utils/ascii_banner.hpp new file mode 100644 index 0000000..2b1157e --- /dev/null +++ b/wust_vision-main/tasks/utils/ascii_banner.hpp @@ -0,0 +1,64 @@ +#pragma once +#include +#include +#include +namespace wust_vision { +constexpr std::array ascii_banner = { + R"( _ ____ _____________ _ ___________ ________ _ __ )", + R"(| | / / / / / ___/_ __/ | | / / _/ ___// _/ __ \/ | / /)", + R"(| | /| / / / / /\__ \ / / | | / // / \__ \ / // / / / |/ / )", + R"(| |/ |/ / /_/ /___/ // / | |/ // / ___/ // // /_/ / /| / )", + R"(|__/|__/\____//____//_/ |___/___//____/___/\____/_/ |_/ )", +}; +namespace { + struct RGB { + int r, g, b; + }; + + inline RGB hsv2rgb(float h, float s, float v) { + float c = v * s; + float x = c * (1 - std::fabs(std::fmod(h / 60.0f, 2) - 1)); + float m = v - c; + + float r = 0, g = 0, b = 0; + if (h < 60) { + r = c; + g = x; + } else if (h < 120) { + r = x; + g = c; + } else if (h < 180) { + g = c; + b = x; + } else if (h < 240) { + g = x; + b = c; + } else if (h < 300) { + r = x; + b = c; + } else { + r = c; + b = x; + } + + return { int((r + m) * 255), int((g + m) * 255), int((b + m) * 255) }; + } +} // namespace +inline void printBanner() { + constexpr const char* reset = "\033[0m"; + + for (const auto& line: ascii_banner) { + const int n = static_cast(std::string_view(line).size()); + + for (int i = 0; i < n; ++i) { + // hue 从 0° → 360° + float hue = 360.0f * i / std::max(1, n - 1); + auto rgb = hsv2rgb(hue, 1.0f, 1.0f); + + std::cout << "\033[38;2;" << rgb.r << ";" << rgb.g << ";" << rgb.b << "m" << line[i]; + } + std::cout << reset << '\n'; + } +} + +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/utils/config.hpp b/wust_vision-main/tasks/utils/config.hpp new file mode 100644 index 0000000..ab8a8b4 --- /dev/null +++ b/wust_vision-main/tasks/utils/config.hpp @@ -0,0 +1,11 @@ +#pragma once +namespace wust_vision { +constexpr const char* ML_CONFIG = "config/detect_ml.yaml"; +constexpr const char* OPENCV_CONFIG = "config/detect_opencv.yaml"; +constexpr const char* COMMON_CONFIG = "config/common.yaml"; +constexpr const char* CAMERA_CONFIG = "config/camera.yaml"; +constexpr const char* AUTO_AIM_CONFIG = "config/auto_aim.yaml"; +constexpr const char* AUTO_BUFF_CONFIG = "config/auto_buff.yaml"; +constexpr const char* AUTO_GUIDANCE_CONFIG = "config/auto_guidance.yaml"; +constexpr const char* AUTO_SNIPER_CONFIG = "config/auto_sniper.yaml"; +} // namespace wust_vision diff --git a/wust_vision-main/tasks/utils/debug_utils.hpp b/wust_vision-main/tasks/utils/debug_utils.hpp new file mode 100644 index 0000000..9d33c97 --- /dev/null +++ b/wust_vision-main/tasks/utils/debug_utils.hpp @@ -0,0 +1,154 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +namespace wust_vision { + +template +class LogsStream { +public: + LogsStream(const std::string& n) { + name = n; + } + void handleOnce(const T& t, nlohmann::json& j) { + log_data.push_back(t); + trim(); + insertData(j); + } + void push_back(const T& t) { + log_data.push_back(t); + } + void trim() { + while (log_data.size() > MAX_N) { + log_data.erase(log_data.begin()); + } + } + void insertData(nlohmann::json& j) { + j[name] = log_data; + } + void clear() { + log_data.clear(); + } + +private: + std::string name; + std::vector log_data; +}; +template +void drawDebugOverlayImpl( + const DebugT& dbg, + std::pair camera_info, + bool auto_fps, + DrawFn&& draw_fn, + OutputFn&& output_fn +) { + static auto last_show_time = std::chrono::steady_clock::now(); + + if (dbg.img_frame.src_img.empty()) + return; + + constexpr double min_interval_ms = 1000.0 / 30.0; + const auto now = std::chrono::steady_clock::now(); + + if (auto_fps + && std::chrono::duration(now - last_show_time).count() + < min_interval_ms) + return; + + last_show_time = now; + + cv::Mat debug_img; + if (dbg.img_frame.pixel_format == wust_vl::video::PixelFormat::GRAY) { + cv::cvtColor(dbg.img_frame.src_img, debug_img, cv::COLOR_GRAY2RGB); + } else if (dbg.img_frame.pixel_format == wust_vl::video::PixelFormat::BGR) { + cv::cvtColor(dbg.img_frame.src_img, debug_img, cv::COLOR_BGR2RGB); + } else { + debug_img = dbg.img_frame.src_img; + } + + if (debug_img.empty()) + return; + draw_fn(debug_img, dbg, camera_info); + + output_fn(debug_img); +} +inline auto writeToFile = [](const cv::Mat& img) { + cv::Mat bgr; + cv::cvtColor(img, bgr, cv::COLOR_RGB2BGR); + + std::vector buf; + cv::imencode(".jpg", bgr, buf); + + std::ofstream ofs("/dev/shm/debug_frame.jpg.tmp", std::ios::binary); + ofs.write(reinterpret_cast(buf.data()), buf.size()); + ofs.close(); + + std::rename("/dev/shm/debug_frame.jpg.tmp", "/dev/shm/debug_frame.jpg"); +}; +class ShmWriter { +public: + static constexpr size_t shm_max_size = 2 * 1024 * 1024; + + explicit ShmWriter(const char* name, mode_t mode = 0666) { + fd_ = shm_open(name, O_CREAT | O_RDWR, mode); + if (fd_ == -1) { + std::cerr << "[SHM] shm_open failed\n"; + return; + } + + if (ftruncate(fd_, shm_max_size) == -1) { + std::cerr << "[SHM] ftruncate failed\n"; + close(fd_); + fd_ = -1; + return; + } + + ptr_ = mmap(nullptr, shm_max_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd_, 0); + + if (ptr_ == MAP_FAILED) { + std::cerr << "[SHM] mmap failed\n"; + close(fd_); + fd_ = -1; + ptr_ = nullptr; + } + } + + ~ShmWriter() { + if (ptr_) + munmap(ptr_, shm_max_size); + if (fd_ != -1) + close(fd_); + } + + void operator()(const cv::Mat& img) const { + if (!ptr_) + return; + + static const std::vector jpeg_params = { cv::IMWRITE_JPEG_QUALITY, 75 }; + + std::vector buf; + cv::imencode(".jpg", img, buf, jpeg_params); + + if (buf.size() + 4 > shm_max_size) + return; + + uint32_t size = static_cast(buf.size()); + std::memcpy(ptr_, &size, 4); + std::memcpy(static_cast(ptr_) + 4, buf.data(), size); + } + +private: + int fd_ { -1 }; + void* ptr_ { nullptr }; +}; +inline auto showWindow(const char* win_name) { + return [win_name](const cv::Mat& img) { + cv::imshow(win_name, img); + cv::waitKey(1); + }; +} +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/utils/main_base.hpp b/wust_vision-main/tasks/utils/main_base.hpp new file mode 100644 index 0000000..b4e741e --- /dev/null +++ b/wust_vision-main/tasks/utils/main_base.hpp @@ -0,0 +1,99 @@ +#pragma once +#include "3rdparty/backward-cpp/backward.hpp" +#include "tasks/utils/ascii_banner.hpp" +#include "wust_vl/common/concurrency/monitored_thread.hpp" +#include "wust_vl/common/utils/signal.hpp" + +namespace wust_vision { +template +concept VisionLike = std::default_initializable && requires(T v, bool b) { + { + v.init(b) + } -> std::same_as; + { + v.start() + } -> std::same_as; + { + v.checkStateMatchMode() + } -> std::same_as; +}; + +template +inline int runVisionMain(int argc, char** argv) { + printBanner(); + bool debug = false; + if (argc > 1) { + const std::string firstArg = argv[1]; + debug = (firstArg == "true" || firstArg == "1"); + std::cout << "debug: " << firstArg << std::endl; + } + std::set_terminate([]() { + std::cerr << "Uncaught exception, terminating program.\n"; + if (auto e = std::current_exception()) { + try { + std::rethrow_exception(e); + } catch (const std::exception& ex) { + std::cerr << "Exception: " << ex.what() << std::endl; + } catch (...) { + std::cerr << "Unknown exception" << std::endl; + } + } + std::abort(); + }); + + try { + int exit_code = 0; + + { + T v; + v.init(debug); + std::cout << "Starting program..." << std::endl; + v.start(); + + wust_vl::common::utils::SignalHandler sig; + sig.start([&] {}); + + bool exit_flag = false; + + while (!sig.shouldExit() && !exit_flag) { + wust_vl::common::concurrency::ThreadManager::instance().printStatus(); + const auto all_status = + wust_vl::common::concurrency::ThreadManager::instance().getAllThreadStatuses(); + v.checkStateMatchMode(); + // for (auto& status: all_status) { + // if (status.second + // == wust_vl::common::concurrency::MonitoredThread::Status::Hung) { + // std::cerr << status.first << " is Hunging! Exiting program..." << std::endl; + // exit_code = -1; + // std::exit(exit_code); + + // break; + // } + // } + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + } + + std::cout << "Exiting program..." << std::endl; + return exit_code; + + } catch (const std::exception& e) { + std::cerr << "Caught exception in main: " << e.what() << "\n"; + throw; + return -1; + } catch (...) { + std::cerr << "Unknown exception caught in main!\n"; + return -1; + } +} +} // namespace wust_vision + +#define VISION_MAIN(VISION_TYPE) \ + int main(int argc, char** argv) { \ + return wust_vision::runVisionMain(argc, argv); \ + } + +#define ENABLE_BACKWARD() \ + namespace backward { \ + static backward::SignalHandling sh; \ + } \ No newline at end of file diff --git a/wust_vision-main/tasks/utils/sinple_img_rotate_saver.hpp b/wust_vision-main/tasks/utils/sinple_img_rotate_saver.hpp new file mode 100644 index 0000000..4ed505f --- /dev/null +++ b/wust_vision-main/tasks/utils/sinple_img_rotate_saver.hpp @@ -0,0 +1,160 @@ +#pragma once +#include "Eigen/Dense" +#include "opencv2/opencv.hpp" +#include "wust_vl/common/utils/recorder.hpp" +#include "wust_vl/common/utils/timer.hpp" +namespace wust_vision { +class ImgWriter: public wust_vl::common::utils::Writer { +public: + explicit ImgWriter( + const std::filesystem::path& video_path, + int fps = 30, + int fourcc = cv::VideoWriter::fourcc('a', 'v', 'c', '1') + ): + path_(video_path), + fps_(fps), + fourcc_(fourcc) {} + + ~ImgWriter() override { + release(); + } + + void write(std::ostream&, const cv::Mat& frame) override { + if (frame.empty()) + return; + + if (!initialized_) { + initVideoWriter(frame.size()); + } + + cv::Mat bgr; + if (frame.channels() == 3) + cv::cvtColor(frame, bgr, cv::COLOR_RGB2BGR); + else + cv::cvtColor(frame, bgr, cv::COLOR_GRAY2BGR); + + writer_.write(bgr); + } + +private: + void initVideoWriter(const cv::Size& frame_size) { + if (std::filesystem::exists(path_.parent_path()) == false) + std::filesystem::create_directories(path_.parent_path()); + + writer_.open(path_.string(), fourcc_, fps_, frame_size, true); + if (!writer_.isOpened()) { + throw std::runtime_error("Failed to open video writer for " + path_.string()); + } + initialized_ = true; + + std::cout << "[ImgWriter] Started writing video: " << path_ << std::endl; + } + + void release() { + if (initialized_) { + writer_.release(); + std::cout << "[ImgWriter] Video saved: " << path_ << std::endl; + initialized_ = false; + } + } + +private: + std::filesystem::path path_; + cv::VideoWriter writer_; + int fps_; + int fourcc_; + bool initialized_ = false; +}; + +class RotateWriterCSV: public wust_vl::common::utils::Writer { +public: + using RotateWriterCSVPtr = std::shared_ptr; + inline RotateWriterCSVPtr makeShared(bool write_header = true) { + return std::make_shared(write_header); + } + explicit RotateWriterCSV(bool write_header = true): first_write_(write_header) {} + + void write(std::ostream& os, const Eigen::Vector3d& data) override { + double now = wust_vl::common::utils::time_utils::sinceProgramStartSec(); + if (first_write_) { + os << "time,yaw,pitch,roll\n"; + first_write_ = false; + } + + os << std::fixed << std::setprecision(6) << now << "," << data[0] << "," << data[1] << "," + << data[2] << "\n"; + os.flush(); + } + +private: + bool first_write_; +}; +class RotateReaderCSV { +public: + using RotateReaderCSVPtr = std::shared_ptr; + struct Record { + double time; + Eigen::Vector3d ypr; + }; + + explicit RotateReaderCSV(const std::string& csv_path, double speed = 1.0): + csv_path_(csv_path), + speed_(speed) { + loadCSV(); + } + + /// 播放记录:每条数据 sleep 对齐真实时间间隔 + void replay(std::function callback) const { + if (records_.empty()) + return; + + for (size_t i = 0; i < records_.size(); ++i) { + const auto& rec = records_[i]; + callback(rec.ypr); + + if (i + 1 < records_.size()) { + double dt = (records_[i + 1].time - rec.time) / speed_; + if (dt > 0.0) + std::this_thread::sleep_for(std::chrono::duration(dt)); + } + } + } + +private: + void loadCSV() { + std::ifstream file(csv_path_); + if (!file.is_open()) { + throw std::runtime_error("Failed to open CSV: " + csv_path_); + } + + std::string line; + bool header_skipped = false; + while (std::getline(file, line)) { + if (!header_skipped) { // 跳过表头 + header_skipped = true; + continue; + } + std::istringstream ss(line); + std::string token; + Record rec; + + std::getline(ss, token, ','); + rec.time = std::stod(token); + + for (int i = 0; i < 3; ++i) { + std::getline(ss, token, ','); + rec.ypr[i] = std::stod(token); + } + + records_.push_back(rec); + } + std::cout << "[RotateReaderCSV] Loaded " << records_.size() << " records from " << csv_path_ + << std::endl; + } + +private: + std::string csv_path_; + std::vector records_; + double speed_ = 1.0; // 1.0=原速, 2.0=2倍速, 0.5=半速 +}; +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/utils/utils.cpp b/wust_vision-main/tasks/utils/utils.cpp new file mode 100644 index 0000000..50ebddc --- /dev/null +++ b/wust_vision-main/tasks/utils/utils.cpp @@ -0,0 +1,536 @@ +#include "utils.hpp" +#include +#include +#include +#include +// util functions +namespace wust_vision { +namespace utils { + double limit_rad(double angle) noexcept { + while (angle > M_PI) + angle -= 2.0 * M_PI; + while (angle < -M_PI) + angle += 2.0 * M_PI; + return angle; + } + + Eigen::Vector3d + quatToEuler(const Eigen::Quaterniond& q, int axis0, int axis1, int axis2, bool extrinsic) { + if (!extrinsic) + std::swap(axis0, axis2); + + auto i = axis0, j = axis1, k = axis2; + bool is_proper = (i == k); + if (is_proper) + k = 3 - i - j; + int sign = (i - j) * (j - k) * (k - i) / 2; + + double a, b, c, d; + const Eigen::Vector4d xyzw = q.coeffs(); // [x,y,z,w] + if (is_proper) { + a = xyzw[3]; + b = xyzw[i]; + c = xyzw[j]; + d = xyzw[k] * sign; + } else { + a = xyzw[3] - xyzw[j]; + b = xyzw[i] + xyzw[k] * sign; + c = xyzw[j] + xyzw[3]; + d = xyzw[k] * sign - xyzw[i]; + } + + Eigen::Vector3d eulers; + const double n2 = a * a + b * b + c * c + d * d; + eulers[1] = std::acos(2 * (a * a + b * b) / n2 - 1); + + const double half_sum = std::atan2(b, a); + const double half_diff = std::atan2(-d, c); + + const double eps = 1e-7; + const bool safe1 = std::abs(eulers[1]) >= eps; + const bool safe2 = std::abs(eulers[1] - M_PI) >= eps; + const bool safe = safe1 && safe2; + + if (safe) { + eulers[0] = half_sum + half_diff; + eulers[2] = half_sum - half_diff; + } else { + if (!extrinsic) { + eulers[0] = 0; + eulers[2] = !safe1 ? 2 * half_sum : -2 * half_diff; + } else { + eulers[2] = 0; + eulers[0] = !safe1 ? 2 * half_sum : 2 * half_diff; + } + } + + for (int idx = 0; idx < 3; idx++) + eulers[idx] = limit_rad(eulers[idx]); + + if (!is_proper) { + eulers[2] *= sign; + eulers[1] -= M_PI / 2; + } + + if (!extrinsic) + std::swap(eulers[0], eulers[2]); + + return eulers; + } + + Eigen::Quaterniond + eulerToQuat(const Eigen::Vector3d& euler, int axis0, int axis1, int axis2, bool extrinsic) { + const double rz = euler[0], ry = euler[1], rx = euler[2]; + const Eigen::Quaterniond qx(Eigen::AngleAxisd(rx, Eigen::Vector3d::UnitX())); + const Eigen::Quaterniond qy(Eigen::AngleAxisd(ry, Eigen::Vector3d::UnitY())); + const Eigen::Quaterniond qz(Eigen::AngleAxisd(rz, Eigen::Vector3d::UnitZ())); + + if (!extrinsic) + std::swap(axis0, axis2); + Eigen::Quaterniond q; + + // 生成四元数 + if (axis0 == 0 && axis1 == 1 && axis2 == 2) + q = qx * qy * qz; + else if (axis0 == 0 && axis1 == 2 && axis2 == 1) + q = qx * qz * qy; + else if (axis0 == 1 && axis1 == 0 && axis2 == 2) + q = qy * qx * qz; + else if (axis0 == 1 && axis1 == 2 && axis2 == 0) + q = qy * qz * qx; + else if (axis0 == 2 && axis1 == 0 && axis2 == 1) + q = qz * qx * qy; + else if (axis0 == 2 && axis1 == 1 && axis2 == 0) + q = qz * qy * qx; + else + throw std::invalid_argument("Unsupported axis order"); + + return q; + } + + Eigen::Matrix3d + eulerToMatrix(const Eigen::Vector3d& euler, int axis0, int axis1, int axis2, bool extrinsic) { + return eulerToQuat(euler, axis0, axis1, axis2, extrinsic).toRotationMatrix(); + } + + Eigen::Vector3d + matrixToEuler(const Eigen::Matrix3d& R, int axis0, int axis1, int axis2, bool extrinsic) { + Eigen::Quaterniond q(R); + return quatToEuler(q, axis0, axis1, axis2, extrinsic); + } + + Eigen::Vector3d quatToEuler(const Eigen::Quaterniond& q, EulerOrder order, bool extrinsic) { + switch (order) { + case EulerOrder::XYZ: + return quatToEuler(q, 0, 1, 2, extrinsic); + case EulerOrder::XZY: + return quatToEuler(q, 0, 2, 1, extrinsic); + case EulerOrder::YXZ: + return quatToEuler(q, 1, 0, 2, extrinsic); + case EulerOrder::YZX: + return quatToEuler(q, 1, 2, 0, extrinsic); + case EulerOrder::ZXY: + return quatToEuler(q, 2, 0, 1, extrinsic); + case EulerOrder::ZYX: + return quatToEuler(q, 2, 1, 0, extrinsic); + default: + throw std::invalid_argument("Unsupported EulerOrder"); + } + } + + Eigen::Quaterniond eulerToQuat(const Eigen::Vector3d& euler, EulerOrder order, bool extrinsic) { + switch (order) { + case EulerOrder::XYZ: + return eulerToQuat(euler, 0, 1, 2, extrinsic); + case EulerOrder::XZY: + return eulerToQuat(euler, 0, 2, 1, extrinsic); + case EulerOrder::YXZ: + return eulerToQuat(euler, 1, 0, 2, extrinsic); + case EulerOrder::YZX: + return eulerToQuat(euler, 1, 2, 0, extrinsic); + case EulerOrder::ZXY: + return eulerToQuat(euler, 2, 0, 1, extrinsic); + case EulerOrder::ZYX: + return eulerToQuat(euler, 2, 1, 0, extrinsic); + default: + throw std::invalid_argument("Unsupported EulerOrder"); + } + } + + Eigen::Matrix3d eulerToMatrix(const Eigen::Vector3d& euler, EulerOrder order, bool extrinsic) { + return eulerToQuat(euler, order, extrinsic).toRotationMatrix(); + } + + Eigen::Vector3d matrixToEuler(const Eigen::Matrix3d& R, EulerOrder order, bool extrinsic) { + return quatToEuler(Eigen::Quaterniond(R), order, extrinsic); + } + + Eigen::MatrixXd cvToEigen(const cv::Mat& cv_mat) noexcept { + Eigen::MatrixXd eigen_mat = Eigen::MatrixXd::Zero(cv_mat.rows, cv_mat.cols); + cv::cv2eigen(cv_mat, eigen_mat); + return eigen_mat; + } + Eigen::Vector3d transformPosition( + const Eigen::Vector3d& pos_camera, + const Eigen::Matrix4d& T_camera_to_odom + ) noexcept { + const Eigen::Vector4d pos_homo(pos_camera.x(), pos_camera.y(), pos_camera.z(), 1.0); + const Eigen::Vector4d pos_odom = T_camera_to_odom * pos_homo; + return pos_odom.head<3>(); + } + + Eigen::Quaterniond transformOrientation( + const Eigen::Quaterniond& q_camera, + const Eigen::Matrix4d& T_camera_to_odom + ) noexcept { + const Eigen::Matrix3d R_camera_to_odom = T_camera_to_odom.block<3, 3>(0, 0); + const Eigen::Matrix3d R_ori_camera = q_camera.normalized().toRotationMatrix(); + const Eigen::Matrix3d R_ori_odom = R_camera_to_odom * R_ori_camera; + return Eigen::Quaterniond(R_ori_odom).normalized(); + } + void pnpToEigen( + const cv::Mat& rvec, + const cv::Mat& tvec, + Eigen::Vector3d& t_out, + Eigen::Quaterniond& q_out + ) noexcept { + // 平移 + cv::cv2eigen(tvec, t_out); + + // 旋转 + cv::Mat R_cv; + cv::Rodrigues(rvec, R_cv); + Eigen::Matrix3d R; + cv::cv2eigen(R_cv, R); + + q_out = Eigen::Quaterniond(R).normalized(); + } + void pnpToEigen( + const cv::Vec3d& rvec, + const cv::Vec3d& tvec, + Eigen::Vector3d& t_out, + Eigen::Quaterniond& q_out + ) noexcept { + // 平移 + t_out = Eigen::Vector3d(tvec[0], tvec[1], tvec[2]); + + // Rodrigues 旋转向量转矩阵 + cv::Mat rvec_mat(rvec); // cv::Vec3d -> 3x1 Mat + cv::Mat R_cv; + cv::Rodrigues(rvec_mat, R_cv); + + Eigen::Matrix3d R; + cv::cv2eigen(R_cv, R); + + q_out = Eigen::Quaterniond(R).normalized(); + } + + cv::Point2f computeCenter(const std::vector& points) noexcept { + if (points.empty()) { + return cv::Point2f(0.f, 0.f); + } + + float sum_x = 0.f; + float sum_y = 0.f; + for (const auto& pt: points) { + sum_x += pt.x; + sum_y += pt.y; + } + return cv::Point2f(sum_x / points.size(), sum_y / points.size()); + } + bool isStateValid(const Eigen::VectorXd& state) noexcept { + return state.allFinite(); // 所有元素都不是 NaN 或 Inf + } + Eigen::Matrix4d computeCameraToOdomTransform( + const Eigen::Matrix3d& R_gimbal2odom, + const Eigen::Matrix3d& R_camera_to_gimbal, + const Eigen::Vector3d& t_camera_to_gimbal + ) noexcept { + Eigen::Matrix4d T_gimbal_to_odom = Eigen::Matrix4d::Identity(); + T_gimbal_to_odom.block<3, 3>(0, 0) = R_gimbal2odom; + + Eigen::Matrix4d T_camera_to_gimbal = Eigen::Matrix4d::Identity(); + T_camera_to_gimbal.block<3, 3>(0, 0) = R_camera_to_gimbal; + T_camera_to_gimbal.block<3, 1>(0, 3) = t_camera_to_gimbal; + + const Eigen::Matrix4d T_camera_to_odom = T_gimbal_to_odom * T_camera_to_gimbal; + + return T_camera_to_odom; + } + + void addVelFromAccDt(Eigen::Vector3d& vel, const Eigen::Vector3d& acc, double dt) noexcept { + vel.x() += acc.x() * dt; + vel.y() += acc.y() * dt; + vel.z() += acc.z() * dt; + } + void addPosFromVelDt(Eigen::Vector3d& pos, const Eigen::Vector3d& vel, double dt) noexcept { + pos.x() += vel.x() * dt; + pos.y() += vel.y() * dt; + pos.z() += vel.z() * dt; + } + template + bool tryGetValue(const YAML::Node& node, const char* key, T& out_val) { + if (!node[key]) { + return false; + } + try { + out_val = node[key].template as(); + return true; + } catch (...) { + return false; + } + } + void changeFileOwner(const std::string& filepath, const std::string& username) { + struct passwd* pwd = getpwnam(username.c_str()); + if (pwd == nullptr) { + perror("getpwnam failed"); + return; + } + const uid_t uid = pwd->pw_uid; + const gid_t gid = pwd->pw_gid; + + if (chown(filepath.c_str(), uid, gid) != 0) { + perror("chown failed"); + } + } + std::string getOriginalUsername() { + const char* sudo_user = std::getenv("SUDO_USER"); + if (sudo_user) { + return std::string(sudo_user); + } + const uid_t uid = getuid(); + struct passwd* pw = getpwuid(uid); + if (pw) { + return std::string(pw->pw_name); + } + return ""; + } + bool setThreadAffinityAndPriority( + std::thread& thread, + int cpu_id, + int priority, + bool use_sched_fifo + ) { +#ifdef __linux__ + pthread_t native = thread.native_handle(); + + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + CPU_SET(cpu_id, &cpuset); + if (pthread_setaffinity_np(native, sizeof(cpu_set_t), &cpuset) != 0) { + perror("pthread_setaffinity_np failed"); + return false; + } + + sched_param sch_params; + sch_params.sched_priority = priority; + int policy = use_sched_fifo ? SCHED_FIFO : SCHED_RR; + if (pthread_setschedparam(native, policy, &sch_params) != 0) { + perror("pthread_setschedparam failed"); + return false; + } + + return true; + +#elif defined(_WIN32) || defined(_WIN64) + HANDLE native = (HANDLE)thread.native_handle(); + + DWORD_PTR affinityMask = 1ULL << cpu_id; + if (SetThreadAffinityMask(native, affinityMask) == 0) { + return false; + } + + int win_priority = THREAD_PRIORITY_HIGHEST; // you can map `priority` if needed + if (!SetThreadPriority(native, win_priority)) { + return false; + } + + return true; + +#else + // Unsupported platform + (void)thread; + (void)cpu_id; + (void)priority; + (void)use_sched_fifo; + return false; +#endif + } + double rad2deg(double rad) noexcept { + return rad * 180.0 / M_PI; + } + double deg2rad(double deg) noexcept { + return deg * M_PI / 180.0; + } + + std::tuple xyz2ypd_rad(double x, double y, double z) noexcept { + const double distance = std::sqrt(x * x + y * y + z * z); + const double yaw = std::atan2(y, x); + const double pitch = std::atan2(z, std::sqrt(x * x + y * y)); + return std::make_tuple(yaw, pitch, distance); + } + + std::tuple + ypd2xyz_rad(double yaw, double pitch, double distance) noexcept { + const double x = distance * std::cos(pitch) * std::cos(yaw); + const double y = distance * std::cos(pitch) * std::sin(yaw); + const double z = distance * std::sin(pitch); + return std::make_tuple(x, y, z); + } + + std::tuple xyz2ypd_deg(double x, double y, double z) noexcept { + const double distance = std::sqrt(x * x + y * y + z * z); + const double yaw = std::atan2(y, x); + const double pitch = std::atan2(z, std::sqrt(x * x + y * y)); + return std::make_tuple(rad2deg(yaw), rad2deg(pitch), distance); + } + + std::tuple + ypd2xyz_deg(double yaw_deg, double pitch_deg, double distance) noexcept { + const double yaw = deg2rad(yaw_deg); + const double pitch = deg2rad(pitch_deg); + const double x = distance * std::cos(pitch) * std::cos(yaw); + const double y = distance * std::cos(pitch) * std::sin(yaw); + const double z = distance * std::sin(pitch); + return std::make_tuple(x, y, z); + } + Eigen::Vector3d xyz2ypd(const Eigen::Vector3d& xyz) noexcept { + const auto x = xyz[0], y = xyz[1], z = xyz[2]; + const auto yaw = std::atan2(y, x); + const auto pitch = std::atan2(z, std::sqrt(x * x + y * y)); + const auto distance = std::sqrt(x * x + y * y + z * z); + return { yaw, pitch, distance }; + } + template + Point getCenter(const std::vector& points) noexcept { + if (points.empty()) + return Point(); + return std::accumulate(points.begin(), points.end(), Point()) + / static_cast(points.size()); + } + bool segmentIntersection( + const cv::Point2f& a1, + const cv::Point2f& a2, + const cv::Point2f& b1, + const cv::Point2f& b2, + cv::Point2f& intersection + ) { + const cv::Point2f r = a2 - a1; + const cv::Point2f s = b2 - b1; + const float rxs = r.x * s.y - r.y * s.x; + const float qpxr = (b1 - a1).x * r.y - (b1 - a1).y * r.x; + + if (fabs(rxs) < 1e-6) + return false; // 平行或重叠,无唯一交点 + + const float t = ((b1 - a1).x * s.y - (b1 - a1).y * s.x) / rxs; + const float u = qpxr / rxs; + + if (t >= 0 && t <= 1 && u >= 0 && u <= 1) { + intersection = a1 + t * r; + return true; + } + return false; + } + + + std::vector intersectLineRotatedRect( + const cv::RotatedRect& rect, + const cv::Point2f& line_p1, + const cv::Point2f& line_p2 + ) { + std::vector intersections; + + cv::Point2f vertices[4]; + rect.points(vertices); + + for (int i = 0; i < 4; i++) { + const cv::Point2f p1 = vertices[i]; + const cv::Point2f p2 = vertices[(i + 1) % 4]; + cv::Point2f inter; + + if (segmentIntersection(line_p1, line_p2, p1, p2, inter)) { + intersections.push_back(inter); + } + } + + return intersections; + } + std::string makeTimestampedFileName() { + using namespace std::chrono; + const auto now = system_clock::now(); + const auto time_t_now = system_clock::to_time_t(now); + const auto ms = duration_cast(now.time_since_epoch()) % 1000; + + std::tm tm_now {}; +#if defined(_MSC_VER) + localtime_s(&tm_now, &time_t_now); +#else + localtime_r(&time_t_now, &tm_now); +#endif + + std::ostringstream oss; + oss << std::put_time(&tm_now, "%Y%m%d_%H%M%S") << "_" << std::setfill('0') << std::setw(3) + << ms.count(); + return oss.str(); + } + + std::string expandEnv(const std::string& s) { + std::regex env_re(R"(\$\{([^}]+)\})"); + std::smatch match; + std::string result = s; + while (std::regex_search(result, match, env_re)) { + const char* env = std::getenv(match[1].str().c_str()); + std::string val = env ? env : ""; + result.replace(match.position(0), match.length(0), val); + } + return result; + } + double computeBrightness(const cv::Mat& frame) { + cv::Mat gray; + cv::cvtColor(frame, gray, cv::COLOR_BGR2GRAY); + return cv::mean(gray)[0]; + } + cv::Mat letterbox( + const cv::Mat& img, + Eigen::Matrix3f& transform_matrix, + const int new_shape_w, + const int new_shape_h + ) noexcept { + const int img_h = img.rows; + const int img_w = img.cols; + + const float scale = std::min((float)new_shape_h / img_h, (float)new_shape_w / img_w); + const int resize_h = int(img_h * scale + 0.5f); + const int resize_w = int(img_w * scale + 0.5f); + + const int pad_h = new_shape_h - resize_h; + const int pad_w = new_shape_w - resize_w; + const int top = pad_h / 2; + const int left = pad_w / 2; + + cv::Mat resized; + cv::resize(img, resized, cv::Size(resize_w, resize_h), 0, 0, cv::INTER_LINEAR); + + cv::Mat out; + cv::copyMakeBorder( + resized, + out, + top, + pad_h - top, + left, + pad_w - left, + cv::BORDER_CONSTANT, + cv::Scalar(114, 114, 114) + ); + + const float inv_scale = 1.0f / scale; + + transform_matrix << inv_scale, 0, -left * inv_scale, 0, inv_scale, -top * inv_scale, 0, 0, + 1; + + return out; + } + +} // namespace utils +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/utils/utils.hpp b/wust_vision-main/tasks/utils/utils.hpp new file mode 100644 index 0000000..871ffcd --- /dev/null +++ b/wust_vision-main/tasks/utils/utils.hpp @@ -0,0 +1,213 @@ +// Created by Chengfu Zou on 2024.1.19 +// Copyright(C) FYT Vision Group. All rights resevred. +// Copyright 2025 XiaoJian Wu +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "3rdparty/angles.h" +#include +#include +#include +#include +#include +namespace wust_vision { +// util functions +namespace utils { + // Convert euler angles to rotation matrix + enum class EulerOrder { XYZ, XZY, YXZ, YZX, ZXY, ZYX }; + + double limit_rad(double angle) noexcept; + + Eigen::Vector3d quatToEuler( + const Eigen::Quaterniond& q, + int axis0, + int axis1, + int axis2, + bool extrinsic = true + ); + Eigen::Quaterniond eulerToQuat( + const Eigen::Vector3d& euler, + int axis0, + int axis1, + int axis2, + bool extrinsic = true + ); + + Eigen::Matrix3d eulerToMatrix( + const Eigen::Vector3d& euler, + int axis0, + int axis1, + int axis2, + bool extrinsic = true + ); + + Eigen::Vector3d + matrixToEuler(const Eigen::Matrix3d& R, int axis0, int axis1, int axis2, bool extrinsic = true); + + Eigen::Vector3d + quatToEuler(const Eigen::Quaterniond& q, EulerOrder order, bool extrinsic = true); + + Eigen::Quaterniond + eulerToQuat(const Eigen::Vector3d& euler, EulerOrder order, bool extrinsic = true); + Eigen::Matrix3d + eulerToMatrix(const Eigen::Vector3d& euler, EulerOrder order, bool extrinsic = true); + Eigen::Vector3d + matrixToEuler(const Eigen::Matrix3d& R, EulerOrder order, bool extrinsic = true); + + Eigen::MatrixXd cvToEigen(const cv::Mat& cv_mat) noexcept; + Eigen::Vector3d transformPosition( + const Eigen::Vector3d& pos_camera, + const Eigen::Matrix4d& T_camera_to_odom + ) noexcept; + + Eigen::Quaterniond transformOrientation( + const Eigen::Quaterniond& q_camera, + const Eigen::Matrix4d& T_camera_to_odom + ) noexcept; + void pnpToEigen( + const cv::Mat& rvec, + const cv::Mat& tvec, + Eigen::Vector3d& t_out, + Eigen::Quaterniond& q_out + ) noexcept; + void pnpToEigen( + const cv::Vec3d& rvec, + const cv::Vec3d& tvec, + Eigen::Vector3d& t_out, + Eigen::Quaterniond& q_out + ) noexcept; + + cv::Point2f computeCenter(const std::vector& points) noexcept; + bool isStateValid(const Eigen::VectorXd& state) noexcept; + Eigen::Matrix4d computeCameraToOdomTransform( + const Eigen::Matrix3d& R_gimbal2odom, + const Eigen::Matrix3d& R_camera_to_gimbal, + const Eigen::Vector3d& t_camera_to_gimbal + ) noexcept; + + void addVelFromAccDt(Eigen::Vector3d& vel, const Eigen::Vector3d& acc, double dt) noexcept; + void addPosFromVelDt(Eigen::Vector3d& pos, const Eigen::Vector3d& vel, double dt) noexcept; + template + bool tryGetValue(const YAML::Node& node, const char* key, T& out_val); + void changeFileOwner(const std::string& filepath, const std::string& username); + std::string getOriginalUsername(); + bool setThreadAffinityAndPriority( + std::thread& thread, + int cpu_id, + int priority, + bool use_sched_fifo + ); + double rad2deg(double rad) noexcept; + double deg2rad(double deg) noexcept; + std::tuple xyz2ypd_rad(double x, double y, double z) noexcept; + + std::tuple + ypd2xyz_rad(double yaw, double pitch, double distance) noexcept; + + std::tuple xyz2ypd_deg(double x, double y, double z) noexcept; + std::tuple + ypd2xyz_deg(double yaw_deg, double pitch_deg, double distance) noexcept; + Eigen::Vector3d xyz2ypd(const Eigen::Vector3d& xyz) noexcept; + template + Point getCenter(const std::vector& points) noexcept; + bool segmentIntersection( + const cv::Point2f& a1, + const cv::Point2f& a2, + const cv::Point2f& b1, + const cv::Point2f& b2, + cv::Point2f& intersection + ); + std::vector intersectLineRotatedRect( + const cv::RotatedRect& rect, + const cv::Point2f& line_p1, + const cv::Point2f& line_p2 + ); + std::string makeTimestampedFileName(); + std::string expandEnv(const std::string& s); + double computeBrightness(const cv::Mat& frame); + cv::Mat letterbox( + const cv::Mat& img, + Eigen::Matrix3f& transform_matrix, + const int new_shape_w, + const int new_shape_h + ) noexcept; + template + void XSecOnce(Func&& func, double dt) noexcept { + static auto last_call = std::chrono::steady_clock::now(); + + const auto now = std::chrono::steady_clock::now(); + const double elapsed = std::chrono::duration(now - last_call).count(); + + if (elapsed >= dt) { + last_call = now; + func(); + } + } + template + concept Point2DLike = requires(T p) { + { + p.x + } -> std::convertible_to; + { + p.y + } -> std::convertible_to; + T { 0.f, 0.f }; + }; + template + [[nodiscard]] inline T transformPoint2D(const Eigen::Matrix3f& H, const T& p) noexcept { + const Eigen::Vector3f hp { p.x, p.y, 1.f }; + const Eigen::Vector3f tp = H * hp; + return { tp.x(), tp.y() }; + } + inline cv::Rect2f transformRect(const Eigen::Matrix3f& H, const cv::Rect2f& rect) { + cv::Point2f p1(rect.x, rect.y); + cv::Point2f p2(rect.x + rect.width, rect.y); + cv::Point2f p3(rect.x, rect.y + rect.height); + cv::Point2f p4(rect.x + rect.width, rect.y + rect.height); + + auto tp1 = utils::transformPoint2D(H, p1); + auto tp2 = utils::transformPoint2D(H, p2); + auto tp3 = utils::transformPoint2D(H, p3); + auto tp4 = utils::transformPoint2D(H, p4); + + float min_x = std::min({ tp1.x, tp2.x, tp3.x, tp4.x }); + float min_y = std::min({ tp1.y, tp2.y, tp3.y, tp4.y }); + float max_x = std::max({ tp1.x, tp2.x, tp3.x, tp4.x }); + float max_y = std::max({ tp1.y, tp2.y, tp3.y, tp4.y }); + + return cv::Rect2f(min_x, min_y, max_x - min_x, max_y - min_y); + } + template + [[nodiscard]] double orientationToYaw(const Eigen::Quaterniond& q) noexcept { + static double last_yaw = 0; + double roll, pitch, yaw; + const Eigen::Vector3d euler = utils::quatToEuler(q, utils::EulerOrder::ZYX, false); + yaw = euler[0]; + yaw = last_yaw + angles::shortest_angular_distance(last_yaw, yaw); + last_yaw = yaw; + return yaw; + } + template + [[nodiscard]] double orientationToRoll(const Eigen::Quaterniond& q) noexcept { + static double last_roll = 0; + double roll, pitch, yaw; + Eigen::Vector3d euler = utils::quatToEuler(q, utils::EulerOrder::ZYX, false); + roll = euler[2]; + roll = last_roll + angles::shortest_angular_distance(last_roll, roll); + last_roll = roll; + return roll; + } +} // namespace utils +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/tasks/vision_base.hpp b/wust_vision-main/tasks/vision_base.hpp new file mode 100644 index 0000000..e09e4ac --- /dev/null +++ b/wust_vision-main/tasks/vision_base.hpp @@ -0,0 +1,622 @@ +#pragma once +#include "tasks/auto_aim/auto_aim.hpp" +#include "tasks/auto_buff/auto_buff.hpp" +#include "tasks/imodule.hpp" +#include "tasks/type_common.hpp" +#include "tasks/utils/sinple_img_rotate_saver.hpp" +#include "tasks/utils/utils.hpp" +#include "wust_vl/common/concurrency/ThreadPool.h" +#include "wust_vl/common/drivers/serial_driver.hpp" +#include +#include +#include +namespace wust_vision { +struct LoggerConfig: wust_vl::common::utils::ParamGroup { +public: + static constexpr const char* kKey = "logger"; + static constexpr const char* Logger = "Config: common::logger"; + const char* key() const override { + return kKey; + } + using Ptr = std::shared_ptr; + LoggerConfig() {} + static Ptr create() { + return std::make_shared(); + } + + bool first_load = false; + + void loadSelf(const YAML::Node& node) override { + if (!first_load) { + std::string log_level = node["log_level"].as(); + std::string log_path = node["log_path"].as(); + bool use_logcli = node["use_logcli"].as(); + bool use_logfile = node["use_logfile"].as(); + bool use_simplelog = node["use_simplelog"].as(); + wust_vl::initLogger(log_level, log_path, use_logcli, use_logfile, use_simplelog); + first_load = true; + } else { + } + } +}; + +struct MaxInferRunningConfig: + public wust_vl::common::utils::SimpleConfigBase { + static constexpr const char* kKey = "max_infer_running"; + static constexpr const char* Logger = "Config: common::max_infer_running"; + + int max_infer_running = 0; + + void loadSelf(const YAML::Node& node) override { + loadOnceOrUpdate( + node, + max_infer_running, + [](const YAML::Node& n, int& v) { v = n.as(); }, + [](const YAML::Node& n, int& v) { + int nv = n.as(); + if (nv != v) { + v = nv; + WUST_DEBUG(Logger) << "max_infer_running change to " << nv; + } + } + ); + } +}; +template +class VisionBase { +public: + VisionBase( + std::string common_config, + std::string camera_config, + std::string auto_aim_config, + std::string auto_buff_config + ): + common_config_(common_config), + camera_config_(camera_config), + auto_aim_config_(auto_aim_config), + auto_buff_config_(auto_buff_config) {} + ~VisionBase() { + run_flag_ = false; + if (debug_thread_.joinable()) { + debug_thread_.join(); + } + + WUST_MAIN("main") << "vision stop already!"; + } + bool init(bool debug_mode) { + try { + const char* v = std::getenv("VISION_ROOT"); + if (v) + std::cout << "[env] VISION_ROOT = " << v << "\n"; + else + std::cout << "[env] VISION_ROOT not set in this process\n"; + debug_mode_ = debug_mode; + common_config_parameter_.loadFromFile(common_config_); + control_config_ = ControlConfig::create(this); + shoot_config_ = ShootConfig::create(this); + record_config_ = RecordConfig::create(this); + logger_config_ = LoggerConfig::create(); + tf_config_ = TFConfig::create(); + max_infer_running_config_ = MaxInferRunningConfig::create(); + common_config_parameter_.registerGroup(*control_config_); + common_config_parameter_.registerGroup(*shoot_config_); + common_config_parameter_.registerGroup(*record_config_); + common_config_parameter_.registerGroup(*logger_config_); + common_config_parameter_.registerGroup(*tf_config_); + common_config_parameter_.registerGroup(*max_infer_running_config_); + common_config_parameter_.reloadFromOldPath(); + auto config = common_config_parameter_.getConfig(); + debug_fps_ = config["debug_fps"].as(); + attack_mode_ = config["attack_mode"].as(); + detect_color_ = config["detect_color"].as(); + + wust_vl::common::utils::ParameterManager::instance().registerParameter( + common_config_parameter_ + ); + YAML::Node camera_config = YAML::LoadFile(camera_config_); + camera_ = std::make_shared(); + camera_->init(camera_config); + camera_->setFrameCallback( + std::bind(&VisionBase::frameCallback, this, std::placeholders::_1) + ); + std::string camera_info_path = + utils::expandEnv(camera_config["camera_info_path"].as()); + YAML::Node config_camera_info = YAML::LoadFile(camera_info_path); + std::vector camera_k = + config_camera_info["camera_matrix"]["data"].as>(); + std::vector camera_d = + config_camera_info["distortion_coefficients"]["data"].as>(); + + assert(camera_k.size() == 9); + assert(camera_d.size() == 5); + + cv::Mat K(3, 3, CV_64F); + std::memcpy(K.data, camera_k.data(), 9 * sizeof(double)); + + cv::Mat D(1, 5, CV_64F); + std::memcpy(D.data, camera_d.data(), 5 * sizeof(double)); + + auto camera_info = std::make_pair(K.clone(), D.clone()); + camera_info_ = camera_info; + + thread_pool_ = std::make_unique( + max_infer_running_config_->max_infer_running + ); + motion_buffer_ = + std::make_shared>(); + + timer_ = std::make_unique("solve"); + + WUST_MAIN("main") << "vision init already!"; + } catch (std::exception& e) { + std::cerr << "init exception: " << e.what() << std::endl; + } + return true; + } + void start() { + run_flag_ = true; + camera_->start(); + for (auto& module: modules_) { + if (module.second) { + module.second->start(); + } + } + if (timer_) { + const auto timercallback = + std::bind(&VisionBase::timerCallback, this, std::placeholders::_1); + const double rate_hz = control_config_->control_rate_param.get(); + timer_->start(rate_hz, timercallback); + } + if (serial_) { + serial_->start(); + } else if (rotate_reader_) { + rotate_reader_->replay([this](const Eigen::Vector3d& ypr) { + ReceiveAimINFO aim_data; + aim_data.yaw = ypr(0); + aim_data.pitch = ypr(1); + aim_data.roll = ypr(2); + this->processAimData(aim_data); + }); + } + if (debug_mode_) { + debug_thread_ = std::thread([this]() { this->debugThread(); }); + } + if (rotate_writer_) { + rotate_writer_->start(); + } + if (img_writer_) { + img_writer_->start(); + } + } + void serialCallback(const uint8_t* data, std::size_t len) { + if (len != sizeof(ReceiveAimINFO)) { + return; + } + try { + const std::vector buf(data, data + len); + const ReceiveAimINFO aim_data = + wust_vl::common::drivers::fromVector(buf); + processAimData(aim_data); + + } catch (const std::exception& e) { + std::cerr << "serialCallback exception: " << e.what() << std::endl; + } catch (...) { + std::cerr << "serialCallback unknown exception" << std::endl; + } + } + void frameCallback(wust_vl::video::ImageFrame& img_frame) { + if (!run_flag_ || infer_running_count_ >= max_infer_running_config_->max_infer_running) { + return; + } + if (img_frame.src_img.empty()) { + return; + } + + thread_pool_->enqueue([this, img_frame = std::move(img_frame)]() mutable { + infer_running_count_++; + CommonFrame frame; + frame.detect_color = detect_color_; + frame.img_frame = std::move(img_frame); + frame.expanded = + cv::Rect(0, 0, frame.img_frame.src_img.cols, frame.img_frame.src_img.rows); + frame.offset = cv::Point2f(0, 0); + frame.any_ctx = VisionCtx { .motion_buffer = motion_buffer_, + .camera = camera_, + .communication_delay_μs = + control_config_->communication_delay_us_param.get(), + .mode = attack_mode_ }; + if (frame.img_frame.src_img.empty()) { + infer_running_count_--; + return; + } + + if (img_writer_) { + img_writer_->push(frame.img_frame.src_img); + } + typename Mode::AttackMode mode = Mode::toAttackMode(attack_mode_); + auto module = modules_.at(mode); + if (module) { + module->pushInput(frame); + } + + infer_running_count_--; + }); + } + + void checkStateMatchMode() const { + const auto mode = Mode::toAttackMode(attack_mode_); + + auto this_module = modules_.at(mode); + if (!this_module) { + return; + } + auto this_thread = this_module->getThread(); + + auto pause_others = [&]() { + auto self = this_module; + + for (auto& [_, module]: modules_) { + if (!module) { + continue; + } + if (module != self) { + if (auto t = module->getThread()) { + t->pause(); + } + } + } + }; + + if (!this_thread) { + pause_others(); + return; + } + + if (!this_thread->isAlive()) { + this_thread->resume(); + } + + pause_others(); + } + void timerCallback(double dt_ms) { + if (!run_flag_) { + return; + } + + GimbalCmd cmd; + try { + typename Mode::AttackMode mode = Mode::toAttackMode(attack_mode_); + auto module = modules_.at(mode); + if (!module) { + return; + } + cmd = module->solve(bullet_speed_); + } catch (const std::exception& e) { + std::cout << "solve error: " << e.what() << std::endl; + } + if (!cmd.isValid()) { + return; + } + last_cmd_ = cmd; + + double cmd_pitch = cmd.pitch; + double cmd_yaw = cmd.yaw; + SendRobotCmdData send_data; + send_data.cmd_ID = ID_ROBOT_CMD; + send_data.time_stamp = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch() + ) + .count(); + if (cmd.distance > 0.5) { + send_data.appear = cmd.appear; + } else { + send_data.appear = false; + } + + send_data.detect_color = detect_color_; + send_data.pitch = cmd_pitch + cmd.v_pitch * control_config_->pitch_ramp_param.get(); + send_data.yaw = cmd_yaw + cmd.v_yaw * control_config_->yaw_ramp_param.get(); + send_data.v_pitch = cmd.v_pitch; + send_data.v_yaw = cmd.v_yaw; + send_data.a_pitch = cmd.a_pitch; + send_data.a_yaw = cmd.a_yaw; + send_data.target_yaw = cmd.target_yaw; + send_data.target_pitch = cmd.target_pitch; + send_data.enable_pitch_diff = cmd.enable_pitch_diff; + send_data.enable_yaw_diff = cmd.enable_yaw_diff; + send_data.shoot_rate = shoot_config_->rate_param.get(); + if (serial_) { + serial_->write(std::move(wust_vl::common::drivers::toVector(send_data))); + } + } + bool isWebRunning() { + static std::atomic cached { true }; + static std::atomic last_check_ms { 0 }; + + const long long now_ms = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch() + ) + .count(); + if (last_check_ms.load() == 0) { + last_check_ms = now_ms; + return true; + } + + if (now_ms - last_check_ms.load() >= 1000) { + const int ret = std::system("pgrep -x wust_vision_web > /dev/null 2>&1"); + cached = (ret == 0); + last_check_ms = now_ms; + } + + return cached.load(); + } + void debugThread() { + const double us_interval = 1e6 / static_cast(debug_fps_); + const auto kInterval = std::chrono::microseconds(static_cast(us_interval)); + while (run_flag_) { + const auto start_time = std::chrono::steady_clock::now(); + do { + try { + if (!isWebRunning()) { + break; + } + typename Mode::AttackMode mode = Mode::toAttackMode(attack_mode_); + auto module = modules_.at(mode); + if (module) { + module->doDebug(); + } + + // utils::XSecOnce( + // [this]() { + // wust_vl::common::utils::ParameterManager::instance() + // .allReloadFromOldPath(); + // }, + // 1.0 + // ); + + } catch (std::exception& e) { + std::cout << "debug thread error: " << e.what() << std::endl; + } + } while (0); + + const auto elapsed = std::chrono::steady_clock::now() - start_time; + if (elapsed < kInterval) { + std::this_thread::sleep_for(kInterval - elapsed); + } + } + } + + void processAimData(const ReceiveAimINFO& aim_data) { + static wust_vl::common::concurrency::Averager vyaw_avg(100); + if (!this->run_flag_) { + return; + } + detect_color_ = aim_data.detect_color; + bullet_speed_ = aim_data.bullet_speed; + const double roll = -(aim_data.roll) * M_PI / 180.0; + const double pitch = (aim_data.pitch) * M_PI / 180.0; + const double yaw = (aim_data.yaw) * M_PI / 180.0; + const double v_roll = aim_data.roll_vel * M_PI / 180.0; + const double v_pitch = aim_data.pitch_vel * M_PI / 180.0; + const double v_yaw = aim_data.yaw_vel * M_PI / 180.0; + vyaw_avg.add(v_yaw); + + const double v_x = 0.0; + const double v_y = 0.0; + const double v_z = 0.0; + + const auto now = std::chrono::steady_clock::now(); + if (motion_buffer_) { + CarMotion motion { yaw, pitch, roll, 0.0, v_pitch, v_roll, v_x, v_y, v_z }; + motion_buffer_->push(motion, now); + } + if (debug_mode_) { + updateSerialLog(aim_data); + flushSerialLog(); + } + + static auto last_push_time = std::chrono::steady_clock::now(); + const auto elapsed = + std::chrono::duration_cast(now - last_push_time).count(); + if (elapsed >= 10) { // 至少间隔 10ms(100Hz) + if (rotate_writer_) { + rotate_writer_->push(Eigen::Vector3d(aim_data.yaw, aim_data.pitch, aim_data.roll)); + } + last_push_time = now; + } + } + struct ControlConfig: wust_vl::common::utils::ParamGroup { + public: + static constexpr const char* kKey = "control"; + static constexpr const char* Logger = "Config: common::control"; + const char* key() const override { + return kKey; + } + using Ptr = std::shared_ptr; + ControlConfig(VisionBase* b) { + base = b; + communication_delay_us_param.onChange([this](int o, int n) { + if (isBaseACtive()) { + WUST_DEBUG(Logger) + << "communication_delay_μs from: " << o << " to: " << n << " us"; + } + }); + } + static Ptr create(VisionBase* b) { + return std::make_shared(b); + } + GEN_PARAM(double, communication_delay_us); + GEN_PARAM(double, yaw_ramp); + GEN_PARAM(double, pitch_ramp); + GEN_PARAM(double, control_rate); + VisionBase* base; + bool first_load = false; + bool isBaseACtive() { + return base != nullptr; + } + void loadSelf(const YAML::Node& node) override { + if (!isBaseACtive()) + return; + if (!first_load) { + communication_delay_us_param.set(node["communication_delay_us"].as()); + yaw_ramp_param.set(node["yaw_ramp"].as()); + pitch_ramp_param.set(node["pitch_ramp"].as()); + control_rate_param.set(node["control_rate"].as()); + std::string device_name = node["device_name"].as(); + base->serial_ = std::make_shared(); + bool use_serial = node["use_serial"].as(); + if (use_serial) { + wust_vl::common::drivers::SerialDriver::SerialPortConfig cfg { + /*baud*/ 115200, + /*csize*/ 8, + boost::asio::serial_port_base::parity::none, + boost::asio::serial_port_base::stop_bits::one, + boost::asio::serial_port_base::flow_control::none + }; + base->serial_->init_port(device_name, cfg); + base->serial_->set_receive_callback(std::bind( + &VisionBase::serialCallback, + base, + std::placeholders::_1, + std::placeholders::_2 + + )); + base->serial_->set_error_callback([&](const boost::system::error_code& ec) { + WUST_ERROR("serial") << "serial error: " << ec.message(); + }); + } + first_load = true; + } else { + communication_delay_us_param.load(node); + yaw_ramp_param.load(node); + pitch_ramp_param.load(node); + control_rate_param.load(node); + } + } + }; + ControlConfig::Ptr control_config_; + struct ShootConfig: wust_vl::common::utils::ParamGroup { + public: + static constexpr const char* kKey = "shoot"; + static constexpr const char* Logger = "Config: common::shoot"; + const char* key() const override { + return kKey; + } + using Ptr = std::shared_ptr; + ShootConfig(VisionBase* b) { + base = b; + rate_param.onChange([this](int o, int n) { + WUST_DEBUG(Logger) << "shoot_rate from: " << o << " to: " << n << " HZ"; + }); + } + + static Ptr create(VisionBase* b) { + return std::make_shared(b); + } + GEN_PARAM(int, rate); + GEN_PARAM(double, bullet_speed); + VisionBase* base; + bool first_load = false; + bool isBaseACtive() { + return base != nullptr; + } + void loadSelf(const YAML::Node& node) override { + if (!isBaseACtive()) + return; + if (!first_load) { + rate_param.set(node["rate"].as()); + bullet_speed_param.set(node["bullet_speed"].as()); + base->bullet_speed_ = bullet_speed_param.get(); + first_load = true; + } else { + rate_param.load(node); + } + } + }; + ShootConfig::Ptr shoot_config_; + struct RecordConfig: wust_vl::common::utils::ParamGroup { + public: + static constexpr const char* kKey = "record"; + static constexpr const char* Logger = "Config: common::record"; + const char* key() const override { + return kKey; + } + using Ptr = std::shared_ptr; + RecordConfig(VisionBase* b) { + base = b; + } + static Ptr create(VisionBase* b) { + return std::make_shared(b); + } + + VisionBase* base; + bool first_load = false; + bool isBaseACtive() { + return base != nullptr; + } + void loadSelf(const YAML::Node& node) override { + if (!isBaseACtive()) + return; + if (!first_load) { + bool use_record = node["use_record"].as(false); + if (use_record) { + std::string folder_path = node["folder_path"].as(); + auto file_name = utils::makeTimestampedFileName(); + std::string text_path = fmt::format("{}/{}.csv", folder_path, file_name); + std::string video_path = fmt::format("{}/{}.avi", folder_path, file_name); + + std::filesystem::create_directory(folder_path); + auto rw = std::make_shared(true); + base->rotate_writer_ = + std::make_shared>( + text_path, + rw + ); + auto imgw = std::make_shared( + video_path, + 30, + cv::VideoWriter::fourcc('M', 'J', 'P', 'G') + ); + base->img_writer_ = + std::make_shared>("", imgw); + } + bool use_rotate_reader = node["use_rotate_reader"].as(); + if (use_rotate_reader) { + std::string csv_path = node["read_csv_path"].as(); + base->rotate_reader_ = std::make_shared(csv_path); + } + first_load = true; + } else { + } + } + }; + RecordConfig::Ptr record_config_; + LoggerConfig::Ptr logger_config_; + TFConfig::Ptr tf_config_; + MaxInferRunningConfig::Ptr max_infer_running_config_; + int attack_mode_; + int debug_fps_; + bool detect_color_; + double bullet_speed_; + wust_vl::common::utils::Parameter common_config_parameter_; + std::unique_ptr thread_pool_; + std::map modules_; + std::shared_ptr camera_; + std::shared_ptr serial_; + std::unique_ptr timer_; + std::shared_ptr> motion_buffer_; + std::shared_ptr> rotate_writer_; + RotateReaderCSV::RotateReaderCSVPtr rotate_reader_; + std::shared_ptr> img_writer_; + std::thread debug_thread_; + GimbalCmd last_cmd_; + std::pair camera_info_; + bool run_flag_ = false; + bool debug_mode_ = false; + std::atomic infer_running_count_ { 0 }; + std::string common_config_; + std::string camera_config_; + std::string auto_aim_config_; + std::string auto_buff_config_; +}; +} // namespace wust_vision \ No newline at end of file diff --git a/wust_vision-main/templates/index.html b/wust_vision-main/templates/index.html new file mode 100644 index 0000000..dfc60b8 --- /dev/null +++ b/wust_vision-main/templates/index.html @@ -0,0 +1,302 @@ + + + + + + 崇实战队自瞄系统网页调试器 + + + + + + +

+ +
+
+ +
+
+ 视频流 + +
+
+ + +
+
+

Target_info

+
+
+
+

Serial_info

+
+
+
+ + +
+
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+ 主图范围设置 +
+ +
+
+ +
+ + +
+
+
+
+
+ + + + + + diff --git a/wust_vision-main/test/nav.cpp b/wust_vision-main/test/nav.cpp new file mode 100644 index 0000000..36f5d69 --- /dev/null +++ b/wust_vision-main/test/nav.cpp @@ -0,0 +1,67 @@ +#include "geometry_msgs/msg/twist.hpp" +#include "ros2/ros2.hpp" +#include "tasks/packet_typedef.hpp" +#include +#include +#include +#include +using namespace wust_vision; +class Nav: public rclcpp::Node { +public: + Nav(): Node("wust_vision_global_node") { + cmd_sub_ = this->create_subscription( + "cmd_vel", + 10, + std::bind(&Nav::twistCb, this, std::placeholders::_1) + ); + serial_ = std::make_shared(); + wust_vl::common::drivers::SerialDriver::SerialPortConfig cfg { + /*baud*/ 115200, + /*csize*/ 8, + boost::asio::serial_port_base::parity::none, + boost::asio::serial_port_base::stop_bits::one, + boost::asio::serial_port_base::flow_control::none + }; + serial_->init_port("/dev/ttyACM0", cfg); + serial_->set_receive_callback(std::bind( + &Nav::serialCallback, + this, + std::placeholders::_1, + std::placeholders::_2 + + )); + serial_->set_error_callback([&](const boost::system::error_code& ec) { + WUST_ERROR("serial") << "serial error: " << ec.message(); + }); + } + void serialCallback(const uint8_t* data, std::size_t len) {} + void twistCb(const geometry_msgs::msg::Twist::SharedPtr msg) { + NavRobotCmdData send_data; + send_data.cmd_ID = ID_NAV_CMD; + send_data.packet_type = ID_NAV_CONTROL; + send_data.time_stamp = + static_cast(std::chrono::duration_cast( + wust_vl::common::utils::time_utils::now().time_since_epoch() + ) + .count()); + send_data.vx = -msg->linear.y; + send_data.vy = msg->linear.x; + send_data.wz = msg->angular.z; + if (serial_) { + serial_->write(std::move(wust_vl::common::drivers::toVector(send_data))); + } + } + rclcpp::Subscription::SharedPtr cmd_sub_; + std::shared_ptr serial_; +}; +int main(int argc, char** argv) { + rclcpp::init(argc, argv); + + { + auto node = rclcpp::Node::make_shared("wust_vision_global_node"); + rclcpp::spin(node); + // rclcpp::shutdown(); + } + + return 0; +} diff --git a/wust_vision-main/test/test_usbcamera.cpp b/wust_vision-main/test/test_usbcamera.cpp new file mode 100644 index 0000000..fc80aa0 --- /dev/null +++ b/wust_vision-main/test/test_usbcamera.cpp @@ -0,0 +1,19 @@ + +#include "wust_vl/video/uvc.hpp" +#include +#include +int main() { + wust_vl::video::UVC uvc; + auto config = YAML::LoadFile("/home/hy/wust_vision/config/camera.yaml"); + uvc.loadConfig(config["uvc"]); + uvc.start(); + while (true) { + auto frame = uvc.readImage(); + if (!frame.src_img.empty()) { + cv::imshow("frame", frame.src_img); + cv::waitKey(1); + } + } + uvc.stop(); + return 0; +} diff --git a/wust_vision-main/user_bashrc_copy.bash b/wust_vision-main/user_bashrc_copy.bash new file mode 100644 index 0000000..0a9fbda --- /dev/null +++ b/wust_vision-main/user_bashrc_copy.bash @@ -0,0 +1,259 @@ +# don't put duplicate lines or lines starting with space in the history. +# See bash(1) for more options +HISTCONTROL=ignoreboth + +# append to the history file, don't overwrite it +shopt -s histappend + +# for setting history length see HISTSIZE and HISTFILESIZE in bash(1) +HISTSIZE=1000 +HISTFILESIZE=2000 + +# check the window size after each command and, if necessary, +# update the values of LINES and COLUMNS. +shopt -s checkwinsize + +# If set, the pattern "**" used in a pathname expansion context will +# match all files and zero or more directories and subdirectories. +#shopt -s globstar + +# make less more friendly for non-text input files, see lesspipe(1) +[ -x /usr/bin/lesspipe ] && eval "$(SHELL=/bin/sh lesspipe)" + +# set variable identifying the chroot you work in (used in the prompt below) +if [ -z "${debian_chroot:-}" ] && [ -r /etc/debian_chroot ]; then + debian_chroot=$(cat /etc/debian_chroot) +fi + +# set a fancy prompt (non-color, unless we know we "want" color) +case "$TERM" in + xterm-color|*-256color) color_prompt=yes;; +esac + +# uncomment for a colored prompt, if the terminal has the capability; turned +# off by default to not distract the user: the focus in a terminal window +# should be on the output of commands, not on the prompt +#force_color_prompt=yes + +if [ -n "$force_color_prompt" ]; then + if [ -x /usr/bin/tput ] && tput setaf 1 >&/dev/null; then + # We have color support; assume it's compliant with Ecma-48 + # (ISO/IEC-6429). (Lack of such support is extremely rare, and such + # a case would tend to support setf rather than setaf.) + color_prompt=yes + else + color_prompt= + fi +fi + +if [ "$color_prompt" = yes ]; then + PS1='${debian_chroot:+($debian_chroot)}\[\033[01;32m\]\u@\h\[\033[00m\]:\[\033[01;34m\]\w\[\033[00m\]\$ ' +else + PS1='${debian_chroot:+($debian_chroot)}\u@\h:\w\$ ' +fi +unset color_prompt force_color_prompt + +# If this is an xterm set the title to user@host:dir +case "$TERM" in +xterm*|rxvt*) + PS1="\[\e]0;${debian_chroot:+($debian_chroot)}\u@\h: \w\a\]$PS1" + ;; +*) + ;; +esac + +# enable color support of ls and also add handy aliases +if [ -x /usr/bin/dircolors ]; then + test -r ~/.dircolors && eval "$(dircolors -b ~/.dircolors)" || eval "$(dircolors -b)" + alias ls='ls --color=auto' + #alias dir='dir --color=auto' + #alias vdir='vdir --color=auto' + + alias grep='grep --color=auto' + alias fgrep='fgrep --color=auto' + alias egrep='egrep --color=auto' +fi + +# colored GCC warnings and errors +#export GCC_COLORS='error=01;31:warning=01;35:note=01;36:caret=01;32:locus=01:quote=01' + +# some more ls aliases +alias ll='ls -alF' +alias la='ls -A' +alias l='ls -CF' + +# Add an "alert" alias for long running commands. Use like so: +# sleep 10; alert +alias alert='notify-send --urgency=low -i "$([ $? = 0 ] && echo terminal || echo error)" "$(history|tail -n1|sed -e '\''s/^\s*[0-9]\+\s*//;s/[;&|]\s*alert$//'\'')"' + +# Alias definitions. +# You may want to put all your additions into a separate file like +# ~/.bash_aliases, instead of adding them here directly. +# See /usr/share/doc/bash-doc/examples in the bash-doc package. + +if [ -f ~/.bash_aliases ]; then + . ~/.bash_aliases +fi + +# enable programmable completion features (you don't need to enable +# this, if it's already enabled in /etc/bash.bashrc and /etc/profile +# sources /etc/bash.bashrc). +if ! shopt -oq posix; then + if [ -f /usr/share/bash-completion/bash_completion ]; then + . /usr/share/bash-completion/bash_completion + elif [ -f /etc/bash_completion ]; then + . /etc/bash_completion + fi +fi + + + + + + + + +export PATH="/usr/local/opt/llvm/bin:$PATH" + +export LD_LIBRARY_PATH=/usr/lib +#export RMW_IMPLEMENTATION=rmw_cyclonedds_cpp + + + + + + + + + + + +#export MVCAM_SDK_PATH=/opt/MVS + +#export MVCAM_COMMON_RUNENV=/opt/MVS/lib + +#export MVCAM_GENICAM_CLPROTOCOL=/opt/MVS/lib/CLProtocol + +#export ALLUSERSPROFILE=/opt/MVS/MVFG +#export LD_LIBRARY_PATH=/opt/MVS/lib/64:/opt/MVS/lib/32:$LD_LIBRARY_PATH + + + + + + + +export PATH=/usr/local/cuda-12.6/bin:$PATH +export LD_LIBRARY_PATH=/usr/local/cuda-12.6/lib64:$LD_LIBRARY_PATH + + +export CUDA_HOME=/usr/local/cuda-12.6 +#export CUDNN_HOME=/usr/local/cuda-11.8 +export PATH=$CUDA_HOME/bin:$PATH +export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH + +#export ROS_DOMAIN_ID=12 + + +#export PATH=$PATH:/usr/local/cuda-11.8/bin +#export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.8/lib64 + + + + + + +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/hy/TensorRT-10.6.0.26/lib +export PATH=/home/hy/TensorRT-10.6.0.26/bin:$PATH +export RMW_IMPLEMENTATION=rmw_cyclonedds_cpp +export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH + + + +#source /home/hy/acado/build/acado_env.sh +export PATH=/home/hy/arm-gnu-toolchain-14.2.rel1-x86_64-arm-none-eabi/bin:$PATH + +# >>> xmake >>> +test -f "/home/hy/.xmake/profile" && source "/home/hy/.xmake/profile" +# <<< xmake <<< + +# >>> fishros initialize >>> +source /opt/ros/humble/setup.bash +# <<< fishros initialize <<< + +#source /home/hy/moveit2_ws/install/setup.bash +export PATH="$PATH:$HOME/.local/bin" + +export NVM_DIR="$HOME/.nvm" +[ -s "$NVM_DIR/nvm.sh" ] && \. "$NVM_DIR/nvm.sh" # This loads nvm +[ -s "$NVM_DIR/bash_completion" ] && \. "$NVM_DIR/bash_completion" # This loads nvm bash_completion + + + + + + + + + +export ROS_DOMAIN_ID=42 +#alias code='/usr/share/code/code --no-sandbox' + + +# >>> wust_vision dev >>> +py() { + source ~/anaconda3/etc/profile.d/conda.sh + echo "已激活 conda, conda activate激活环境,conda deactivate退出环境, conda info --envs查看环境" +} + +buildme() { + colcon build --symlink-install --cmake-args -DCMAKE_BUILD_TYPE=Release +} + +builddebug() { + colcon build --packages-select "$1" --cmake-args -DCMAKE_BUILD_TYPE=Debug --symlink-install +} + +killros() { + pkill -f ros && ros2 daemon stop && ros2 daemon start +} + +format() { + find . -path ./build -prune -o \ + -type f \( -name '*.h' -o -name '*.hpp' -o -name '*.c' -o -name '*.cu' -o -name '*.cpp' \) \ + -exec clang-format -i {} + +} + +s() { + source install/setup.bash +} + +hik() { + export MVCAM_SDK_PATH=/opt/MVS + export MVCAM_COMMON_RUNENV=/opt/MVS/lib + export MVCAM_GENICAM_CLPROTOCOL=/opt/MVS/lib/CLProtocol + export ALLUSERSPROFILE=/opt/MVS/MVFG + export LD_LIBRARY_PATH=/opt/MVS/lib/64:/opt/MVS/lib/32:$LD_LIBRARY_PATH +} +# <<< wust_vision dev <<< +. "$HOME/.cargo/env" +. "$HOME/.cargo/env" +export PATH="${HOME}/.bin:${PATH}" + + + +export PATH=/usr/local/go/bin:$PATH +export RMW_IMPLEMENTATION=rmw_cyclonedds_cpp +export LD_LIBRARY_PATH=/home/hy/TensorRT-10.6.0.26/lib:$LD_LIBRARY_PATH +export CUDA_HOME=/usr/local/cuda-12.6 +export PATH=$CUDA_HOME/bin:$PATH +export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH +export LD_LIBRARY_PATH=/home/hy/onnxruntime-linux-x64-gpu-1.22.0/lib:$LD_LIBRARY_PATH + +#export LUCKFOX_SDK_PATH=/home/hy/luckfox-pico +#export GLIBC_COMPILER=/usr/bin/arm-linux-gnueabihf- +export JAVA_HOME=/usr/lib/jvm/java-21-openjdk-amd64 +export PATH=$JAVA_HOME/bin:$PATH +source ~/acado/build/acado_env.sh + +source /usr/local/share/sentry_interfase/local_setup.bash diff --git a/wust_vision-main/web.py b/wust_vision-main/web.py new file mode 100644 index 0000000..213643d --- /dev/null +++ b/wust_vision-main/web.py @@ -0,0 +1,263 @@ +from flask import Flask, render_template, Response, jsonify +import time, json, socket, os, logging, struct +import mmap +import threading +import subprocess +import atexit +import fcntl +import setproctitle +setproctitle.setproctitle("wust_vision_web") + +app = Flask(__name__) + +# =============================== +# 参数设置:选择模式 +# =============================== +# True -> 强制共享内存模式 +# False -> 文件模式 +USE_SHARED_MEMORY_MODE = True +STREAM_FPS = 60 +FRAME_INTERVAL = 1.0 / STREAM_FPS +# 通信参数 +shared_memory_path = "/dev/shm/debug_frame" +shared_size = 2 * 1024 * 1024 # 2MB +shared_frame_path = "/dev/shm/debug_frame.jpg" + +# 初始化通信模式 +use_shared_memory = False +mapfile = None +fd = None + +# 权限修复锁 +permission_lock = threading.Lock() + +port = 8000 +def ensure_shared_memory_permissions(): + """确保共享内存文件存在且权限正确""" + with permission_lock: + try: + if not os.path.exists(shared_memory_path): + print(f"创建共享内存文件: {shared_memory_path}") + with open(shared_memory_path, "wb") as f: + f.write(b"\0" * shared_size) + + current_mode = oct(os.stat(shared_memory_path).st_mode & 0o777) + if current_mode != "0o777": + print(f"修复权限 (当前: {current_mode} -> 目标: 777)") + result = subprocess.run( + ["sudo", "chmod", "777", shared_memory_path], + capture_output=True, + text=True, + ) + if result.returncode == 0: + print("权限修复成功") + return True + else: + print(f"权限修复失败: {result.stderr.strip()}") + return False + return True + except Exception as e: + print(f"权限修复异常: {str(e)}") + return False + + +def init_shared_memory(): + """初始化共享内存连接""" + global use_shared_memory, mapfile, fd + + if not ensure_shared_memory_permissions(): + print("[WARN] 权限修复失败") + use_shared_memory = False + return False + + try: + fd = os.open(shared_memory_path, os.O_RDONLY) + mapfile = mmap.mmap(fd, shared_size, mmap.MAP_SHARED, mmap.PROT_READ) + fcntl.flock(fd, fcntl.LOCK_SH | fcntl.LOCK_NB) + use_shared_memory = True + print("[INFO] 共享内存初始化成功") + return True + except Exception as e: + print(f"[WARN] 共享内存初始化失败: {e}") + if mapfile: + try: + mapfile.close() + except: + pass + mapfile = None + if fd: + try: + os.close(fd) + except: + pass + fd = None + use_shared_memory = False + return False + + +# =============================== +# 初始化模式 +# =============================== +if USE_SHARED_MEMORY_MODE: + if init_shared_memory(): + print("✅ 使用共享内存模式") + else: + print("⚠️ 强制共享内存模式失败,回退到文件模式") + use_shared_memory = False +else: + use_shared_memory = False + print("ℹ️ 使用文件模式") + + +# =============================== +# 清理函数 +# =============================== +@atexit.register +def cleanup(): + if mapfile: + try: + mapfile.close() + except: + pass + if fd: + try: + os.close(fd) + except: + pass + + +# =============================== +# MJPEG 流生成器 +# =============================== +def mjpeg_stream(): + global use_shared_memory, mapfile + last_fix_attempt = 0 + + while True: + try: + if use_shared_memory and mapfile: + try: + mapfile.seek(0) + size_bytes = mapfile.read(4) + if len(size_bytes) < 4: + time.sleep(FRAME_INTERVAL) + continue + jpg_size = struct.unpack("I", size_bytes)[0] + if jpg_size <= 0 or jpg_size > shared_size - 4: + time.sleep(FRAME_INTERVAL) + continue + jpg_bytes = mapfile.read(jpg_size) + if len(jpg_bytes) != jpg_size: + time.sleep(FRAME_INTERVAL) + continue + if jpg_bytes[0:3] != b"\xff\xd8\xff": + time.sleep(FRAME_INTERVAL) + continue + except (OSError, ValueError) as e: + current_time = time.time() + if current_time - last_fix_attempt > 60: + print("尝试重新初始化共享内存...") + if init_shared_memory(): + continue + last_fix_attempt = current_time + use_shared_memory = False + continue + + if not use_shared_memory or not mapfile: + try: + with open(shared_frame_path, "rb") as f: + jpg_bytes = f.read() + if jpg_bytes[0:3] != b"\xff\xd8\xff": + time.sleep(FRAME_INTERVAL) + continue + except FileNotFoundError: + time.sleep(0.1) + continue + except Exception: + time.sleep(0.1) + continue + + yield b"--frame\r\n" b"Content-Type: image/jpeg\r\n\r\n" + jpg_bytes + b"\r\n" + time.sleep(FRAME_INTERVAL) + except Exception: + time.sleep(0.5) + + +# =============================== +# Flask 路由 +# =============================== +@app.route("/") +def index(): + def get_local_ip(): + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect(("10.255.255.255", 1)) + IP = s.getsockname()[0] + except: + IP = "127.0.0.1" + finally: + s.close() + return IP + + url = f"http://{get_local_ip()}:{port}" + return render_template("index.html", server_url=url) + + +@app.route("/video") +def video_feed(): + return Response( + mjpeg_stream(), mimetype="multipart/x-mixed-replace; boundary=frame" + ) + + +@app.route("/data") +def get_data(): + try: + with open("/dev/shm/cmd_log.json", "r") as f: + return jsonify(json.load(f)) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/serial_log") +def serial_log(): + try: + with open("/dev/shm/serial_log.json", "r") as f: + return jsonify(json.load(f)) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/target_log") +def target_log(): + try: + with open("/dev/shm/target_log.json", "r") as f: + return jsonify(json.load(f)) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +# =============================== +# 主函数 +# =============================== +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + logging.getLogger("werkzeug").setLevel(logging.ERROR) + + def get_local_ip(): + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect(("10.255.255.255", 1)) + IP = s.getsockname()[0] + except: + IP = "127.0.0.1" + finally: + s.close() + return IP + + url = f"http://{get_local_ip()}:{port}" + print(f"✅ Web 调试器已启动: {url}") + print(f" - 共享内存模式: {'是' if use_shared_memory else '否'}") + print(f" - 访问地址: {url}") + + app.run(host="0.0.0.0", port=port, threaded=True)