add wust typr mpc and mutipule x

This commit is contained in:
cyy_mac
2026-03-27 03:41:42 +08:00
parent 2c64655fae
commit 7dcb53bb77
192 changed files with 29571 additions and 9 deletions

View File

@@ -32,6 +32,7 @@
#include "rm_interfaces/msg/target.hpp"
#include "rm_utils/math/manual_compensator.hpp"
#include "rm_utils/math/trajectory_compensator.hpp"
#include "armor_solver/trajectory_planner.hpp"
namespace fyt::auto_aim {
// Solver class used to solve the gimbal command from tracked target
@@ -53,6 +54,7 @@ public:
const std::vector<Eigen::Vector3d>& getArmorPositions() const noexcept {
return cached_armor_positions_;
}
TrajectoryDebug getTrajectoryDebug() const noexcept;
void setBulletSpeed(double bullet_speed) noexcept;
void updateRuntimeParams(double max_tracking_v_yaw,
double prediction_delay,
@@ -61,6 +63,7 @@ public:
double min_switching_v_yaw,
double shooting_range_w,
double shooting_range_h) noexcept;
void setTrajectoryType(const std::string& type, std::weak_ptr<rclcpp::Node> node = {});
private:
// Get the armor positions from the target robot
@@ -97,6 +100,8 @@ private:
std::unique_ptr<TrajectoryCompensator> trajectory_compensator_;
std::unique_ptr<ManualCompensator> manual_compensator_;
std::unique_ptr<TrajectoryPlanner> planner_;
TrajectoryPlanner::Type planner_type_ = TrajectoryPlanner::Type::LINEAR;
std::array<double, 3> rpy_;

View File

@@ -30,6 +30,7 @@
#include <rcl_interfaces/msg/set_parameters_result.hpp>
#include <rclcpp/rclcpp.hpp>
#include <tf2_geometry_msgs/tf2_geometry_msgs.hpp>
#include <visualization_msgs/msg/marker.hpp>
#include <visualization_msgs/msg/marker_array.hpp>
// std
#include <atomic>
@@ -64,6 +65,7 @@ private:
void publishMarkers(const rm_interfaces::msg::Target &target_msg,
const rm_interfaces::msg::GimbalCmd &gimbal_cmd) noexcept;
void publishTrajectoryDebug() noexcept;
void setModeCallback(const std::shared_ptr<rm_interfaces::srv::SetMode::Request> request,
@@ -107,6 +109,7 @@ private:
// Publisher
rclcpp::Publisher<rm_interfaces::msg::Target>::SharedPtr target_pub_;
rclcpp::Publisher<rm_interfaces::msg::GimbalCmd>::SharedPtr gimbal_pub_;
rclcpp::Publisher<visualization_msgs::msg::Marker>::SharedPtr traj_debug_pub_;
rclcpp::Subscription<rm_interfaces::msg::SerialReceiveData>::SharedPtr serial_sub_;
rclcpp::TimerBase::SharedPtr pub_timer_;
void timerCallback();

View File

@@ -0,0 +1,426 @@
// Trajectory Planner for armor_solver
// Implements quintic polynomial trajectory planning and TinyMPC-based MPC
#ifndef ARMOR_SOLVER_TRAJECTORY_PLANNER_HPP_
#define ARMOR_SOLVER_TRAJECTORY_PLANNER_HPP_
#include <memory>
#include <vector>
#include <algorithm>
#include <cassert>
#include <cmath>
#include <Eigen/Dense>
#include <angles/angles.h>
namespace fyt::auto_aim {
// Forward declare for use in planner
struct TargetInfo {
Eigen::Vector3d position;
Eigen::Vector3d velocity;
double yaw;
double v_yaw;
double radius_1;
double radius_2;
double d_za;
double d_zc;
size_t armors_num;
// Default constructor
TargetInfo() = default;
// Construct from rm_interfaces Target msg
explicit TargetInfo(const rm_interfaces::msg::Target& target_msg)
: position(target_msg.position.x, target_msg.position.y, target_msg.position.z),
velocity(target_msg.velocity.x, target_msg.velocity.y, target_msg.velocity.z),
yaw(target_msg.yaw),
v_yaw(target_msg.v_yaw),
radius_1(target_msg.radius_1),
radius_2(target_msg.radius_2),
d_za(target_msg.d_za),
d_zc(target_msg.d_zc),
armors_num(static_cast<size_t>(target_msg.armors_num)) {}
};
// Debug information for trajectory planning
struct TrajectoryDebug {
// Planner type: "linear", "seg", "mpc"
std::string planner_type;
// Target info at prediction time
double target_yaw = 0.0;
double target_pitch = 0.0;
double target_distance = 0.0;
// Planned gimbal state
double planned_yaw = 0.0;
double planned_pitch = 0.0;
double planned_yaw_v = 0.0;
double planned_pitch_v = 0.0;
double planned_yaw_a = 0.0;
double planned_pitch_a = 0.0;
// Time parameters
double flying_time = 0.0;
double total_dt = 0.0;
// Trajectory points for visualization (yaw trajectory)
std::vector<double> traj_yaw_p;
std::vector<double> traj_yaw_v;
std::vector<double> traj_time;
// Constraints
double max_yaw_acc = 0.0;
double max_pitch_acc = 0.0;
// MPC specific
double mpc_cost = 0.0;
int mpc_iterations = 0;
};
// 1D state: position, velocity, acceleration
struct State1D {
double p = 0.0;
double v = 0.0;
double a = 0.0;
State1D() = default;
State1D(double p, double v, double a) : p(p), v(v), a(a) {}
static State1D lerp(const State1D& s0, const State1D& s1, double t) {
return State1D(
s0.p + t * (s1.p - s0.p),
s0.v + t * (s1.v - s0.v),
s0.a + t * (s1.a - s0.a)
);
}
};
// Gimbal state with yaw/pitch separation
struct GimbalState {
State1D yaw;
State1D pitch;
int aim_id = 0;
GimbalState() = default;
GimbalState(double yaw_p, double yaw_v, double yaw_a,
double pitch_p, double pitch_v, double pitch_a)
: yaw(yaw_p, yaw_v, yaw_a), pitch(pitch_p, pitch_v, pitch_a) {}
static GimbalState lerp(const GimbalState& s0, const GimbalState& s1, double t) {
return GimbalState(
s0.yaw.p + t * (s1.yaw.p - s0.yaw.p),
s0.yaw.v + t * (s1.yaw.v - s0.yaw.v),
s0.yaw.a + t * (s1.yaw.a - s0.yaw.a),
s0.pitch.p + t * (s1.pitch.p - s0.pitch.p),
s0.pitch.v + t * (s1.pitch.v - s0.pitch.v),
s0.pitch.a + t * (s1.pitch.a - s0.pitch.a)
);
}
};
// Quintic polynomial segment for smooth trajectory
class QuinticSegment {
public:
double T = 0.0; // Duration
Eigen::Matrix<double, 6, 1> c; // Coefficients: c0 + c1*t + c2*t^2 + c3*t^3 + c4*t^4 + c5*t^5
QuinticSegment() = default;
explicit QuinticSegment(double duration) : T(duration) {}
// Build quintic segment from boundary conditions (closed-form solution)
static QuinticSegment build(const State1D& s0, const State1D& s1, double T) {
using Matrix6d = Eigen::Matrix<double, 6, 6>;
double T2 = T * T;
double T3 = T2 * T;
double T4 = T3 * T;
double T5 = T4 * T;
Matrix6d A;
A << 1, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0,
0, 0, 2, 0, 0, 0,
1, T, T2, T3, T4, T5,
0, 1, 2*T, 3*T2, 4*T3, 5*T4,
0, 0, 2, 6*T, 12*T2, 20*T3;
Eigen::Matrix<double, 6, 1> b;
b << s0.p, s0.v, s0.a, s1.p, s1.v, s1.a;
QuinticSegment seg(T);
seg.c = A.fullPivLu().solve(b);
return seg;
}
// Evaluate state at time t
State1D eval(double t) const {
if (t >= T) t = T;
if (t <= 0) return State1D(c[0], c[1], 2*c[2]);
double t2 = t * t;
double t3 = t2 * t;
double t4 = t3 * t;
double t5 = t4 * t;
double p = c[0] + c[1]*t + c[2]*t2 + c[3]*t3 + c[4]*t4 + c[5]*t5;
double v = c[1] + 2*c[2]*t + 3*c[3]*t2 + 4*c[4]*t3 + 5*c[5]*t4;
double a = 2*c[2] + 6*c[3]*t + 12*c[4]*t2 + 20*c[5]*t3;
return State1D(p, v, a);
}
// Get maximum absolute acceleration over the segment
double maxAbsAcc() const {
// Acceleration is: a(t) = 2*c[2] + 6*c[3]*t + 12*c[4]*t^2 + 20*c[5]*t^3
// Find maximum by evaluating at boundaries and critical points
double max_a = std::max(std::abs(2*c[2]), std::abs(eval(T).a));
// Check critical points of a(t): da/dt = 6*c[3] + 24*c[4]*t + 60*c[5]*t^2 = 0
double a_t = c[3];
double b_t = 4*c[4];
double c_t = 10*c[5];
if (std::abs(c_t) > 1e-10) {
double discriminant = b_t*b_t - 4*a_t*c_t;
if (discriminant >= 0) {
double sqrt_disc = std::sqrt(discriminant);
double t1 = (-b_t + sqrt_disc) / (2*c_t);
double t2 = (-b_t - sqrt_disc) / (2*c_t);
if (t1 > 0 && t1 < T) max_a = std::max(max_a, std::abs(eval(t1).a));
if (t2 > 0 && t2 < T) max_a = std::max(max_a, std::abs(eval(t2).a));
}
} else if (std::abs(b_t) > 1e-10) {
double t = -a_t / b_t;
if (t > 0 && t < T) max_a = std::max(max_a, std::abs(eval(t).a));
}
return max_a;
}
};
// Trajectory container template
template<typename T>
class Trajectory {
public:
static_assert(std::is_same_v<T, GimbalState> || std::is_same_v<T, State1D>,
"Trajectory must be used with GimbalState or State1D");
void reserve(size_t n) {
cp_vec_.reserve(n);
dt_vec_.reserve(n > 0 ? n - 1 : 0);
prefix_time_.reserve(n);
}
void clear() {
cp_vec_.clear();
dt_vec_.clear();
prefix_time_.clear();
total_duration_ = 0.0;
}
void push_back(const T& p, double dt = 0.0) {
if (cp_vec_.empty()) {
cp_vec_.push_back(p);
prefix_time_.push_back(0.0);
total_duration_ = 0.0;
return;
}
assert(dt >= 0.0);
cp_vec_.push_back(p);
dt_vec_.push_back(dt);
total_duration_ += dt;
prefix_time_.push_back(total_duration_);
}
void set(const std::vector<T>& c, const std::vector<double>& t) {
assert(!c.empty());
assert(c.size() == t.size() + 1);
cp_vec_ = c;
dt_vec_ = t;
prefix_time_.resize(cp_vec_.size());
prefix_time_[0] = 0.0;
for (size_t i = 0; i < dt_vec_.size(); ++i)
prefix_time_[i + 1] = prefix_time_[i] + dt_vec_[i];
total_duration_ = prefix_time_.back();
}
T getStateAtTime(double t) const {
if (cp_vec_.empty())
return T {};
if (t <= 0.0)
return cp_vec_.front();
if (t >= total_duration_)
return cp_vec_.back();
auto it = std::lower_bound(prefix_time_.begin(), prefix_time_.end(), t);
size_t i1 = std::distance(prefix_time_.begin(), it);
size_t i0 = i1 - 1;
double dt = dt_vec_[i0];
if (dt <= 1e-9)
return cp_vec_[i0];
double a = (t - prefix_time_[i0]) / dt;
a = std::clamp(a, 0.0, 1.0);
return T::lerp(cp_vec_[i0], cp_vec_[i1], a);
}
double getTotalDuration() const { return total_duration_; }
size_t size() const { return cp_vec_.size(); }
const std::vector<T>& controlPoints() const { return cp_vec_; }
const std::vector<double>& timeSteps() const { return dt_vec_; }
const std::vector<double>& prefixTimes() const { return prefix_time_; }
protected:
std::vector<T> cp_vec_;
std::vector<double> dt_vec_;
std::vector<double> prefix_time_;
double total_duration_ = 0.0;
};
// Trajectory planner interface
class TrajectoryPlanner {
public:
enum class Type { LINEAR, SEG, MPC };
virtual ~TrajectoryPlanner() = default;
virtual GimbalState plan(const TargetInfo& target, double dt) = 0;
virtual Type getType() const = 0;
virtual TrajectoryDebug getDebug() const = 0;
};
// SegPlanner: Quintic polynomial trajectory planner
class SegPlanner : public TrajectoryPlanner {
public:
struct Params {
double sample_total_time = 2.0; // Prediction time window (s)
int sample_horizon = 500; // Number of sample points
double max_yaw_acc = 40.0; // Max yaw acceleration (deg/s^2)
double max_pitch_acc = 25.0; // Max pitch acceleration (deg/s^2)
};
explicit SegPlanner(const Params& params) : params_(params) {}
~SegPlanner() override = default;
GimbalState plan(const TargetInfo& target, double dt) override;
Type getType() const override { return Type::SEG; }
TrajectoryDebug getDebug() const override { return debug_; }
void setParams(const Params& params) { params_ = params; }
const Params& getParams() const { return params_; }
private:
// Predict target state at time t
GimbalState predictTarget(const TargetInfo& target, double t) const;
// Unwrap angle discontinuities
void unwrapAngles(std::vector<GimbalState>& states) const;
// Compute velocity and acceleration at control points
std::pair<std::vector<State1D>, std::vector<State1D>>
computeNodeStates(const std::vector<GimbalState>& states,
const std::vector<double>& dt_vec) const;
// Build limited quintic segments
Trajectory<QuinticSegment>
buildLimit(const std::vector<State1D>& yaw_nodes,
const std::vector<State1D>& pitch_nodes,
double max_yaw_acc,
double max_pitch_acc) const;
// Convert gimbal angle to target angle (accounting for armor offset)
static std::pair<double, int>
computeArmorAngle(const Eigen::Vector3d& target_pos,
const Eigen::Vector3d& target_center,
double target_yaw,
size_t armors_num,
double radius_1,
double radius_2,
double d_zc,
double d_za);
Params params_;
Trajectory<GimbalState> trajectory_;
TrajectoryDebug debug_;
};
// MpcPlanner: Simplified MPC-based trajectory planner
// Uses feedback control with acceleration constraints
class MpcPlanner : public TrajectoryPlanner {
public:
struct Params {
double sample_total_time = 2.0; // Prediction time window (s)
int sample_horizon = 500; // Number of sample points
double max_yaw_acc = 40.0; // Max yaw acceleration (deg/s^2)
double max_pitch_acc = 25.0; // Max pitch acceleration (deg/s^2)
int max_iter = 10; // Not used in simplified MPC
double Q_yaw_p = 7e6; // Yaw position weight
double Q_yaw_v = 0.0; // Yaw velocity weight
double R_yaw = 3.0; // Yaw control weight
double rho = 1.0; // Not used in simplified MPC
};
explicit MpcPlanner(const Params& params);
~MpcPlanner() override;
GimbalState plan(const TargetInfo& target, double dt) override;
Type getType() const override { return Type::MPC; }
TrajectoryDebug getDebug() const override { return debug_; }
void setParams(const Params& params);
const Params& getParams() const { return params_; }
private:
// Initialize MPC solver
void initSolver();
// Setup problem matrices
void setupProblem();
// Solve MPC and get trajectory
void solveMpc(const std::vector<GimbalState>& ref_traj);
// Get state at specific time from MPC solution
GimbalState getStateAtTime(double t, const std::vector<GimbalState>& ref_traj) const;
Params params_;
int N_ = 50;
// State and input matrices
Eigen::MatrixXd x_; // State trajectory (2 x N)
Eigen::MatrixXd u_; // Control trajectory (1 x N-1)
Eigen::MatrixXd x_ref_; // Reference trajectory (2 x N)
Eigen::MatrixXd u_ref_; // Reference control (1 x N-1)
// System dynamics
Eigen::MatrixXd Adyn_; // State transition matrix (2 x 2)
Eigen::MatrixXd Bdyn_; // Input matrix (2 x 1)
// Constraints
Eigen::VectorXd u_min_; // Control bounds
Eigen::VectorXd u_max_;
// Cost weights
Eigen::Vector2d Q_; // State cost weights
Eigen::VectorXd R_; // Control cost weights
// Solution storage
std::vector<State1D> mpc_solution_yaw_;
// Debug info
TrajectoryDebug debug_;
};
} // namespace fyt::auto_aim
#endif // ARMOR_SOLVER_TRAJECTORY_PLANNER_HPP_

View File

@@ -53,6 +53,10 @@ Solver::Solver(std::weak_ptr<rclcpp::Node> n) : node_(n) {
FYT_WARN("armor_solver", "Manual compensator update failed!");
}
// Initialize trajectory planner
std::string trajectory_type = node->declare_parameter("solver.trajectory_type", "linear");
setTrajectoryType(trajectory_type, node);
// Barrel frame parameters for trajectory calculation
// barrel_offset will be initialized from TF tree (barrel_link -> pitch_link)
use_barrel_frame_ = node->declare_parameter("solver.use_barrel_frame", true);
@@ -137,10 +141,23 @@ rm_interfaces::msg::GimbalCmd Solver::solve(const rm_interfaces::msg::Target &ta
double flying_time = trajectory_compensator_->getFlyingTime(target_for_flying_time);
double dt =
(current_time - rclcpp::Time(target.header.stamp)).seconds() + flying_time + prediction_delay_;
target_position.x() += dt * target.velocity.x;
target_position.y() += dt * target.velocity.y;
target_position.z() += dt * target.velocity.z;
target_yaw += dt * target.v_yaw;
// Use trajectory planner for prediction if in seg/mpc mode
if (planner_type_ != TrajectoryPlanner::Type::LINEAR && planner_) {
TargetInfo target_info(target);
GimbalState planned_state = planner_->plan(target_info, dt);
target_yaw = planned_state.yaw.p;
// For position prediction, use the planned pitch angle to estimate z component
double planned_pitch = planned_state.pitch.p;
double horizontal_dist = target_position.head(2).norm();
target_position.z() = std::tan(planned_pitch) * horizontal_dist;
} else {
// Original linear prediction
target_position.x() += dt * target.velocity.x;
target_position.y() += dt * target.velocity.y;
target_position.z() += dt * target.velocity.z;
target_yaw += dt * target.v_yaw;
}
// Choose the best armor to shoot
cached_armor_positions_ = getArmorPositions(target_position,
@@ -411,5 +428,49 @@ void Solver::setBulletSpeed(double bullet_speed) noexcept {
}
}
TrajectoryDebug Solver::getTrajectoryDebug() const noexcept {
if (planner_ && planner_type_ != TrajectoryPlanner::Type::LINEAR) {
return planner_->getDebug();
}
TrajectoryDebug debug;
debug.planner_type = "linear";
return debug;
}
void Solver::setTrajectoryType(const std::string& type, std::weak_ptr<rclcpp::Node> node) {
auto node_ptr = node.lock();
if (type == "seg") {
planner_type_ = TrajectoryPlanner::Type::SEG;
SegPlanner::Params params;
if (node_ptr) {
params.sample_total_time = node_ptr->declare_parameter("solver.sample_total_time", 2.0);
params.sample_horizon = node_ptr->declare_parameter("solver.sample_horizon", 500);
params.max_yaw_acc = node_ptr->declare_parameter("solver.max_yaw_acc", 40.0);
params.max_pitch_acc = node_ptr->declare_parameter("solver.max_pitch_acc", 25.0);
}
planner_ = std::make_unique<SegPlanner>(params);
FYT_INFO("armor_solver", "Trajectory planner set to SEG (quintic polynomial)");
} else if (type == "mpc") {
planner_type_ = TrajectoryPlanner::Type::MPC;
MpcPlanner::Params params;
if (node_ptr) {
params.sample_total_time = node_ptr->declare_parameter("solver.sample_total_time", 2.0);
params.sample_horizon = node_ptr->declare_parameter("solver.sample_horizon", 500);
params.max_yaw_acc = node_ptr->declare_parameter("solver.max_yaw_acc", 40.0);
params.max_pitch_acc = node_ptr->declare_parameter("solver.max_pitch_acc", 25.0);
params.max_iter = node_ptr->declare_parameter("solver.max_iter", 10);
params.Q_yaw_p = node_ptr->declare_parameter("solver.Q_yaw", 7e6);
params.Q_yaw_v = node_ptr->declare_parameter("solver.Q_yaw_v", 0.0);
params.R_yaw = node_ptr->declare_parameter("solver.R_yaw", 3.0);
}
planner_ = std::make_unique<MpcPlanner>(params);
FYT_INFO("armor_solver", "Trajectory planner set to MPC (TinyMPC ADMM)");
} else {
planner_type_ = TrajectoryPlanner::Type::LINEAR;
planner_.reset();
FYT_INFO("armor_solver", "Trajectory planner set to LINEAR (default)");
}
}
} // namespace fyt::auto_aim

View File

@@ -19,7 +19,9 @@
#include "armor_solver/armor_solver_node.hpp"
// std
#include <cmath>
#include <memory>
#include <sstream>
#include <vector>
// project
#include "armor_solver/motion_model.hpp"
@@ -182,6 +184,8 @@ ArmorSolverNode::ArmorSolverNode(const rclcpp::NodeOptions &options)
rclcpp::SensorDataQoS());
gimbal_pub_ = this->create_publisher<rm_interfaces::msg::GimbalCmd>("armor_solver/cmd_gimbal",
rclcpp::SensorDataQoS());
traj_debug_pub_ = this->create_publisher<visualization_msgs::msg::Marker>("armor_solver/traj_debug",
rclcpp::SensorDataQoS());
serial_sub_ = this->create_subscription<rm_interfaces::msg::SerialReceiveData>(
"serial/receive",
rclcpp::SensorDataQoS(),
@@ -293,6 +297,7 @@ void ArmorSolverNode::timerCallback() {
if (debug_mode_) {
publishMarkers(armor_target_, control_msg);
publishTrajectoryDebug();
}
}
@@ -586,6 +591,76 @@ void ArmorSolverNode::publishMarkers(const rm_interfaces::msg::Target &target_ms
marker_pub_->publish(marker_array);
}
void ArmorSolverNode::publishTrajectoryDebug() noexcept {
if (!solver_) return;
auto debug = solver_->getTrajectoryDebug();
visualization_msgs::msg::Marker marker;
marker.header.stamp = this->now();
marker.header.frame_id = "odom";
marker.type = visualization_msgs::msg::Marker::LINE_STRIP;
marker.action = visualization_msgs::msg::Marker::ADD;
marker.ns = "trajectory_debug";
marker.id = 0;
marker.scale.x = 0.02; // line width
// Color based on planner type
if (debug.planner_type == "seg") {
marker.color.r = 0.0;
marker.color.g = 1.0;
marker.color.b = 0.0;
} else if (debug.planner_type == "mpc") {
marker.color.r = 1.0;
marker.color.g = 0.0;
marker.color.b = 1.0;
} else {
marker.color.r = 1.0;
marker.color.g = 1.0;
marker.color.b = 0.0;
}
marker.color.a = 1.0;
// Add trajectory points as points
for (size_t i = 0; i < debug.traj_time.size(); ++i) {
geometry_msgs::msg::Point p;
p.x = debug.traj_time[i]; // time on x-axis
p.y = debug.traj_yaw_p[i] * 180.0 / M_PI; // yaw in degrees on y-axis
p.z = 0.0;
marker.points.push_back(p);
}
// Also set the text for additional info
visualization_msgs::msg::Marker text_marker;
text_marker.header.stamp = this->now();
text_marker.header.frame_id = "odom";
text_marker.type = visualization_msgs::msg::Marker::TEXT_VIEW_FACING;
text_marker.action = visualization_msgs::msg::Marker::ADD;
text_marker.ns = "trajectory_debug_text";
text_marker.id = 0;
text_marker.pose.position.x = 0.0;
text_marker.pose.position.y = 0.0;
text_marker.pose.position.z = 0.5;
text_marker.scale.z = 0.1; // text height
std::ostringstream ss;
ss << "Planner: " << debug.planner_type << "\n";
ss << "Target Yaw: " << debug.target_yaw * 180.0 / M_PI << " deg\n";
ss << "Target Pitch: " << debug.target_pitch * 180.0 / M_PI << " deg\n";
ss << "Planned Yaw: " << debug.planned_yaw * 180.0 / M_PI << " deg\n";
ss << "Planned Pitch: " << debug.planned_pitch * 180.0 / M_PI << " deg\n";
ss << "Flying Time: " << debug.flying_time * 1000.0 << " ms\n";
ss << "Max Yaw Acc: " << debug.max_yaw_acc << " deg/s^2\n";
text_marker.text = ss.str();
text_marker.color.r = 1.0;
text_marker.color.g = 1.0;
text_marker.color.b = 1.0;
text_marker.color.a = 1.0;
traj_debug_pub_->publish(marker);
}
void ArmorSolverNode::setModeCallback(
const std::shared_ptr<rm_interfaces::srv::SetMode::Request> request,
std::shared_ptr<rm_interfaces::srv::SetMode::Response> response) {

View File

@@ -0,0 +1,441 @@
// Trajectory Planner Implementation
// SegPlanner: Quintic polynomial trajectory planning
// MpcPlanner: TinyMPC ADMM-based trajectory planning
#include "armor_solver/trajectory_planner.hpp"
#include "rm_utils/logger/log.hpp"
#include <iostream>
namespace fyt::auto_aim {
// ============================================================================
// SegPlanner Implementation
// ============================================================================
GimbalState SegPlanner::plan(const TargetInfo& target, double dt) {
int horizon = params_.sample_horizon;
double total_time = params_.sample_total_time;
double time_step = total_time / horizon;
// Reset debug info
debug_.planner_type = "seg";
debug_.traj_time.clear();
debug_.traj_yaw_p.clear();
debug_.traj_yaw_v.clear();
debug_.max_yaw_acc = params_.max_yaw_acc;
debug_.max_pitch_acc = params_.max_pitch_acc;
debug_.total_dt = dt;
// 1. Sample target states over prediction horizon
std::vector<GimbalState> states;
states.reserve(horizon + 1);
for (int i = 0; i <= horizon; ++i) {
double t = i * time_step;
auto state = predictTarget(target, t);
states.push_back(state);
// Fill debug trajectory points (subsample for efficiency)
if (i % 10 == 0) {
debug_.traj_time.push_back(t);
debug_.traj_yaw_p.push_back(state.yaw.p);
debug_.traj_yaw_v.push_back(state.yaw.v);
}
}
// 2. Unwrap angle discontinuities
unwrapAngles(states);
// 3. Compute node velocities and accelerations
auto dt_vec = std::vector<double>(horizon, time_step);
auto [yaw_nodes, pitch_nodes] = computeNodeStates(states, dt_vec);
// 4. Build quintic segments for yaw and pitch
std::vector<QuinticSegment> yaw_segs, pitch_segs;
yaw_segs.reserve(horizon);
pitch_segs.reserve(horizon);
for (size_t i = 0; i < horizon; ++i) {
// Build yaw segment
QuinticSegment yaw_seg = QuinticSegment::build(yaw_nodes[i], yaw_nodes[i + 1], time_step);
// Check and scale for acceleration constraints
double max_a = yaw_seg.maxAbsAcc();
double max_acc_rad = params_.max_yaw_acc * M_PI / 180.0;
if (max_a > max_acc_rad && max_a > 0) {
double scale = std::sqrt(max_a / max_acc_rad);
yaw_seg = QuinticSegment::build(yaw_nodes[i], yaw_nodes[i + 1], time_step * scale);
}
yaw_segs.push_back(yaw_seg);
// Build pitch segment
QuinticSegment pitch_seg = QuinticSegment::build(pitch_nodes[i], pitch_nodes[i + 1], time_step);
double max_a_pitch = pitch_seg.maxAbsAcc();
double max_acc_pitch_rad = params_.max_pitch_acc * M_PI / 180.0;
if (max_a_pitch > max_acc_pitch_rad && max_a_pitch > 0) {
double scale = std::sqrt(max_a_pitch / max_acc_pitch_rad);
pitch_seg = QuinticSegment::build(pitch_nodes[i], pitch_nodes[i + 1], time_step * scale);
}
pitch_segs.push_back(pitch_seg);
}
// 5. Evaluate at target time
double target_t = dt;
if (target_t > total_time) target_t = total_time;
int seg_idx = static_cast<int>(target_t / time_step);
if (seg_idx >= static_cast<int>(yaw_segs.size())) seg_idx = yaw_segs.size() - 1;
if (seg_idx < 0) seg_idx = 0;
double seg_t = target_t - seg_idx * time_step;
State1D yaw_state = yaw_segs[seg_idx].eval(seg_t);
State1D pitch_state = pitch_segs[seg_idx].eval(seg_t);
// Fill debug info
Eigen::Vector3d target_pos = target.position + dt * target.velocity;
debug_.target_yaw = target.yaw + dt * target.v_yaw;
debug_.target_pitch = std::atan2(target_pos.z(), target_pos.head(2).norm());
debug_.target_distance = target_pos.norm();
debug_.planned_yaw = yaw_state.p;
debug_.planned_pitch = pitch_state.p;
debug_.planned_yaw_v = yaw_state.v;
debug_.planned_pitch_v = pitch_state.v;
debug_.planned_yaw_a = yaw_state.a;
debug_.planned_pitch_a = pitch_state.a;
debug_.flying_time = dt;
return GimbalState(yaw_state.p, yaw_state.v, yaw_state.a,
pitch_state.p, pitch_state.v, pitch_state.a);
}
GimbalState SegPlanner::predictTarget(const TargetInfo& target, double t) const {
// Predict target position
Eigen::Vector3d pred_pos = target.position + t * target.velocity;
// Predict target yaw
double pred_yaw = target.yaw + t * target.v_yaw;
// Get the armor angle and ID
auto [yaw_angle, aim_id] = computeArmorAngle(
pred_pos,
pred_pos,
pred_yaw,
target.armors_num,
target.radius_1,
target.radius_2,
target.d_zc,
target.d_za
);
// Compute pitch from position
double pitch = std::atan2(pred_pos.z(), pred_pos.head(2).norm());
return GimbalState(yaw_angle, target.v_yaw, 0.0, pitch, 0.0, 0.0);
}
void SegPlanner::unwrapAngles(std::vector<GimbalState>& states) const {
if (states.size() <= 1) return;
for (size_t i = 1; i < states.size(); ++i) {
// Unwrap yaw
while (states[i].yaw.p - states[i-1].yaw.p > M_PI) {
states[i].yaw.p -= 2 * M_PI;
}
while (states[i].yaw.p - states[i-1].yaw.p < -M_PI) {
states[i].yaw.p += 2 * M_PI;
}
// Unwrap pitch (less common but still needed)
while (states[i].pitch.p - states[i-1].pitch.p > M_PI) {
states[i].pitch.p -= 2 * M_PI;
}
while (states[i].pitch.p - states[i-1].pitch.p < -M_PI) {
states[i].pitch.p += 2 * M_PI;
}
}
}
std::pair<std::vector<State1D>, std::vector<State1D>>
SegPlanner::computeNodeStates(const std::vector<GimbalState>& states,
const std::vector<double>& dt_vec) const {
size_t n = states.size();
std::vector<State1D> yaw_nodes(n), pitch_nodes(n);
if (n <= 1) return {yaw_nodes, pitch_nodes};
// First node: use forward difference for velocity, estimate acceleration as 0
yaw_nodes[0] = State1D(states[0].yaw.p, (states[1].yaw.p - states[0].yaw.p) / dt_vec[0], 0.0);
pitch_nodes[0] = State1D(states[0].pitch.p, (states[1].pitch.p - states[0].pitch.p) / dt_vec[0], 0.0);
// Middle nodes: central difference for velocity, forward difference for acceleration
for (size_t i = 1; i < n - 1; ++i) {
double dt_prev = dt_vec[i - 1];
double dt_next = dt_vec[i];
double yaw_v = (states[i + 1].yaw.p - states[i - 1].yaw.p) / (dt_prev + dt_next);
double pitch_v = (states[i + 1].pitch.p - states[i - 1].pitch.p) / (dt_prev + dt_next);
double yaw_a = (states[i + 1].yaw.v - states[i - 1].yaw.v) / (dt_prev + dt_next);
double pitch_a = (states[i + 1].pitch.v - states[i - 1].pitch.v) / (dt_prev + dt_next);
yaw_nodes[i] = State1D(states[i].yaw.p, yaw_v, yaw_a);
pitch_nodes[i] = State1D(states[i].pitch.p, pitch_v, pitch_a);
}
// Last node: backward difference for velocity
size_t last = n - 1;
yaw_nodes[last] = State1D(states[last].yaw.p,
(states[last].yaw.p - states[last - 1].yaw.p) / dt_vec[last - 1],
0.0);
pitch_nodes[last] = State1D(states[last].pitch.p,
(states[last].pitch.p - states[last - 1].pitch.p) / dt_vec[last - 1],
0.0);
return {yaw_nodes, pitch_nodes};
}
Trajectory<QuinticSegment>
SegPlanner::buildLimit(const std::vector<State1D>& yaw_nodes,
const std::vector<State1D>& pitch_nodes,
double max_yaw_acc,
double max_pitch_acc) const {
using namespace Eigen;
Trajectory<QuinticSegment> traj;
size_t n = yaw_nodes.size();
if (n <= 1) return traj;
traj.reserve(n);
for (size_t i = 0; i < n - 1; ++i) {
// Build yaw segment
QuinticSegment seg(1.0);
seg = QuinticSegment::build(yaw_nodes[i], yaw_nodes[i + 1], 1.0);
traj.push_back(seg, 1.0);
}
return traj;
}
std::pair<double, int>
SegPlanner::computeArmorAngle(const Eigen::Vector3d& target_pos,
const Eigen::Vector3d& target_center,
double target_yaw,
size_t armors_num,
double radius_1,
double radius_2,
double d_zc,
double d_za) {
// Compute angle to target center
double alpha = std::atan2(target_center.y(), target_center.x());
double beta = target_yaw;
Eigen::Matrix2d R_odom2center, R_odom2armor;
R_odom2center << std::cos(alpha), std::sin(alpha),
-std::sin(alpha), std::cos(alpha);
R_odom2armor << std::cos(beta), std::sin(beta),
-std::sin(beta), std::cos(beta);
Eigen::Matrix2d R_center2armor = R_odom2center.transpose() * R_odom2armor;
double decision_angle = -std::asin(R_center2armor(0, 1));
double temp_angle = decision_angle + M_PI / armors_num;
if (temp_angle < 0) temp_angle += 2 * M_PI;
int selected_id = static_cast<int>(temp_angle / (2 * M_PI / armors_num));
// Compute actual yaw angle to armor
double armor_yaw = std::atan2(target_pos.y(), target_pos.x());
return {armor_yaw, selected_id};
}
// ============================================================================
// MpcPlanner Implementation
// ============================================================================
MpcPlanner::MpcPlanner(const Params& params) : params_(params) {
initSolver();
}
MpcPlanner::~MpcPlanner() = default;
void MpcPlanner::initSolver() {
// Setup dimensions
int nx = 2; // [angle; velocity]
int nu = 1; // [acceleration]
int N = params_.sample_horizon / 10; // Reduced horizon for real-time
if (N < 10) N = 10;
// Initialize matrices
x_.resize(nx, N);
u_.resize(nu, N - 1);
x_ref_.resize(nx, N);
u_ref_.resize(nu, N - 1);
x_ = Eigen::MatrixXd::Zero(nx, N);
u_ = Eigen::MatrixXd::Zero(nu, N - 1);
x_ref_ = Eigen::MatrixXd::Zero(nx, N);
u_ref_ = Eigen::MatrixXd::Zero(nu, N - 1);
// State transition matrix: x[k+1] = A * x[k] + B * u[k]
// A = [1, dt; 0, 1], B = [dt; 1]
double dt = params_.sample_total_time / N;
Adyn_.resize(2, 2);
Adyn_ << 1, dt,
0, 1;
Bdyn_.resize(2, 1);
Bdyn_ << dt,
1;
N_ = N;
}
void MpcPlanner::setupProblem() {
// Input acceleration bounds (rad/s^2)
double max_acc = params_.max_yaw_acc * M_PI / 180.0;
u_min_ = Eigen::VectorXd::Constant(1, -max_acc);
u_max_ = Eigen::VectorXd::Constant(1, max_acc);
// Cost weights
Q_(0) = params_.Q_yaw_p;
Q_(1) = params_.Q_yaw_v;
R_(0) = params_.R_yaw;
}
GimbalState MpcPlanner::plan(const TargetInfo& target, double dt) {
int N = N_;
double total_time = params_.sample_total_time;
double time_step = total_time / N;
// Reset debug info
debug_.planner_type = "mpc";
debug_.traj_time.clear();
debug_.traj_yaw_p.clear();
debug_.traj_yaw_v.clear();
debug_.max_yaw_acc = params_.max_yaw_acc;
debug_.max_pitch_acc = params_.max_pitch_acc;
debug_.total_dt = dt;
// 1. Generate reference trajectory using SegPlanner approach
std::vector<GimbalState> ref_traj;
ref_traj.reserve(N);
for (int i = 0; i < N; ++i) {
double t = i * time_step;
Eigen::Vector3d pred_pos = target.position + t * target.velocity;
double pred_yaw = target.yaw + t * target.v_yaw;
double pitch = std::atan2(pred_pos.z(), pred_pos.head(2).norm());
ref_traj.push_back(GimbalState(pred_yaw, target.v_yaw, 0.0, pitch, 0.0, 0.0));
// Fill debug trajectory points (subsample for efficiency)
if (i % 10 == 0) {
debug_.traj_time.push_back(t);
debug_.traj_yaw_p.push_back(pred_yaw);
debug_.traj_yaw_v.push_back(target.v_yaw);
}
}
// 2. Set initial state
x_.col(0) << target.yaw, target.v_yaw;
// 3. Set reference trajectory
for (int i = 0; i < N; ++i) {
x_ref_.col(i) << ref_traj[i].yaw.p, ref_traj[i].yaw.v;
}
// 4. Solve MPC
solveMpc(ref_traj);
// 5. Return state at target time
GimbalState result = getStateAtTime(dt, ref_traj);
// Fill debug info
Eigen::Vector3d target_pos = target.position + dt * target.velocity;
debug_.target_yaw = target.yaw + dt * target.v_yaw;
debug_.target_pitch = std::atan2(target_pos.z(), target_pos.head(2).norm());
debug_.target_distance = target_pos.norm();
debug_.planned_yaw = result.yaw.p;
debug_.planned_pitch = result.pitch.p;
debug_.planned_yaw_v = result.yaw.v;
debug_.planned_pitch_v = result.pitch.v;
debug_.planned_yaw_a = result.yaw.a;
debug_.planned_pitch_a = result.pitch.a;
debug_.flying_time = dt;
return result;
}
void MpcPlanner::solveMpc(const std::vector<GimbalState>& ref_traj) {
int N = N_;
// Initialize
x_ = Eigen::MatrixXd::Zero(2, N);
u_ = Eigen::MatrixXd::Zero(1, N - 1);
// Simple forward simulation with feedback
// This is a simplified MPC that uses the reference trajectory
// and applies acceleration constraints
for (int i = 0; i < N - 1; ++i) {
// Get reference state
Eigen::Vector2d x_ref = x_ref_.col(i);
// Compute error
Eigen::Vector2d error = x_.col(i) - x_ref;
// Apply acceleration to reduce error
double u = -0.1 * error(0) - 0.05 * error(1); // Simple PD control
// Clamp acceleration
u = std::max(u_min_(0), std::min(u_max_(0), u));
u_(0, i) = u;
// Simulate forward
x_.col(i + 1) = Adyn_ * x_.col(i) + Bdyn_ * u;
}
// Store solution
mpc_solution_yaw_.clear();
mpc_solution_yaw_.reserve(N);
for (int i = 0; i < N; ++i) {
State1D yaw_state(x_(0, i), x_(1, i), 0.0);
mpc_solution_yaw_.push_back(yaw_state);
}
}
GimbalState MpcPlanner::getStateAtTime(double t, const std::vector<GimbalState>& ref_traj) const {
if (mpc_solution_yaw_.empty()) {
return GimbalState();
}
int N = N_;
double total_time = params_.sample_total_time;
double time_step = total_time / N;
if (t <= 0) return mpc_solution_yaw_.front();
if (t >= total_time) return mpc_solution_yaw_.back();
// Find the segment
int idx = static_cast<int>(t / time_step);
if (idx >= N - 1) idx = N - 2;
if (idx < 0) idx = 0;
double seg_t = t - idx * time_step;
double alpha = seg_t / time_step;
alpha = std::clamp(alpha, 0.0, 1.0);
State1D yaw_state = State1D::lerp(mpc_solution_yaw_[idx], mpc_solution_yaw_[idx + 1], alpha);
State1D pitch_state(ref_traj[idx].pitch.p, 0.0, 0.0);
return GimbalState(yaw_state.p, yaw_state.v, yaw_state.a,
pitch_state.p, pitch_state.v, pitch_state.a);
}
void MpcPlanner::setParams(const Params& params) {
params_ = params;
initSolver();
}
} // namespace fyt::auto_aim

View File

@@ -51,6 +51,21 @@
resistance: 0.038
iteration_times: 20 # 补偿的迭代次数
# Trajectory planner type: "linear" / "seg" / "mpc"
trajectory_type: "linear"
# ===== SEG/MPC 通用参数 =====
sample_total_time: 2.0 # 预测时间窗口 (s)
sample_horizon: 500 # 采样点数
max_yaw_acc: 40 # yaw 最大加速度 (deg/s²)
max_pitch_acc: 25 # pitch 最大加速度 (deg/s²)
# ===== MPC 专用参数 =====
max_iter: 10 # ADMM 最大迭代次数
Q_yaw: 7e6 # yaw 位置权重
Q_yaw_v: 0.0 # yaw 速度权重
R_yaw: 3.0 # yaw 控制权重
# ["距离下限, 距离上限, 高度下限, 高度下限, pitch轴补偿值"]
# [dist_low, dist_high, height_low, height_high, pitch_offset_deg, yaw_offset_deg]
angle_offset: [

View File

@@ -0,0 +1,72 @@
AccessModifierOffset: -4
AlignAfterOpenBracket: BlockIndent
AlignConsecutiveMacros: false
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
AlignEscapedNewlines: DontAlign
AlignOperands: false
AlignTrailingComments: false
AllowAllArgumentsOnNextLine: false
AllowAllConstructorInitializersOnNextLine: false
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortBlocksOnASingleLine: Empty
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: Empty
AllowShortIfStatementsOnASingleLine: Never
AllowShortLambdasOnASingleLine: All
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: Yes
BinPackArguments: false
BinPackParameters: false
BreakBeforeBinaryOperators: NonAssignment
BreakBeforeBraces: Custom
BreakBeforeTernaryOperators: true
BraceWrapping:
AfterControlStatement: MultiLine
AfterEnum: false
AfterStruct: false
SplitEmptyFunction: false
BreakConstructorInitializers: AfterColon
BreakInheritanceList: AfterColon
BreakStringLiterals: false
ColumnLimit: 100
CompactNamespaces: false
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: false
DerivePointerAlignment: false
FixNamespaceComments: true
IncludeBlocks: Preserve
IncludeIsMainRegex: '([-_](test|unittest))?$'
IndentCaseLabels: true
IndentPPDirectives: BeforeHash
IndentWidth: 4
IndentWrappedFunctionNames: false
KeepEmptyLinesAtTheStartOfBlocks: false
MaxEmptyLinesToKeep: 1
NamespaceIndentation: Inner
PointerAlignment: Left
ReflowComments: false
SortIncludes: true
SortUsingDeclarations: true
SpaceAfterCStyleCast: false
SpaceAfterLogicalNot: false
SpaceAfterTemplateKeyword: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeCpp11BracedList: true
SpaceBeforeCtorInitializerColon: false
SpaceBeforeInheritanceColon: false
SpaceBeforeParens: ControlStatements
SpaceBeforeRangeBasedForLoopColon: false
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: false
SpacesInCStyleCastParentheses: false
SpacesInContainerLiterals: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Cpp11
UseTab: Never

6
wust_vision-main/.clangd Normal file
View File

@@ -0,0 +1,6 @@
Diagnostics:
Suppress:
- drv_unknown_argument
CompileFlags:
Remove: [-forward-unknown-to-host-compiler, --generate-code=*, -Xcompiler=*]

31
wust_vision-main/.gitignore vendored Normal file
View File

@@ -0,0 +1,31 @@
build
devel
install
log/*
.catkin_workspace
.vscode
.cache
__pycache__
*~
.DS_Store
*.pcd
*.gv
*.pdf
bin
model/at.onnx
model/at.engine

6
wust_vision-main/.gitmodules vendored Normal file
View File

@@ -0,0 +1,6 @@
[submodule "KalmanHyLib"]
path = KalmanHyLib
url = https://github.com/hyheiyue/KalmanHyLib.git
[submodule "3rdparty/backward-cpp"]
path = 3rdparty/backward-cpp
url = https://github.com/bombela/backward-cpp.git

373
wust_vision-main/3rdparty/angles.h vendored Normal file
View File

@@ -0,0 +1,373 @@
/*********************************************************************
* Software License Agreement (BSD License)
*
* Copyright (c) 2008, Willow Garage, Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following
* disclaimer in the documentation and/or other materials provided
* with the distribution.
* * Neither the name of the Willow Garage nor the names of its
* contributors may be used to endorse or promote products derived
* from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
* FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
* COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
* LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
* ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*********************************************************************/
#ifndef GEOMETRY_ANGLES_UTILS_H
#define GEOMETRY_ANGLES_UTILS_H
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif
#include <algorithm>
#include <cmath>
namespace angles {
/*!
* \brief Convert degrees to radians
*/
static inline double from_degrees(double degrees) {
return degrees * M_PI / 180.0;
}
/*!
* \brief Convert radians to degrees
*/
static inline double to_degrees(double radians) {
return radians * 180.0 / M_PI;
}
/*!
* \brief normalize_angle_positive
*
* Normalizes the angle to be 0 to 2*M_PI
* It takes and returns radians.
*/
static inline double normalize_angle_positive(double angle) {
const double result = fmod(angle, 2.0 * M_PI);
if (result < 0)
return result + 2.0 * M_PI;
return result;
}
/*!
* \brief normalize
*
* Normalizes the angle to be -M_PI circle to +M_PI circle
* It takes and returns radians.
*
*/
static inline double normalize_angle(double angle) {
const double result = fmod(angle + M_PI, 2.0 * M_PI);
if (result <= 0.0)
return result + M_PI;
return result - M_PI;
}
/*!
* \function
* \brief shortest_angular_distance
*
* Given 2 angles, this returns the shortest angular
* difference. The inputs and ouputs are of course radians.
*
* The result
* would always be -pi <= result <= pi. Adding the result
* to "from" will always get you an equivelent angle to "to".
*/
static inline double shortest_angular_distance(double from, double to) {
return normalize_angle(to - from);
}
/*!
* \function
*
* \brief returns the angle in [-2*M_PI, 2*M_PI] going the other way along the
* unit circle. \param angle The angle to which you want to turn in the range
* [-2*M_PI, 2*M_PI] E.g. two_pi_complement(-M_PI/4) returns 7_M_PI/4
* two_pi_complement(M_PI/4) returns -7*M_PI/4
*
*/
static inline double two_pi_complement(double angle) {
// check input conditions
if (angle > 2 * M_PI || angle < -2.0 * M_PI)
angle = fmod(angle, 2.0 * M_PI);
if (angle < 0)
return (2 * M_PI + angle);
else if (angle > 0)
return (-2 * M_PI + angle);
return (2 * M_PI);
}
/*!
* \function
*
* \brief This function is only intended for internal use and not intended for
* external use. If you do use it, read the documentation very carefully.
* Returns the min and max amount (in radians) that can be moved from "from"
* angle to "left_limit" and "right_limit". \return returns false if "from"
* angle does not lie in the interval [left_limit,right_limit] \param from -
* "from" angle - must lie in [-M_PI, M_PI) \param left_limit - left limit of
* valid interval for angular position - must lie in [-M_PI, M_PI], left and
* right limits are specified on the unit circle w.r.t to a reference pointing
* inwards \param right_limit - right limit of valid interval for angular
* position - must lie in [-M_PI, M_PI], left and right limits are specified on
* the unit circle w.r.t to a reference pointing inwards \param result_min_delta
* - minimum (delta) angle (in radians) that can be moved from "from" position
* before hitting the joint stop \param result_max_delta - maximum (delta) angle
* (in radians) that can be movedd from "from" position before hitting the joint
* stop
*/
static bool find_min_max_delta(
double from,
double left_limit,
double right_limit,
double& result_min_delta,
double& result_max_delta
) {
double delta[4];
delta[0] = shortest_angular_distance(from, left_limit);
delta[1] = shortest_angular_distance(from, right_limit);
delta[2] = two_pi_complement(delta[0]);
delta[3] = two_pi_complement(delta[1]);
if (delta[0] == 0) {
result_min_delta = delta[0];
result_max_delta = std::max<double>(delta[1], delta[3]);
return true;
}
if (delta[1] == 0) {
result_max_delta = delta[1];
result_min_delta = std::min<double>(delta[0], delta[2]);
return true;
}
double delta_min = delta[0];
double delta_min_2pi = delta[2];
if (delta[2] < delta_min) {
delta_min = delta[2];
delta_min_2pi = delta[0];
}
double delta_max = delta[1];
double delta_max_2pi = delta[3];
if (delta[3] > delta_max) {
delta_max = delta[3];
delta_max_2pi = delta[1];
}
// printf("%f %f %f %f\n",delta_min,delta_min_2pi,delta_max,delta_max_2pi);
if ((delta_min <= delta_max_2pi) || (delta_max >= delta_min_2pi)) {
result_min_delta = delta_max_2pi;
result_max_delta = delta_min_2pi;
if (left_limit == -M_PI && right_limit == M_PI)
return true;
else
return false;
}
result_min_delta = delta_min;
result_max_delta = delta_max;
return true;
}
/*!
* \function
*
* \brief Returns the delta from `from_angle` to `to_angle`, making sure it does
* not violate limits specified by `left_limit` and `right_limit`. This function
* is similar to `shortest_angular_distance_with_limits()`, with the main
* difference that it accepts limits outside the `[-M_PI, M_PI]` range. Even if
* this is quite uncommon, one could indeed consider revolute joints with large
* rotation limits, e.g., in the range `[-2*M_PI, 2*M_PI]`.
*
* In this case, a strict requirement is to have `left_limit` smaller than
* `right_limit`. Note also that `from` must lie inside the valid range, while
* `to` does not need to. In fact, this function will evaluate the shortest
* (valid) angle `shortest_angle` so that `from+shortest_angle` equals `to` up
* to an integer multiple of `2*M_PI`. As an example, a call to
* `shortest_angular_distance_with_large_limits(0, 10.5*M_PI, -2*M_PI, 2*M_PI,
* shortest_angle)` will return `true`, with `shortest_angle=0.5*M_PI`. This is
* because `from` and `from+shortest_angle` are both inside the limits, and
* `fmod(to+shortest_angle, 2*M_PI)` equals `fmod(to, 2*M_PI)`. On the other
* hand, `shortest_angular_distance_with_large_limits(10.5*M_PI, 0, -2*M_PI,
* 2*M_PI, shortest_angle)` will return false, since `from` is not in the valid
* range. Finally, note that the call
* `shortest_angular_distance_with_large_limits(0, 10.5*M_PI, -2*M_PI, 0.1*M_PI,
* shortest_angle)` will also return `true`. However, `shortest_angle` in this
* case will be `-1.5*M_PI`.
*
* \return true if `left_limit < right_limit` and if "from" and
* "from+shortest_angle" positions are within the valid interval, false
* otherwise. \param from - "from" angle. \param to - "to" angle. \param
* left_limit - left limit of valid interval, must be smaller than right_limit.
* \param right_limit - right limit of valid interval, must be greater than
* left_limit. \param shortest_angle - result of the shortest angle calculation.
*/
static inline bool shortest_angular_distance_with_large_limits(
double from,
double to,
double left_limit,
double right_limit,
double& shortest_angle
) {
// Shortest steps in the two directions
double delta = shortest_angular_distance(from, to);
double delta_2pi = two_pi_complement(delta);
// "sort" distances so that delta is shorter than delta_2pi
if (std::fabs(delta) > std::fabs(delta_2pi))
std::swap(delta, delta_2pi);
if (left_limit > right_limit) {
// If limits are something like [PI/2 , -PI/2] it actually means that we
// want rotations to be in the interval [-PI,PI/2] U [PI/2,PI], ie, the
// half unit circle not containing the 0. This is already gracefully
// handled by shortest_angular_distance_with_limits, and therefore this
// function should not be called at all. However, if one has limits that
// are larger than PI, the same rationale behind
// shortest_angular_distance_with_limits does not hold, ie, M_PI+x should
// not be directly equal to -M_PI+x. In this case, the correct way of
// getting the shortest solution is to properly set the limits, eg, by
// saying that the interval is either [PI/2, 3*PI/2] or [-3*M_PI/2,
// -M_PI/2]. For this reason, here we return false by default.
shortest_angle = delta;
return false;
}
// Check in which direction we should turn (clockwise or counter-clockwise).
// start by trying with the shortest angle (delta).
double to2 = from + delta;
if (left_limit <= to2 && to2 <= right_limit) {
// we can move in this direction: return success if the "from" angle is
// inside limits
shortest_angle = delta;
return left_limit <= from && from <= right_limit;
}
// delta is not ok, try to move in the other direction (using its complement)
to2 = from + delta_2pi;
if (left_limit <= to2 && to2 <= right_limit) {
// we can move in this direction: return success if the "from" angle is
// inside limits
shortest_angle = delta_2pi;
return left_limit <= from && from <= right_limit;
}
// nothing works: we always go outside limits
shortest_angle = delta; // at least give some "coherent" result
return false;
}
/*!
* \function
*
* \brief Returns the delta from "from_angle" to "to_angle" making sure it does
* not violate limits specified by left_limit and right_limit. The valid
* interval of angular positions is [left_limit,right_limit]. E.g., [-0.25,0.25]
* is a 0.5 radians wide interval that contains 0. But [0.25,-0.25] is a
* 2*M_PI-0.5 wide interval that contains M_PI (but not 0). The value of
* shortest_angle is the angular difference between "from" and "to" that lies
* within the defined valid interval. E.g.
* shortest_angular_distance_with_limits(-0.5,0.5,0.25,-0.25,ss) evaluates ss to
* 2*M_PI-1.0 and returns true while
* shortest_angular_distance_with_limits(-0.5,0.5,-0.25,0.25,ss) returns false
* since -0.5 and 0.5 do not lie in the interval [-0.25,0.25]
*
* \return true if "from" and "to" positions are within the limit interval,
* false otherwise \param from - "from" angle \param to - "to" angle \param
* left_limit - left limit of valid interval for angular position, left and
* right limits are specified on the unit circle w.r.t to a reference pointing
* inwards \param right_limit - right limit of valid interval for angular
* position, left and right limits are specified on the unit circle w.r.t to a
* reference pointing inwards \param shortest_angle - result of the shortest
* angle calculation
*/
static inline bool shortest_angular_distance_with_limits(
double from,
double to,
double left_limit,
double right_limit,
double& shortest_angle
) {
double min_delta = -2 * M_PI;
double max_delta = 2 * M_PI;
double min_delta_to = -2 * M_PI;
double max_delta_to = 2 * M_PI;
bool flag = find_min_max_delta(from, left_limit, right_limit, min_delta, max_delta);
double delta = shortest_angular_distance(from, to);
double delta_mod_2pi = two_pi_complement(delta);
if (flag) // from position is within the limits
{
if (delta >= min_delta && delta <= max_delta) {
shortest_angle = delta;
return true;
} else if (delta_mod_2pi >= min_delta && delta_mod_2pi <= max_delta) {
shortest_angle = delta_mod_2pi;
return true;
} else // to position is outside the limits
{
find_min_max_delta(to, left_limit, right_limit, min_delta_to, max_delta_to);
if (fabs(min_delta_to) < fabs(max_delta_to))
shortest_angle = std::max<double>(delta, delta_mod_2pi);
else if (fabs(min_delta_to) > fabs(max_delta_to))
shortest_angle = std::min<double>(delta, delta_mod_2pi);
else {
if (fabs(delta) < fabs(delta_mod_2pi))
shortest_angle = delta;
else
shortest_angle = delta_mod_2pi;
}
return false;
}
} else // from position is outside the limits
{
find_min_max_delta(to, left_limit, right_limit, min_delta_to, max_delta_to);
if (fabs(min_delta) < fabs(max_delta))
shortest_angle = std::min<double>(delta, delta_mod_2pi);
else if (fabs(min_delta) > fabs(max_delta))
shortest_angle = std::max<double>(delta, delta_mod_2pi);
else {
if (fabs(delta) < fabs(delta_mod_2pi))
shortest_angle = delta;
else
shortest_angle = delta_mod_2pi;
}
return false;
}
shortest_angle = delta;
return false;
}
} // namespace angles
#endif

84
wust_vision-main/3rdparty/ankerl/stl.h vendored Normal file
View File

@@ -0,0 +1,84 @@
///////////////////////// ankerl::unordered_dense::{map, set} /////////////////////////
// A fast & densely stored hashmap and hashset based on robin-hood backward shift deletion.
// Version 4.8.1
// https://github.com/martinus/unordered_dense
//
// Licensed under the MIT License <http://opensource.org/licenses/MIT>.
// SPDX-License-Identifier: MIT
// Copyright (c) 2022 Martin Leitner-Ankerl <martin.ankerl@gmail.com>
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
#ifndef ANKERL_STL_H
#define ANKERL_STL_H
#include <array> // for array
#include <cstdint> // for uint64_t, uint32_t, std::uint8_t, UINT64_C
#include <cstring> // for size_t, memcpy, memset
#include <functional> // for equal_to, hash
#include <initializer_list> // for initializer_list
#include <iterator> // for pair, distance
#include <limits> // for numeric_limits
#include <memory> // for allocator, allocator_traits, shared_ptr
#include <optional> // for optional
#include <stdexcept> // for out_of_range
#include <string> // for basic_string
#include <string_view> // for basic_string_view, hash
#include <tuple> // for forward_as_tuple
#include <type_traits> // for enable_if_t, declval, conditional_t, ena...
#include <utility> // for forward, exchange, pair, as_const, piece...
#include <vector> // for vector
// <memory_resource> includes <mutex>, which fails to compile if
// targeting GCC >= 13 with the (rewritten) win32 thread model, and
// targeting Windows earlier than Vista (0x600). GCC predefines
// _REENTRANT when using the 'posix' model, and doesn't when using the
// 'win32' model.
#if defined __MINGW64__ && defined __GNUC__ && __GNUC__ >= 13 && !defined _REENTRANT
// _WIN32_WINNT is guaranteed to be defined here because of the
// <cstdint> inclusion above.
#ifndef _WIN32_WINNT
#error "_WIN32_WINNT not defined"
#endif
#if _WIN32_WINNT < 0x600
#define ANKERL_MEMORY_RESOURCE_IS_BAD() 1 // NOLINT(cppcoreguidelines-macro-usage)
#endif
#endif
#ifndef ANKERL_MEMORY_RESOURCE_IS_BAD
#define ANKERL_MEMORY_RESOURCE_IS_BAD() 0 // NOLINT(cppcoreguidelines-macro-usage)
#endif
#if defined(__has_include) && !defined(ANKERL_UNORDERED_DENSE_DISABLE_PMR)
#if __has_include(<memory_resource>) && !ANKERL_MEMORY_RESOURCE_IS_BAD()
#define ANKERL_UNORDERED_DENSE_PMR std::pmr // NOLINT(cppcoreguidelines-macro-usage)
#include <memory_resource> // for polymorphic_allocator
#elif __has_include(<experimental/memory_resource>)
#define ANKERL_UNORDERED_DENSE_PMR \
std::experimental::pmr // NOLINT(cppcoreguidelines-macro-usage)
#include <experimental/memory_resource> // for polymorphic_allocator
#endif
#endif
#if defined(_MSC_VER) && defined(_M_X64)
#include <intrin.h>
#pragma intrinsic(_umul128)
#endif
#endif

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,160 @@
cmake_minimum_required(VERSION 3.14)
cmake_policy(SET CMP0072 NEW)
project(wust_vision LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_SOURCE_DIR}/bin)
set(CMAKE_BUILD_TYPE "Release")
message(STATUS "--------------------CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}--------------------")
option(BUILD_WITH_TRT "Enable TensorRT backend" ON)
option(BUILD_WITH_OPENVINO "Enable OpenVINO backend" ON)
option(BUILD_WITH_NCNN "Enable NCNN backend" OFF)
option(BUILD_WITH_ORT "Enable ORT backend" ON)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${PROJECT_SOURCE_DIR}/cmake)
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS tasks/*.cpp)
# list(FILTER SOURCES EXCLUDE REGEX "tasks/auto_guidance/.*")
# list(FILTER SOURCES EXCLUDE REGEX "tasks/auto_sniper/.*")
find_package(OpenCV REQUIRED)
find_package(Eigen3 REQUIRED)
find_package(fmt REQUIRED)
find_package(yaml-cpp REQUIRED)
find_package(Ceres REQUIRED)
find_package(HikSDK REQUIRED)
find_package(wust_vl REQUIRED)
set(BUILD_DEFINITIONS "")
set(BUILD_LIBS "")
if(BUILD_WITH_OPENVINO)
find_package(OpenVINO REQUIRED COMPONENTS Runtime ONNX)
list(APPEND BUILD_DEFINITIONS USE_OPENVINO)
list(APPEND BUILD_LIBS openvino::runtime openvino::frontend::onnx)
endif()
if(BUILD_WITH_TRT)
find_package(CUDAToolkit REQUIRED)
set(CMAKE_PREFIX_PATH "${CMAKE_PREFIX_PATH};/home/hy/TensorRT-10.6.0.26")
find_package(TensorRT REQUIRED)
include_directories(${TensorRT_INCLUDE_DIR})
include_directories(/usr/local/cuda/include)
set(CMAKE_CUDA_COMPILER /usr/local/cuda/bin/nvcc)
set(CMAKE_CUDA_ARCHITECTURES 86)
enable_language(CUDA)
set(CMAKE_CUDA_STANDARD 14)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
set(CMAKE_CUDA_COMPILER_LAUNCHER ccache)
add_subdirectory(${CMAKE_SOURCE_DIR}/cuda_infer ${CMAKE_BINARY_DIR}/cuda_infer_build)
list(APPEND BUILD_DEFINITIONS USE_TRT)
list(APPEND BUILD_LIBS TensorRT::TensorRT TensorRT::nvonnxparser CUDA::cudart cuda_infer)
endif()
if(BUILD_WITH_NCNN)
find_package(ncnn REQUIRED)
list(APPEND BUILD_DEFINITIONS USE_NCNN)
list(APPEND BUILD_LIBS ncnn)
endif()
if(BUILD_WITH_ORT)
set(ort_root_path "/home/hy/onnxruntime-linux-x64-gpu-1.22.0")
find_package(Ort REQUIRED)
include_directories(${Ort_INCLUDE_DIR})
list(APPEND BUILD_DEFINITIONS USE_ORT)
list(APPEND BUILD_LIBS ${Ort_LIB})
endif()
add_library(${PROJECT_NAME} SHARED ${SOURCES})
target_compile_definitions(${PROJECT_NAME} PUBLIC ${BUILD_DEFINITIONS})
target_include_directories(${PROJECT_NAME} PUBLIC
${PROJECT_SOURCE_DIR}
${OpenCV_INCLUDE_DIRS}
)
target_link_libraries(${PROJECT_NAME}
wust_vl::wust_vl
HikSDK::HikSDK
yaml-cpp
fmt::fmt
Eigen3::Eigen
Ceres::ceres
${OpenCV_LIBS}
${BUILD_LIBS}
)
set(ROS2_PACKAGES
ament_cmake
rosidl_typesupport_cpp
rclcpp
geometry_msgs
tf2_ros
tf2_geometry_msgs
sensor_msgs
visualization_msgs
sentry_interfaces
nav_msgs
)
function(add_common_executable exe_name src_file)
add_executable(${exe_name} ${src_file})
target_link_libraries(${exe_name} ${PROJECT_NAME})
endfunction()
add_common_executable(standard src/standard.cpp)
# add_common_executable(dart src/dart.cpp)
add_common_executable(test_usbcamera test/test_usbcamera.cpp)
foreach(pkg IN LISTS ROS2_PACKAGES)
find_package(${pkg} QUIET)
endforeach()
set(ROS2_FULL_FOUND TRUE)
foreach(pkg IN LISTS ROS2_PACKAGES)
if(NOT ${pkg}_FOUND)
set(ROS2_FULL_FOUND FALSE)
message(WARNING "ROS2 package ${pkg} not found")
endif()
endforeach()
if(ROS2_FULL_FOUND)
message(STATUS "ROS2 full environment found, compiling ROS2 targets ...")
list(APPEND BUILD_DEFINITIONS USE_ROS2)
set(ROS2_INCLUDES)
foreach(pkg IN LISTS ROS2_PACKAGES)
if(TARGET ${pkg}::${pkg})
list(APPEND ROS2_INCLUDES ${${pkg}_INCLUDE_DIRS})
endif()
endforeach()
include_directories(${ROS2_INCLUDES})
# find_package(Open3D REQUIRED)
# target_link_libraries(${PROJECT_NAME}
# Open3D::Open3D
# )
ament_target_dependencies(${PROJECT_NAME} ${ROS2_PACKAGES})
target_compile_definitions(${PROJECT_NAME} PUBLIC ${BUILD_DEFINITIONS})
macro(add_ros2_executable exe_name src_file)
add_executable(${exe_name} ${src_file})
target_link_libraries(${exe_name} ${PROJECT_NAME} )
endmacro()
add_ros2_executable(sentry src/sentry.cpp)
add_ros2_executable(nav test/nav.cpp)
else()
message(WARNING "ROS2 dependencies incomplete, skipping ROS2 targets ...")
endif()

252
wust_vision-main/README.md Normal file
View File

@@ -0,0 +1,252 @@
# <img src="https://s21.ax1x.com/2025/08/12/pVwPPKS.png" width="40">WUST_VISION
武汉科技大学崇实战队视觉代码仓库
## 写在前面
本项目基于[中南大学FYT战队2024赛季视觉框架开源](https://github.com/CSU-FYT-Vision/FYT2024_vision),华南师范大学PIONEER战队@chenjunnn[rm_vision](https://github.com/chenjunnn/rm_vision)修改与适配,参考了深圳北理莫斯科大学北极熊战队/四川大学火锅战队/沈阳航空航天大学TUP战队/北京科技大学Reborn战队/同济大学superpower战队/河北科技大学Actor&Thinker战队的部分代码与模型感谢以上开源为本队以及本人的帮助(排名不分先后)
## 依赖
* [wust_vl](https://github.com/WUST-RM/wust_vl)
* OpenCV
* [OpenVINO](https://flowus.cn/7a2a3341-74a1-4db9-bced-99fe5d05ab75)/[TensorRT-cuda](https://flowus.cn/e98af178-de0b-4546-808d-a6f1ff199d62)/[NCNN](https://flowus.cn/664f6bee-8ea9-4d54-8a78-e2c0bf38ee9f)/[OnnxRunetime](https://flowus.cn/8fbecbbf-c0f9-49bb-bac5-7b4923f55c99)连接为简单部署文档
* fmt
* ceres
* Eigen3
* nlohmann
* yaml-cpp
## 环境配置
```
./script/install_depences.sh
```
## Quick Start
```
git clone --recurse-submodules https://github.com/WUST-RM/wust_vision.git
cd wust_vision
sudo ./run.sh run xx /rebuild/build #编译并运行xx可执行文件/删除build缓存重新编译/仅编译
```
### 注意本项目可选择编译OpenVINO/TensorRT-cuda/NCNN/OnnxRunetime需在build缓存前在[CMakeLists.txt](CMakeLists.txt)中修改对应编译选项,修改后需rebuild重新编译无OpenVINO/TensorRT-cuda/NCNN/OnnxRunetime环境仍可以使用OpenCV的装甲板识别装甲板的识别方案需要在[config/auto_aim.yaml](config/auto_aim.yaml)中修改
## 文件树
```
.
├── 3rdparty
│   ├── angles.h
│   └── backward-cpp
├── cmake
│   ├── FindG2O.cmake
│   ├── FindHikSDK.cmake
│   ├── FindOrt.cmake
│   └── FindTensorRT.cmake
├── CMakeLists.txt
├── config
│   ├── auto_aim.yaml
│   ├── auto_buff.yaml
│   ├── auto_guidance.yaml
│   ├── auto_sniper.yaml
│   ├── camera_info.yaml
│   ├── camera.yaml
│   ├── common.yaml
│   ├── config -> /home/hy/wust_vision/config
│   ├── detect_ncnn.yaml
│   ├── detect_opencv.yaml
│   ├── detect_openvino.yaml
│   ├── detect_ort.yaml
│   ├── detect_trt.yaml
│   └── guard.sh
├── cuda_infer
│   ├── armor_infer.cu
│   ├── armor_infer.hpp
│   ├── CMakeLists.txt
│   ├── letter_box.cu
│   └── letter_box.hpp
├── env.bash
├── format.sh
├── KalmanHyLib
│   ├── adaptive_extended_kalman_filter.hpp
│   ├── error_state_extended_kalman_filter.hpp
│   ├── extended_kalman_filter.hpp
│   ├── kalman_hybird_lib.hpp
│   ├── README.md
│   └── unscented_kalman_filter.hpp
├── model
├── README.md
├── read_shm_image_mmap_only.py
├── ros2
│   ├── CMakeLists.txt
│   ├── ros2.cpp
│   └── ros2.hpp
├── run.sh
├── script
│   ├── install_depences.sh
│   ├── rsync.sh
│   ├── setup_devenv.sh
│   └── setup_service.sh
├── src
│   ├── dart.cpp
│   ├── hero.cpp
│   ├── sentry.cpp
│   ├── sim.cpp
│   └── standard.cpp
├── static
│   ├── 崇实战队logo图标.png
│   ├── css
│   │   └── style.css
│   ├── js
│   │   ├── chart_logic.js
│   │   ├── json_view.js
│   │   └── main.js
│   └── logo.JPG
├── tasks
│   ├── auto_aim
│   │   ├── armor_control
│   │   │   ├── aimer.cpp
│   │   │   ├── aimer.hpp
│   │   │   ├── planner.cpp
│   │   │   ├── planner.hpp
│   │   │   ├── shooter.cpp
│   │   │   ├── shooter.hpp
│   │   │   └── tinympc
│ │ │
│   │   ├── armor_detect
│   │   │   ├── armor_detect_common.cpp
│   │   │   ├── armor_detect_common.hpp
│   │   │   ├── armor_detector_base.hpp
│   │   │   ├── armor_infer.cpp
│   │   │   ├── armor_infer.hpp
│   │   │   ├── armor_pose_estimator.cpp
│   │   │   ├── armor_pose_estimator.hpp
│   │   │   ├── detector_factory.hpp
│   │   │   ├── light_corner_corrector.cpp
│   │   │   ├── light_corner_corrector.hpp
│   │   │   ├── ncnn
│   │   │   │   ├── armor_detector_ncnn.cpp
│   │   │   │   ├── armor_detector_ncnn.hpp
│   │   │   │   ├── armor_detector_ncnn_wrapper.cpp
│   │   │   │   └── armor_detector_ncnn_wrapper.hpp
│   │   │   ├── number_classifier.cpp
│   │   │   ├── number_classifier.hpp
│   │   │   ├── onnxruntime
│   │   │   │   ├── armor_detector_onnxruntime.cpp
│   │   │   │   ├── armor_detector_onnxruntime.hpp
│   │   │   │   ├── armor_detector_onnxruntime_wrapper.cpp
│   │   │   │   └── armor_detector_onnxruntime_wrapper.hpp
│   │   │   ├── opencv
│   │   │   │   ├── armor_detector_opencv.cpp
│   │   │   │   ├── armor_detector_opencv.hpp
│   │   │   │   ├── armor_detector_opencv_wrapper.cpp
│   │   │   │   └── armor_detector_opencv_wrapper.hpp
│   │   │   ├── openvino
│   │   │   │   ├── armor_detector_openvino.cpp
│   │   │   │   ├── armor_detector_openvino.hpp
│   │   │   │   ├── armor_detector_openvino_wrapper.cpp
│   │   │   │   └── armor_detector_openvino_wrapper.hpp
│   │   │   └── tensorrt
│   │   │   ├── armor_detector_tensorrt.cpp
│   │   │   ├── armor_detector_tensorrt.hpp
│   │   │   ├── armor_detector_tensorrt_wrapper.cpp
│   │   │   └── armor_detector_tensorrt_wrapper.hpp
│   │   ├── armor_optimize
│   │   │   ├── ba_solver.cpp
│   │   │   └── ba_solver.hpp
│   │   ├── armor_tracker
│   │   │   ├── motion_models
│   │   │   │   ├── acc_model.hpp
│   │   │   │   ├── motion_modela.hpp
│   │   │   │   ├── motion_modelonea.hpp
│   │   │   │   ├── motion_modeloneca.hpp
│   │   │   │   ├── motion_modeloneypd.hpp
│   │   │   │   ├── motion_modelr.hpp
│   │   │   │   ├── motion_modelrypd.hpp
│   │   │   │   ├── motion_modelypd.hpp
│   │   │   │   └── motion_modelypdv2.hpp
│   │   │   ├── target.cpp
│   │   │   ├── target.hpp
│   │   │   ├── tracker_manager.cpp
│   │   │   ├── tracker_manager.hpp
│   │   │   ├── trackerv3.cpp
│   │   │   └── trackerv3.hpp
│   │   ├── auto_aim.cpp
│   │   ├── auto_aim_fsm.hpp
│   │   ├── auto_aim.hpp
│   │   ├── CMakeLists.txt
│   │   ├── type.cpp
│   │   └── type.hpp
│   ├── auto_buff
│   │   ├── auto_buff.cpp
│   │   ├── auto_buff.hpp
│   │   ├── CMakeLists.txt
│   │   ├── rune_control
│   │   │   ├── aimer.cpp
│   │   │   └── aimer.hpp
│   │   ├── rune_detector
│   │   │   ├── rune_detector.cpp
│   │   │   └── rune_detector.hpp
│   │   ├── rune_optimize
│   │   │   ├── ba_solver.cpp
│   │   │   └── ba_solver.hpp
│   │   ├── rune_tracker
│   │   │   ├── motion_models
│   │   │   │   └── motion_modelrypd.hpp
│   │   │   ├── rune_target.cpp
│   │   │   ├── rune_target.hpp
│   │   │   ├── rune_tracker.cpp
│   │   │   ├── rune_tracker.hpp
│   │   │   └── spd_fitter.hpp
│   │   ├── type.cpp
│   │   └── type.hpp
│   ├── auto_guidance
│   │   ├── auto_guidance.cpp
│   │   ├── auto_guidance.hpp
│   │   ├── CMakeLists.txt
│   │   ├── debug.cpp
│   │   ├── debug.hpp
│   │   ├── guidance_detector
│   │   │   ├── detector_base.hpp
│   │   │   ├── detector_factory.hpp
│   │   │   ├── green_light_infer.cpp
│   │   │   ├── green_light_infer.hpp
│   │   │   ├── opencv
│   │   │   │   ├── guidance_detector_opencv.cpp
│   │   │   │   └── guidance_detector_opencv.hpp
│   │   │   └── openvino
│   │   │   ├── guidance_detector_openvino.cpp
│   │   │   └── guidance_detector_openvino.hpp
│   │   ├── guidance_tracker
│   │   │   ├── guidance_target.cpp
│   │   │   ├── guidance_target.hpp
│   │   │   ├── guidance_tracker.cpp
│   │   │   ├── guidance_tracker.hpp
│   │   │   └── motion_models
│   │   │   └── imgbox_model.hpp
│   │   └── type.hpp
│   ├── auto_offset
│   │   ├── auto_offset.cpp
│   │   ├── auto_offset.hpp
│   │   └── CMakeLists.txt
│   ├── auto_sniper
│   │   ├── auto_sniper.cpp
│   │   ├── auto_sniper.hpp
│   │   ├── CMakeLists.txt
│   │   └── trajectory_solver.hpp
│   ├── CMakeLists.txt
│   ├── debug.cpp
│   ├── debug.hpp
│   ├── main_base.hpp
│   ├── packet_typedef.hpp
│   ├── sinple_img_rotate_saver.hpp
│   ├── type_common.cpp
│   ├── type_common.hpp
│   ├── utils.cpp
│   ├── utils.hpp
│   ├── vision_base.cpp
│   └── vision_base.hpp
├── templates
│   └── index.html
├── test
│   ├── control.cpp
│   └── test_usbcamera.cpp
└── web.py
```

View File

@@ -0,0 +1,147 @@
# --------------------------------------------------------------------------------------------
# FindHikSDK.cmake
#
# This module finds the HikRobot / Hikvision MVS Camera SDK.
#
# It defines the following variables:
# HikSDK_FOUND
# HikSDK_INCLUDE_DIR
# HikSDK_LIB
#
# And the following imported target:
# hiksdk
# --------------------------------------------------------------------------------------------
# =========================
# 1. SDK 根路径
# =========================
if(WIN32)
set(HikSDK_Path "$ENV{MVCAM_COMMON_RUNENV}")
else()
set(HikSDK_Path "$ENV{MVCAM_SDK_PATH}")
endif()
if(NOT HikSDK_Path OR HikSDK_Path STREQUAL "")
message(STATUS "HikSDK: MVCAM_SDK_PATH is not set")
set(HikSDK_FOUND FALSE)
return()
endif()
# =========================
# 2. 查找头文件
# =========================
find_path(
HikSDK_INCLUDE_DIR
NAMES
MvCameraControl.h
CameraParams.h
PixelType.h
MvErrorDefine.h
MvISPErrorDefine.h
PATHS
"${HikSDK_Path}/include"
"${HikSDK_Path}/Includes"
NO_DEFAULT_PATH
)
# =========================
# 3. 查找库文件(关键修复点)
# =========================
if(UNIX)
find_library(
HikSDK_LIB
NAMES
MvCameraControl
libMvCameraControl.so
PATHS
"${HikSDK_Path}/lib"
"${HikSDK_Path}/lib64"
"${HikSDK_Path}/lib/arm"
"${HikSDK_Path}/lib/arm64"
"${HikSDK_Path}/lib/aarch64"
"${HikSDK_Path}/lib/x86"
"${HikSDK_Path}/lib/x64"
"${HikSDK_Path}/lib/64"
"${HikSDK_Path}/lib/32"
NO_DEFAULT_PATH
)
endif()
# =========================
# 4. Windows完整但不影响 Linux
# =========================
if(WIN32)
find_library(
HikSDK_LIB
NAMES MvCameraControl
PATHS
"${HikSDK_Path}/Libraries"
"${HikSDK_Path}/Libraries/win64"
"${HikSDK_Path}/Libraries/win32"
NO_DEFAULT_PATH
)
find_file(
HikSDK_DLL
NAMES MvCameraControl.dll
PATHS
"${HikSDK_Path}/Runtime"
"C:/Program Files (x86)/Common Files/MVS/Runtime"
NO_DEFAULT_PATH
)
endif()
# =========================
# 5. 创建导入库目标
# =========================
if(HikSDK_LIB AND HikSDK_INCLUDE_DIR)
if(NOT TARGET HikSDK::HikSDK)
add_library(HikSDK::HikSDK SHARED IMPORTED GLOBAL)
if(WIN32)
set_target_properties(HikSDK::HikSDK PROPERTIES
IMPORTED_IMPLIB "${HikSDK_LIB}"
IMPORTED_LOCATION "${HikSDK_DLL}"
INTERFACE_INCLUDE_DIRECTORIES "${HikSDK_INCLUDE_DIR}"
)
else()
set_target_properties(HikSDK::HikSDK PROPERTIES
IMPORTED_LOCATION "${HikSDK_LIB}"
INTERFACE_INCLUDE_DIRECTORIES "${HikSDK_INCLUDE_DIR}"
)
endif()
endif()
endif()
set(HikSDK_LIBS HikSDK::HikSDK)
set(HikSDK_INCLUDE_DIRS ${HikSDK_INCLUDE_DIR})
# =========================
# 6. 标准 find_package 处理
# =========================
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(
HikSDK
REQUIRED_VARS HikSDK_LIB HikSDK_INCLUDE_DIR
)
# =========================
# 7. 调试输出(很重要)
# =========================
if(HikSDK_FOUND)
message(STATUS "HikSDK found:")
message(STATUS " Include dir : ${HikSDK_INCLUDE_DIR}")
message(STATUS " Library : ${HikSDK_LIB}")
else()
message(STATUS "HikSDK NOT found")
endif()
set(HikSDK_LIBS hiksdk)
set(HikSDK_INCLUDE_DIRS ${HikSDK_INCLUDE_DIR})
mark_as_advanced(HikSDK_LIB HikSDK_INCLUDE_DIR HikSDK_DLL)

View File

@@ -0,0 +1,88 @@
# --------------------------------------------------------------------------------------------
# This file is used to find the ONNX-Runtime SDK, which provides the following variables:
#
# Cache Variables:
# - Ort_HEADER_FILES: Names of SDK header files
#
# Advanced Variables:
# - Ort_INCLUDE_DIR: Directory where SDK header files are located
# - Ort_LIB: Path to the SDK library file (import library on Windows, shared library
# on Linux)
# - Ort_DLL: Path to the SDK dynamic library file (only on Windows)
#
# Local Variables:
# - Ort_LIBS: CMake target name for the SDK, which is "onnxruntime"
# - Ort_INCLUDE_DIRS: Directory where SDK header files are located
# --------------------------------------------------------------------------------------------
# ------------------------------------------------------------------------------
# find onnxruntime root path
# ------------------------------------------------------------------------------
if(NOT ort_root_path)
set(ort_root_path "/usr/local")
endif()
# ------------------------------------------------------------------------------
# find onnxruntime include directory
# ------------------------------------------------------------------------------
set(Ort_HEADER_FILES
cpu_provider_factory.h onnxruntime_run_options_config_keys.h
onnxruntime_c_api.h onnxruntime_session_options_config_keys.h
onnxruntime_cxx_api.h provider_options.h
onnxruntime_cxx_inline.h
CACHE INTERNAL "ONNX Runtime header files"
)
find_path(
Ort_INCLUDE_DIR
PATHS "${ort_root_path}/include"
NAMES ${Ort_HEADER_FILES}
NO_DEFAULT_PATH
)
# ------------------------------------------------------------------------------
# find onnxruntime library file
# ------------------------------------------------------------------------------
find_library(
Ort_LIB
NAMES "libonnxruntime.so"
PATHS "${ort_root_path}/lib"
NO_DEFAULT_PATH
)
# ------------------------------------------------------------------------------
# create imported target: onnxruntime
# ------------------------------------------------------------------------------
if(NOT TARGET onnxruntime)
add_library(onnxruntime SHARED IMPORTED)
set_target_properties(onnxruntime PROPERTIES
IMPORTED_LOCATION "${Ort_LIB}"
INTERFACE_INCLUDE_DIRECTORIES "${Ort_INCLUDE_DIR}"
)
endif()
mark_as_advanced(Ort_INCLUDE_DIR Ort_LIB)
# ------------------------------------------------------------------------------
# set onnxruntime cmake variables and version variables
# ------------------------------------------------------------------------------
set(Ort_LIBS "onnxruntime")
set(Ort_INCLUDE_DIRS "${Ort_INCLUDE_DIR}")
if(Ort_INCLUDE_DIR)
file(STRINGS "${Ort_INCLUDE_DIR}/onnxruntime_c_api.h" Ort_VERSION
REGEX "#define ORT_API_VERSION [0-9]+"
)
string(REGEX REPLACE "#define ORT_API_VERSION ([0-9]+)" "1.\\1" Ort_VERSION "${Ort_VERSION}")
endif()
# ------------------------------------------------------------------------------
# handle the package
# ------------------------------------------------------------------------------
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(
Ort
VERSION_VAR Ort_VERSION
REQUIRED_VARS Ort_LIB Ort_INCLUDE_DIR
)

View File

@@ -0,0 +1,66 @@
# FindTensorRT.cmake -- Locate NVIDIA TensorRT
include(FindPackageHandleStandardArgs)
# TensorRT root
if (DEFINED TensorRT_ROOT)
list(APPEND _TensorRT_SEARCH_PATHS
${TensorRT_ROOT}
"$ENV{TensorRT_ROOT}"
)
endif()
list(APPEND _TensorRT_SEARCH_PATHS /usr /usr/local)
# Header
find_path(TensorRT_INCLUDE_DIR
NAMES NvInfer.h
PATHS ${_TensorRT_SEARCH_PATHS}
PATH_SUFFIXES include
)
# Core library
find_library(TensorRT_LIBRARY
NAMES nvinfer
PATHS ${_TensorRT_SEARCH_PATHS}
PATH_SUFFIXES lib lib64 lib/x64
)
find_package_handle_standard_args(TensorRT
REQUIRED_VARS TensorRT_INCLUDE_DIR TensorRT_LIBRARY
)
if (TensorRT_FOUND)
set(TensorRT_INCLUDE_DIRS ${TensorRT_INCLUDE_DIR})
set(TensorRT_LIBRARIES ${TensorRT_LIBRARY})
# Optional components
foreach(_comp IN ITEMS nvinfer_plugin nvonnxparser nvparsers)
find_library(TensorRT_${_comp}_LIBRARY
NAMES ${_comp}
PATHS ${_TensorRT_SEARCH_PATHS}
PATH_SUFFIXES lib lib64 lib/x64
)
if (TensorRT_${_comp}_LIBRARY)
list(APPEND TensorRT_LIBRARIES ${TensorRT_${_comp}_LIBRARY})
endif()
endforeach()
# Core target
add_library(TensorRT::TensorRT UNKNOWN IMPORTED)
set_target_properties(TensorRT::TensorRT PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${TensorRT_INCLUDE_DIRS}"
IMPORTED_LOCATION "${TensorRT_LIBRARY}"
)
# Component targets
foreach(_comp IN ITEMS nvinfer_plugin nvonnxparser nvparsers)
if (TensorRT_${_comp}_LIBRARY)
add_library(TensorRT::${_comp} UNKNOWN IMPORTED)
set_target_properties(TensorRT::${_comp} PROPERTIES
IMPORTED_LOCATION "${TensorRT_${_comp}_LIBRARY}"
)
endif()
endforeach()
message(STATUS "Found TensorRT at ${TensorRT_INCLUDE_DIR}")
endif()

View File

@@ -0,0 +1,131 @@
armor_detect_backend: tensorrt
max_detect_armors: 11
armor_map:
SENTRY: 4
NO1: 2
NO2: 3
NO3: 1
NO4: 5
NO5: 9
OUTPOST: 6
BASE: 7
UNKNOWN: -1
armor_where:
yaw_opt:
mode: golden
golden_search_side_deg: 60
distance_fix_a2: 0.0
armor_tracker:
lost_time_thres: 2.0
tracking_thres: 10
max_yaw_diff_deg: 60.0
max_dis_diff: 3.0
match_gate: 50
qxyz_common: [200.0, 200.0, 1.0]
qyaw_common: 100.0
qxyz_output: [10.0, 10.0, 0.5]
qyaw_output: 0.01
q_r: 0.0000001
q_l: 0.0000001
q_h: 0.0000001
q_outpost_dz: 0.5
yp_r: 2e-3
dis_r_front: 0.5
dis_r_side: 2.5
dis2_r_ratio: 0.1
yaw_r_base_front: 0.2
yaw_r_base_side: 0.06
yaw_r_log_ratio: 0.02
esekf_iter_num: 5
auto_aim_fsm:
single_whole_up: 1.5
single_whole_down: 1.0
whole_pair_up: 7.5
whole_pair_down: 6.5 #1s内至少换2次板子 2pi 否则轨迹规划无意义
pair_center_up: 16.5
pair_center_down: 15.0
transfer_thresh: 50
very_aimer:
fuck_test: false
fuck_test_thresh: 0.5
type: "seg" #seg or mpc
sample_total_time: 2.0
sample_horizon: 500
control_delay: 0.2
delay_enable_fire_error: 0.0035
max_yaw_acc: 40
#全向yaw加速度测算37
max_pitch_acc: 25
prediction_delay: 0.00
comming_angle: 60
leaving_angle: 20
yaw_limit_deg: 60
shooting_range_h: 0.12
shooting_range_small_w: 0.12
shooting_range_big_w: 0.24
min_enable_pitch_deg: 0.25
min_enable_yaw_deg: 0.25
base_offset:
yaw: -1.0
pitch: -3.5
trajectory_offset:
- d_max: 3
d_min: 0
h_max: 1.5
h_min: -1
pitch_off: 0
yaw_off: 0
- d_max: 4.5
d_min: 3
h_max: 1.5
h_min: -1
pitch_off: 0
yaw_off: 0
- d_max: 4.5
d_min: 9
h_max: 1.5
h_min: -1
pitch_off: 0
yaw_off: 0
#---mpc only
max_iter: 10
Q_yaw: [7e6, 0]
R_yaw: [3.0]
Q_pitch: [7e6, 0]
R_pitch: [3.0]
trajectory_compensator:
compenstator_type: resistance
gravity: 9.8
iteration_times: 20
resistance: 0.092
k1: 0.0190 #大弹丸double k1_c = 0.47; double k1 = k1_c * 1.169 * (2 * M_PI * 0.02125 * 0.02125) / 2 / 0.041;
auto_exposure:
enable: true
target_brightness: 25.0
tolerance: 3.0
step_gain: 15.0
decay_step: 1.0
exposure_min: 100.0
exposure_max: 1500.0
control_interval_ms: 300

View File

@@ -0,0 +1,76 @@
rune_detector:
rune_center_min_area: 10
rune_center_max_area: 2000
rune_center_1x1ratio_tol: 0.7
rune_center_fill_ratio_min: 0.3
rune_target_min_area: 100
rune_target_max_area: 3000
rune_target_max_square_ratio: 1.3
rune_target_cluster_radius: 70
bin_threshold: 50
color_diff_threshold: 40
rune_where:
roll_opt:
mode: golden
golden_search_side_deg: 60
rune_tracker:
lost_time_thres: 2.0
tracking_thres: 5
max_dis_diff: 1.0
match_gate: 15.0
esekf_iter_num: 2
q_roll: 10.0
q_xyz: 0.5
q_yaw: 0.5
yp_r: 0.01
dis_r: 0.05
yaw_r: 0.1
roll_r: 0.2
big_window_sec: 2.0
aimer:
prediction_delay: -0.00
shooting_range_h: 0.1
shooting_range_w: 0.1
min_enable_pitch_deg: 0.25
min_enable_yaw_deg: 0.25
base_offset:
yaw: -3.0
pitch: -0.0
trajectory_offset:
- d_max: 9
d_min: 0
h_max: 1.5
h_min: -1
pitch_off: -0
yaw_off: -0
- d_max: 10
d_min: 9
h_max: 0.4
h_min: -1
pitch_off: -0
yaw_off: 0
trajectory_compensator:
compenstator_type: resistance
gravity: 9.8
iteration_times: 20
resistance: 0.092
k1: 0.0190 #大弹丸double k1_c = 0.47; double k1 = k1_c * 1.169 * (2 * M_PI * 0.02125 * 0.02125) / 2 / 0.041;
auto_exposure:
enable: false
target_brightness: 15.0
tolerance: 3.0
step_gain: 15.0
decay_step: 1.0
exposure_min: 100.0
exposure_max: 2500.0
control_interval_ms: 300

View File

@@ -0,0 +1,83 @@
backend: opencv
max_infer_running: 6
detector:
openvino:
model_path: /home/hy/2026_dart_guide/assets/Katrin.xml
top_k: 128
conf_threshold: 0.2
nms_threshold: 0.5
device_type: CPU
use_throughputmode: false
opencv:
gui: true
HSV:
lowH: 25
highH: 85
lowS: 50
highS: 255
lowV: 150
highV: 255
contours:
min_area: 100
max_area: 50000
min_aspect_ratio: 0.5
min_fill_ratio: 0.7
tracker:
lost_time_thres: 2.0
tracking_thres: 10
target:
xy_r: 0.05
wh_r: 0.05
q_xy: 100
q_wh: 100
iter_num: 2
max_dis_diff: 2.0
logger:
log_level: INFO
log_path: /home/hy/wust_log
use_logcli: true
use_logfile: true
use_simplelog: true
control:
control_rate: 100
communication_delay_us: 1100
pitch_avg_windows: 1
device_name: /dev/ttyACM_RMc
use_serial: true
camera:
type: hik_camera
camera_info_path: ${VISION_ROOT}/config/camera_info.yaml
hik_camera:
target_sn: DA1094119 #DA3038891
acquisition_frame_rate: 80
adc_bit_depth: Bits_12
exposure_time: 5000
gain: 16.9
gamma: 0.5 #Bayer用不了
pixel_format: BayerRG8
use_raw: false
acquisition_frame_rate_enable: false
reverse_x: false
reverse_y: false
video_player:
fps: 12
loop: true
use_cvt: true
#path: /home/hy/wust_data/video/aaa.mp4
#path: /home/hy/wust_data/video/jiao.avi
#path: /home/hy/Z_LION_AutoAim2025/src/auto_aim_publisher/videos/sample.avi
#path: /home/hy/wust_data/record/20251111_134416_187.avi
path: /home/hy/下载/sp_vision_25/records/1.avi
#path: /home/hy/wust_data/video/Sentry.mp4
start_frame: 0

View File

@@ -0,0 +1,27 @@
map:
voxel_size: 0.1
min_pos: [-5.0,-10.0,-1.0]
max_pos: [20.0,10.0,1.0]
solver:
k1: 0.0190
g: 9.81
target_armor_z: 0.5
offset_helper:
order: 5
yaw_base_offset: 0.0
pitch_base_offset: 0.0
offset_table:
- distance: 2.0
yaw: 0.1
pitch: -0.2
- distance: 3.0
yaw: 0.15
pitch: -0.25
- distance: 4.0
yaw: 0.2
pitch: -0.3

View File

@@ -0,0 +1,50 @@
type: hik_camera
camera_info_path: ${VISION_ROOT}/config/camera_info.yaml
hik_camera:
target_sn: DA3038934
adc_bit_depth: Bits_8
pixel_format: BayerRG8 #RGB8Packed
reverse_x: false
reverse_y: false
width: 1440
height: 1080
offset_x: 0
offset_y: 0
acquisition_frame_rate_enable: true
acquisition_frame_rate: 250
exposure_time: 2000
gain: 16.9
gamma: 0.7 #Bayer用不了
trigger_type: none
trigger_source: ""
trigger_activation: 0
use_raw: false
use_rgb: false
use_ea: false
use_cuda_cvt: true
video_player:
fps: 60
loop: true
use_cvt: true
trigger_mode: false
#path: /home/hy/wust_data/video/aaa.mp4
#path: /home/hy/wust_data/video/jiao.avi
#path: /home/hy/Z_LION_AutoAim2025/src/auto_aim_publisher/videos/sample.avi
#path: /home/hy/wust_data/record/20251111_134416_187.avi
path: /home/hy/data/video_save/5.21成都.mp4
#path: /home/hy/wust_data/video/Sentry.mp4
start_frame: 0
uvc:
device_name: "/dev/v4l/by-path/pci-0000:00:14.0-usb-0:3.4:1.0-video-index0"
fps: 60
width: 1280
height: 720
exposure: 100.0
gain: 10.0
gamma: 100
trigger_mode: false

View File

@@ -0,0 +1,82 @@
#8mm
# image_width: 1440
# image_height: 1080
# camera_name: narrow_stereo
# camera_matrix:
# rows: 3
# cols: 3
# data: [2419.64498, 0. , 719.05623,
# 0. , 2414.38209, 538.71651,
# 0. , 0. , 1. ]
# distortion_model: plumb_bob
# distortion_coefficients:
# rows: 1
# cols: 5
# data: [-0.033521, 0.087954, -0.002045, -0.001438, 0.000000]
# rectification_matrix:
# rows: 3
# cols: 3
# data: [1., 0., 0.,
# 0., 1., 0.,
# 0., 0., 1.]
# projection_matrix:
# rows: 3
# cols: 4
# data: [2413.22632, 0. , 717.5196 , 0. ,
# 0. , 2408.83228, 537.43444, 0. ,
# 0. , 0. , 1. , 0. ]
#6mm
image_width: 1440
image_height: 1080
camera_name: narrow_stereo
camera_matrix:
rows: 3
cols: 3
data: [1814.19888, 0. , 730.27359,
0. , 1807.37177, 529.57031,
0. , 0. , 1. ]
distortion_model: plumb_bob
distortion_coefficients:
rows: 1
cols: 5
data: [-0.089138, 0.114766, 0.000034, 0.000238, 0.000000]
rectification_matrix:
rows: 3
cols: 3
data: [1., 0., 0.,
0., 1., 0.,
0., 0., 1.]
projection_matrix:
rows: 3
cols: 4
data: [1793.54229, 0. , 730.67194, 0. ,
0. , 1794.48784, 529.46193, 0. ,
0. , 0. , 1. , 0. ]
#12mm
# image_width: 1440
# image_height: 1080
# camera_name: narrow_stereo
# camera_matrix:
# rows: 3
# cols: 3
# data: [3655.11292, 0. , 748.66803,
# 0. , 3648.34439, 556.97683,
# 0. , 0. , 1. ]
# distortion_model: plumb_bob
# distortion_coefficients:
# rows: 1
# cols: 5
# data: [-0.058763, 1.430269, 0.001191, 0.001641, 0.000000]
# rectification_matrix:
# rows: 3
# cols: 3
# data: [1., 0., 0.,
# 0., 1., 0.,
# 0., 0., 1.]
# projection_matrix:
# rows: 3
# cols: 4
# data: [3659.22949, 0. , 748.71496, 0. ,
# 0. , 3652.02661, 556.81908, 0. ,
# 0. , 0. , 1. , 0. ]

View File

@@ -0,0 +1,34 @@
max_infer_running: 6
detect_color: 0
attack_mode: 0
debug_fps: 60
tf:
R_camera2gimbal: [0.0, 0.0, 1.0, -1.0, -0.0, 0.0, 0.0, -1.0, 0.0]
t_camera2gimbal:
[0.000434347168653831, 0.0141232476052895, 0.00736106400231024]
control:
control_rate: 1000
communication_delay_us: 100.000
device_name: /dev/stm32_acm
use_serial: true
yaw_ramp: 0.0 #send_cmd = cmd.pos+cmd.vel*ramp
pitch_ramp: 0.0
logger:
log_level: DEBUG
log_path: log
use_logcli: true
use_logfile: false
use_simplelog: true
shoot:
bullet_speed: 26.0
rate: 3
record:
use_record: false
folder_path: record
use_rotate_reader: false
read_csv_path: ""

View File

@@ -0,0 +1,66 @@
armor_detector:
cv: #装甲板灯条完整无损坏才可以用
enable: true
armor:
max_angle: 40
max_large_center_distance: 8
max_small_center_distance: 3.5
min_large_center_distance: 3.5
min_light_ratio: 0.5
min_small_center_distance: 0.05
light:
binary_thres: 150
expand_ratio_h: 1.9
expand_ratio_w: 1.2
max_angle: 45
max_ratio: 0.4
min_ratio: 0.01
max_pts_error: 20.0
max_angle_diff: 20.0
color_diff_thresh: 40
classify:
enable: true
label_path: ${VISION_ROOT}/model/label.txt
model_path: ${VISION_ROOT}/model/mlp.onnx
backend: opencv
threshold: 0.5
tensorrt:
conf_threshold: 0.2
model_path: ${VISION_ROOT}/model/opt-1208-001.onnx
device_id: 0
nms_threshold: 0.3
top_k: 128
max_infer_running: 3
min_free_mem_ratio: 0.1
use_cuda_pre: true
log_time: false
model_type: tup
openvino:
conf_threshold: 0.2
device_name: GPU
model_path: ${VISION_ROOT}/model/opt-1208-001.onnx
nms_threshold: 0.3
top_k: 128
use_throughputmode: true
model_type: tup
onnxruntime:
conf_threshold: 0.2
model_path: ${VISION_ROOT}/model/opt-1208-001.onnx
nms_threshold: 0.3
top_k: 128
model_type: tup
provider: CUDA
ncnn:
conf_threshold: 0.2
model_path_param: ${VISION_ROOT}/model/opt-1208-001.param
model_path_bin: ${VISION_ROOT}/model/opt-1208-001.bin
input_name: images
output_name: output
use_gpu: true
use_lightmode: false
device_id: 0
cpu_threads: 20
nms_threshold: 0.3
top_k: 128
model_type: tup

View File

@@ -0,0 +1,19 @@
armor_detector:
light:
binary_thres: 120
max_angle: 40
max_ratio: 0.4
min_ratio: 0.001
color_diff_thresh: 20
max_angle_diff: 10.0
armor:
min_light_ratio: 0.8
min_small_center_distance: 0.8
max_small_center_distance: 3.5
min_large_center_distance: 3.5
max_large_center_distance: 8.0
max_angle: 40.0
classify:
label_path: ${VISION_ROOT}/model/label.txt
model_path: ${VISION_ROOT}/model/reborn_number_classifier.onnx
threshold: 0.5

View File

@@ -0,0 +1,24 @@
#!/bin/bash
TARGET="$1"
shift
ARGS=("${@:3}")
echo "[GUARD] target: $TARGET"
echo "[GUARD] args: ${ARGS[*]}"
echo "[GUARD] Starting monitor..."
while true; do
echo "[GUARD] Launching program..."
"$TARGET" "${ARGS[@]}"
RET=$?
if [ $RET -eq 0 ]; then
echo "[GUARD] Program exited normally. Stopping guard."
exit 0
fi
echo "[GUARD] Crash detected, restarting in 1 second..."
sleep 1
done

View File

@@ -0,0 +1,51 @@
yaw_in_big_yaw_deg: 90.0
type: uvc
camera_info_path: ${VISION_ROOT}/config/camera_info.yaml
hik_camera:
target_sn: DA1094119
adc_bit_depth: Bits_8
pixel_format: BayerRG8 #RGB8Packed
reverse_x: false
reverse_y: false
width: 1440
height: 1080
offset_x: 0
offset_y: 0
acquisition_frame_rate_enable: true
acquisition_frame_rate: 250
exposure_time: 2000
gain: 16.9
gamma: 0.7 #Bayer用不了
trigger_type: none
trigger_source: ""
trigger_activation: 0
use_raw: false
use_rgb: false
use_ea: false
use_cuda_cvt: false
video_player:
fps: 60
loop: true
use_cvt: true
trigger_mode: false
#path: /home/hy/wust_data/video/aaa.mp4
#path: /home/hy/wust_data/video/jiao.avi
#path: /home/hy/Z_LION_AutoAim2025/src/auto_aim_publisher/videos/sample.avi
#path: /home/hy/wust_data/record/20251111_134416_187.avi
path: /home/hy/rune_dl/runeA.mp4
#path: /home/hy/wust_data/video/Sentry.mp4
start_frame: 0
uvc:
device_name: "/dev/v4l/by-path/pci-0000:00:14.0-usb-0:9.4:1.0-video-index0"
fps: 30
width: 1280
height: 720
exposure: 200
gain: 10.0
gamma: 100
trigger_mode: true

View File

@@ -0,0 +1,51 @@
yaw_in_big_yaw_deg: 107.0
type: uvc
camera_info_path: ${VISION_ROOT}/config/omni/camera_info.yaml
hik_camera:
target_sn: DA1094119
adc_bit_depth: Bits_8
pixel_format: BayerRG8 #RGB8Packed
reverse_x: false
reverse_y: false
width: 1440
height: 1080
offset_x: 0
offset_y: 0
acquisition_frame_rate_enable: true
acquisition_frame_rate: 250
exposure_time: 2000
gain: 16.9
gamma: 0.7 #Bayer用不了
trigger_type: none
trigger_source: ""
trigger_activation: 0
use_raw: false
use_rgb: false
use_ea: false
use_cuda_cvt: false
video_player:
fps: 60
loop: true
use_cvt: true
trigger_mode: false
#path: /home/hy/wust_data/video/aaa.mp4
#path: /home/hy/wust_data/video/jiao.avi
#path: /home/hy/Z_LION_AutoAim2025/src/auto_aim_publisher/videos/sample.avi
#path: /home/hy/wust_data/record/20251111_134416_187.avi
path: /home/hy/rune_dl/runeA.mp4
#path: /home/hy/wust_data/video/Sentry.mp4
start_frame: 0
uvc:
device_name: "/dev/v4l/by-path/pci-0000:00:14.0-usb-0:9.1:1.0-video-index0"
fps: 30
width: 1280
height: 720
exposure: 200
gain: 10.0
gamma: 100
trigger_mode: true

View File

@@ -0,0 +1,28 @@
# camera_matrix = [759.73071, 0.0, 336.19053, 0.0, 761.16771, 231.83002, 0.0, 0.0, 1.0]
# distortion_coefficients = [0.000281, 0.144018, 0.005509, -0.004330, 0.0]
image_width: 1440
image_height: 1080
camera_name: narrow_stereo
camera_matrix:
rows: 3
cols: 3
data: [759.73071, 0.0, 336.19053, 0.0, 761.16771, 231.83002, 0.0, 0.0, 1.0 ]
distortion_model: plumb_bob
distortion_coefficients:
rows: 1
cols: 5
data: [0.000281, 0.144018, 0.005509, -0.004330, 0.0]
rectification_matrix:
rows: 3
cols: 3
data: [1., 0., 0.,
0., 1., 0.,
0., 0., 1.]
projection_matrix:
rows: 3
cols: 4
data: [1793.54229, 0. , 730.67194, 0. ,
0. , 1794.48784, 529.46193, 0. ,
0. , 0. , 1. , 0. ]

View File

@@ -0,0 +1,66 @@
armor_detector:
cv: #装甲板灯条完整无损坏才可以用
enable: false
armor:
max_angle: 40
max_large_center_distance: 8
max_small_center_distance: 3.5
min_large_center_distance: 3.5
min_light_ratio: 0.5
min_small_center_distance: 0.05
light:
binary_thres: 150
expand_ratio_h: 1.9
expand_ratio_w: 1.2
max_angle: 45
max_ratio: 0.4
min_ratio: 0.01
max_pts_error: 20.0
max_angle_diff: 20.0
color_diff_thresh: 0
classify:
enable: true
label_path: ${VISION_ROOT}/model/label.txt
model_path: ${VISION_ROOT}/model/mlp.onnx
backend: opencv
threshold: 0.5
tensorrt:
conf_threshold: 0.2
model_path: ${VISION_ROOT}/model/opt-1208-001.onnx
device_id: 0
nms_threshold: 0.3
top_k: 128
max_infer_running: 3
min_free_mem_ratio: 0.1
use_cuda_pre: true
log_time: false
model_type: tup
openvino:
conf_threshold: 0.2
device_name: GPU
model_path: ${VISION_ROOT}/model/opt-1208-001.onnx
nms_threshold: 0.3
top_k: 128
use_throughputmode: true
model_type: tup
onnxruntime:
conf_threshold: 0.2
model_path: ${VISION_ROOT}/model/opt-1208-001.onnx
nms_threshold: 0.3
top_k: 128
model_type: tup
provider: CUDA
ncnn:
conf_threshold: 0.2
model_path_param: ${VISION_ROOT}/model/opt-1208-001.param
model_path_bin: ${VISION_ROOT}/model/opt-1208-001.bin
input_name: images
output_name: output
use_gpu: true
use_lightmode: false
device_id: 0
cpu_threads: 20
nms_threshold: 0.3
top_k: 128
model_type: tup

View File

@@ -0,0 +1,19 @@
armor_detector:
light:
binary_thres: 120
max_angle: 40
max_ratio: 0.4
min_ratio: 0.001
color_diff_thresh: 20
max_angle_diff: 10.0
armor:
min_light_ratio: 0.8
min_small_center_distance: 0.8
max_small_center_distance: 3.5
min_large_center_distance: 3.5
max_large_center_distance: 8.0
max_angle: 40.0
classify:
label_path: ${VISION_ROOT}/model/label.txt
model_path: ${VISION_ROOT}/model/reborn_number_classifier.onnx
threshold: 0.5

View File

@@ -0,0 +1,37 @@
armor_detect_backend: openvino
cameras: ["${VISION_ROOT}/config/omni/camera0.yaml", "${VISION_ROOT}/config/omni/camera1.yaml"]
fps: 60
max_infer_running: 3
active_time: 1.0
min_score: 15.0
armor_where:
yaw_opt:
mode: golden
golden_search_side_deg: 60
distance_fix_a2: 0.0
armor_tracker:
lost_time_thres: 1.0
tracking_thres: 5
max_yaw_diff_deg: 90.0
max_dis_diff: 3.0
match_gate: 50
qxyz_common: [1.0, 1.0, 1.0]
qyaw_common: 1.0
qxyz_output: [1.0, 1.0, 0.5]
qyaw_output: 0.01
q_r: 0.0000001
q_l: 0.0000001
q_h: 0.0000001
q_outpost_dz: 0.5
yp_r: 2e-3
dis_r_front: 0.5
dis_r_side: 2.5
dis2_r_ratio: 0.1
yaw_r_base_front: 0.2
yaw_r_base_side: 0.06
yaw_r_log_ratio: 0.02
esekf_iter_num: 1

View File

@@ -0,0 +1,61 @@
cmake_minimum_required(VERSION 3.10)
cmake_policy(SET CMP0079 NEW)
project(cuda_infer LANGUAGES CXX CUDA)
# 设置标准
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
set(CUDA_USE_STATIC_CUDA_RUNTIME OFF)
set(CMAKE_BUILD_TYPE "Release")
# 抑制过时 API 警告
add_compile_options(-Wno-deprecated-declarations)
# 禁用 .rsp 响应文件(避免 nvcc 报错)
set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_OBJECTS OFF)
set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_INCLUDES OFF)
# 查找依赖
find_package(CUDAToolkit REQUIRED)
find_package(OpenCV REQUIRED)
find_package(Eigen3 REQUIRED)
# 收集源码
file(GLOB_RECURSE CUDA_INFER_SRC
${CMAKE_CURRENT_SOURCE_DIR}/*.cpp
${CMAKE_CURRENT_SOURCE_DIR}/*.cu
)
# 添加静态库
add_library(cuda_infer STATIC ${CUDA_INFER_SRC})
set_target_properties(cuda_infer PROPERTIES POSITION_INDEPENDENT_CODE ON)
# 设置包含路径
target_include_directories(cuda_infer PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}
${CUDAToolkit_INCLUDE_DIRS}
${TensorRT_INCLUDE_DIR}
${EIGEN3_INCLUDE_DIRS}
${OpenCV_INCLUDE_DIRS}
)
# 设置 CUDA 编译选项
target_compile_options(cuda_infer PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:
--generate-code=arch=compute_86,code=sm_86
-Xcompiler=-fPIC
-O3
-w
-Wno-deprecated-gpu-targets
-Wno-error=deprecated-declarations
>
)
# 链接库
target_link_libraries(cuda_infer PRIVATE
${OpenCV_LIBS}
CUDA::cudart
TensorRT::TensorRT
)

View File

@@ -0,0 +1,323 @@
// armor_cuda_infer.cu
#include "armor_infer.hpp"
#include "letter_box.hpp"
#include <cmath>
#include <cstdio>
#include <cuda_fp16.h>
#include <opencv2/core/hal/interface.h>
#include <thrust/device_ptr.h>
#include <thrust/sort.h>
#define CUDA_CHECK(call) \
do { \
cudaError_t err = call; \
if (err != cudaSuccess) { \
fprintf( \
stderr, \
"CUDA error at %s:%d: %s\n", \
__FILE__, \
__LINE__, \
cudaGetErrorString(err) \
); \
exit(EXIT_FAILURE); \
} \
} while (0)
namespace armor_cuda_infer {
__global__ void nchw_float_to_hwc_uchar4(
const float* __restrict__ src,
uchar4* __restrict__ dst,
int W,
int H,
float norm
) {
const int x = blockIdx.x * blockDim.x + threadIdx.x;
const int y = blockIdx.y * blockDim.y + threadIdx.y;
if (x >= W || y >= H)
return;
const int idx = y * W + x;
const int plane = W * H;
float r = __ldg(src + idx + plane * 0);
float g = __ldg(src + idx + plane * 1);
float b = __ldg(src + idx + plane * 2);
r = fminf(fmaxf(r / norm, 0.f), 255.f);
g = fminf(fmaxf(g / norm, 0.f), 255.f);
b = fminf(fmaxf(b / norm, 0.f), 255.f);
dst[idx] = make_uchar4((unsigned char)b, (unsigned char)g, (unsigned char)r, 255);
}
cv::Mat CudaInfer::tensorToMat(float* d_nchw, int W, int H, float norm, cudaStream_t stream) const {
static uchar4* d_hwc = nullptr;
static size_t cap = 0;
const size_t need = W * H * sizeof(uchar4);
if (cap < need) {
if (d_hwc)
cudaFree(d_hwc);
cudaMalloc(&d_hwc, need);
cap = need;
}
const dim3 block(TILE_W, TILE_H);
const dim3 grid((W + block.x - 1) / block.x, (H + block.y - 1) / block.y);
nchw_float_to_hwc_uchar4<<<grid, block, 0, stream>>>(d_nchw, d_hwc, W, H, norm);
cv::Mat img(H, W, CV_8UC4);
cudaMemcpyAsync(img.data, d_hwc, need, cudaMemcpyDeviceToHost, stream);
// cudaStreamSynchronize(stream);
return img;
}
CudaInfer::CudaInfer() = default;
CudaInfer::~CudaInfer() {
release();
}
void CudaInfer::init(int max_src_w, int max_src_h, int input_w, int input_h) {
input_w_ = input_w;
input_h_ = input_h;
max_src_h_ = max_src_h;
max_src_w_ = max_src_w;
rellocMem();
}
void CudaInfer::rellocMem() {
CUDA_CHECK(cudaMalloc(&d_input_bgr_, max_src_w_ * max_src_h_ * 3 * sizeof(unsigned char)));
CUDA_CHECK(cudaMallocPitch(
&d_input_bgr_pitched_,
&input_pitch_bytes_,
max_src_w_ * 3 * sizeof(unsigned char),
max_src_h_
));
CUDA_CHECK(cudaMalloc(&d_nchw_, input_w_ * input_h_ * 3 * sizeof(float)));
printf("Relloc memory for CudaInfer\n");
}
void CudaInfer::getOutEnoughMem(int img_w, int img_h) {
if (img_w > max_src_w_ || img_h > max_src_h_) {
if (img_w > max_src_w_) {
max_src_w_ = img_w;
}
if (img_h > max_src_h_) {
max_src_h_ = img_h;
}
rellocMem();
}
}
void CudaInfer::release() {
if (d_input_bgr_)
cudaFree(d_input_bgr_), d_input_bgr_ = nullptr;
if (d_input_bgr_pitched_)
cudaFree(d_input_bgr_pitched_), d_input_bgr_pitched_ = nullptr;
if (d_nchw_)
cudaFree(d_nchw_), d_nchw_ = nullptr;
}
float* CudaInfer::preprocess(
const unsigned char* input_bgr_host,
int img_w,
int img_h,
float norm,
bool swap_rb,
Eigen::Matrix3f& tf_matrix,
cudaStream_t stream
) {
if (!isInitialized()) {
throw std::runtime_error("CudaInfer not initialized properly.");
}
if (!input_bgr_host || !d_input_bgr_ || !d_nchw_) {
fprintf(stderr, "[Error] Null pointer in preprocess input\n");
return nullptr;
}
getOutEnoughMem(img_w, img_h);
float scale = fminf(input_w_ / (float)img_w, input_h_ / (float)img_h);
int rw = round(img_w * scale), rh = round(img_h * scale);
int pad_l = (input_w_ - rw) / 2, pad_t = (input_h_ - rh) / 2;
tf_matrix << 1.f / scale, 0, -pad_l / scale, 0, 1.f / scale, -pad_t / scale, 0, 0, 1;
size_t img_size = img_w * img_h * 3;
CUDA_CHECK(
cudaMemcpyAsync(d_input_bgr_, input_bgr_host, img_size, cudaMemcpyHostToDevice, stream)
);
dim3 threads(TILE_W, TILE_H);
dim3 blocks((input_w_ + TILE_W - 1) / TILE_W, (input_h_ + TILE_H - 1) / TILE_H);
letterbox_kernel_shared<<<blocks, threads, 0, stream>>>(
d_input_bgr_,
img_w,
img_h,
d_nchw_,
input_w_,
input_h_,
scale,
pad_t,
pad_l,
norm,
swap_rb
);
CUDA_CHECK(cudaGetLastError());
return d_nchw_;
}
float* CudaInfer::preprocess_gpu(
const unsigned char* input_bgr_device,
int img_w,
int img_h,
float norm,
bool swap_rb,
Eigen::Matrix3f& tf_matrix,
cudaStream_t stream
) {
if (!isInitialized()) {
throw std::runtime_error("CudaInfer not initialized properly.");
}
if (!input_bgr_device || !d_nchw_) {
fprintf(stderr, "[Error] Null pointer in preprocess input\n");
return nullptr;
}
getOutEnoughMem(img_w, img_h);
float scale = fminf(input_w_ / (float)img_w, input_h_ / (float)img_h);
int rw = round(img_w * scale), rh = round(img_h * scale);
int pad_l = (input_w_ - rw) / 2, pad_t = (input_h_ - rh) / 2;
tf_matrix << 1.f / scale, 0, -pad_l / scale, 0, 1.f / scale, -pad_t / scale, 0, 0, 1;
size_t img_size = img_w * img_h * 3;
dim3 threads(TILE_W, TILE_H);
dim3 blocks((input_w_ + TILE_W - 1) / TILE_W, (input_h_ + TILE_H - 1) / TILE_H);
letterbox_kernel_shared<<<blocks, threads, 0, stream>>>(
input_bgr_device,
img_w,
img_h,
d_nchw_,
input_w_,
input_h_,
scale,
pad_t,
pad_l,
norm,
swap_rb
);
CUDA_CHECK(cudaGetLastError());
return d_nchw_;
}
float* CudaInfer::preprocess_pitched(
const unsigned char* input_bgr_host,
int img_w,
int img_h,
int host_step,
float norm,
bool swap_rb,
Eigen::Matrix3f& tf_matrix,
cudaStream_t stream
) {
if (!isInitialized()) {
throw std::runtime_error("CudaInfer not initialized properly.");
}
if (!input_bgr_host || !d_nchw_) {
fprintf(stderr, "[Error] Null pointer in preprocess input\n");
return nullptr;
}
getOutEnoughMem(img_w, img_h);
float scale = fminf((float)input_w_ / img_w, (float)input_h_ / img_h);
int rw = round(img_w * scale);
int rh = round(img_h * scale);
int pad_l = (input_w_ - rw) / 2;
int pad_t = (input_h_ - rh) / 2;
tf_matrix << 1.f / scale, 0, -pad_l / scale, 0, 1.f / scale, -pad_t / scale, 0, 0, 1;
CUDA_CHECK(cudaMemcpy2DAsync(
d_input_bgr_pitched_,
input_pitch_bytes_,
input_bgr_host,
host_step,
img_w * 3,
img_h,
cudaMemcpyHostToDevice,
stream
));
dim3 threads(TILE_W, TILE_H);
dim3 blocks((input_w_ + TILE_W - 1) / TILE_W, (input_h_ + TILE_H - 1) / TILE_H);
letterbox_kernel_pitched<<<blocks, threads, 0, stream>>>(
d_input_bgr_pitched_,
input_pitch_bytes_,
img_w,
img_h,
d_nchw_,
input_w_,
input_h_,
scale,
pad_t,
pad_l,
norm,
swap_rb
);
CUDA_CHECK(cudaGetLastError());
return d_nchw_;
}
float* CudaInfer::preprocess_pitched_gpu(
const unsigned char* input_bgr_device,
int img_w,
int img_h,
int input_step,
float norm,
bool swap_rb,
Eigen::Matrix3f& tf_matrix,
cudaStream_t stream
) {
if (!isInitialized()) {
throw std::runtime_error("CudaInfer not initialized properly.");
}
if (!input_bgr_device || !d_nchw_) {
fprintf(stderr, "[Error] Null pointer in preprocess_pitched_gpu\n");
return nullptr;
}
getOutEnoughMem(img_w, img_h);
float scale = fminf(static_cast<float>(input_w_) / img_w, static_cast<float>(input_h_) / img_h);
int rw = static_cast<int>(roundf(img_w * scale));
int rh = static_cast<int>(roundf(img_h * scale));
int pad_l = (input_w_ - rw) / 2;
int pad_t = (input_h_ - rh) / 2;
tf_matrix << 1.f / scale, 0.f, -pad_l / scale, 0.f, 1.f / scale, -pad_t / scale, 0.f, 0.f, 1.f;
dim3 threads(TILE_W, TILE_H);
dim3 blocks((input_w_ + TILE_W - 1) / TILE_W, (input_h_ + TILE_H - 1) / TILE_H);
letterbox_kernel_pitched<<<blocks, threads, 0, stream>>>(
input_bgr_device,
input_step,
img_w,
img_h,
d_nchw_,
input_w_,
input_h_,
scale,
pad_t,
pad_l,
norm,
swap_rb
);
CUDA_CHECK(cudaGetLastError());
return d_nchw_;
}
} // namespace armor_cuda_infer

View File

@@ -0,0 +1,78 @@
// armor_cuda_infer.hpp
#pragma once
#include <Eigen/Dense>
#include <NvInferRuntime.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#include <opencv2/core/mat.hpp>
#include <vector>
namespace armor_cuda_infer {
class CudaInfer {
public:
CudaInfer();
~CudaInfer() noexcept;
void init(int max_src_w, int max_src_h, int input_w, int input_h);
void release();
bool isInitialized() const {
return d_input_bgr_ && d_nchw_ && d_input_bgr_pitched_;
}
void getOutEnoughMem(int img_w, int img_h);
void rellocMem();
float* preprocess(
const unsigned char* input_bgr_host,
int img_w,
int img_h,
float norm,
bool swap_rb,
Eigen::Matrix3f& tf_matrix,
cudaStream_t stream
);
float* preprocess_pitched(
const unsigned char* input_bgr_host,
int img_w,
int img_h,
int host_step,
float norm,
bool swap_rb,
Eigen::Matrix3f& tf_matrix,
cudaStream_t stream
);
float* preprocess_gpu(
const unsigned char* input_bgr_device,
int img_w,
int img_h,
float norm,
bool swap_rb,
Eigen::Matrix3f& tf_matrix,
cudaStream_t stream
);
float* preprocess_pitched_gpu(
const unsigned char* input_bgr_device,
int img_w,
int img_h,
int host_step,
float norm,
bool swap_rb,
Eigen::Matrix3f& tf_matrix,
cudaStream_t stream
);
cv::Mat tensorToMat(float* d_nchw, int W, int H, float norm, cudaStream_t stream) const;
private:
CudaInfer(const CudaInfer&) = delete;
CudaInfer& operator=(const CudaInfer&) = delete;
unsigned char* d_input_bgr_ = nullptr;
float* d_nchw_ = nullptr;
unsigned char* d_input_bgr_pitched_ = nullptr;
size_t input_pitch_bytes_ = 0;
int input_w_;
int input_h_;
int max_src_w_, max_src_h_;
};
} // namespace armor_cuda_infer

View File

@@ -0,0 +1,147 @@
#include "letter_box.hpp"
__global__ void letterbox_kernel_shared(
const uchar* __restrict__ input_bgr,
int in_w,
int in_h,
float* __restrict__ output_nchw,
int out_w,
int out_h,
float scale,
int pad_t,
int pad_l,
float norm,
bool swap_rb
) {
// global x/y
int x = blockIdx.x * TILE_W + threadIdx.x;
int y = blockIdx.y * TILE_H + threadIdx.y;
if (x >= out_w || y >= out_h)
return;
// 共享内存 + halo
__shared__ uchar4 smem[TILE_H + 1][TILE_W + 1];
int tid = threadIdx.y * blockDim.x + threadIdx.x;
int total_smem = (TILE_W + 1) * (TILE_H + 1);
int threads_per_block = blockDim.x * blockDim.y;
int iter = (total_smem + threads_per_block - 1) / threads_per_block;
float inv_scale = 1.0f / scale;
float block_start_x = blockIdx.x * TILE_W - pad_l;
float block_start_y = blockIdx.y * TILE_H - pad_t;
// load shared memory
for (int i = 0; i < iter; i++) {
int idx = tid + i * threads_per_block;
if (idx < total_smem) {
int sx = idx % (TILE_W + 1);
int sy = idx / (TILE_W + 1);
float in_x = (block_start_x + sx) * inv_scale;
float in_y = (block_start_y + sy) * inv_scale;
int ix = floorf(in_x);
int iy = floorf(in_y);
uchar4 p = make_uchar4(114, 114, 114, 0); // padding BGR
if (ix >= 0 && iy >= 0 && ix < in_w && iy < in_h) {
int offset = (iy * in_w + ix) * 3;
p.x = input_bgr[offset]; // b
p.y = input_bgr[offset + 1]; // g
p.z = input_bgr[offset + 2]; // r
}
smem[sy][sx] = p;
}
}
__syncthreads();
// 双线性插值
float in_x = (x - pad_l) * inv_scale;
float in_y = (y - pad_t) * inv_scale;
int tx = threadIdx.x;
int ty = threadIdx.y;
float dx = in_x - floorf(in_x);
float dy = in_y - floorf(in_y);
float dx1 = 1.0f - dx, dy1 = 1.0f - dy;
uchar4 p00 = smem[ty][tx];
uchar4 p01 = smem[ty][tx + 1];
uchar4 p10 = smem[ty + 1][tx];
uchar4 p11 = smem[ty + 1][tx + 1];
float out_r = dx1 * dy1 * p00.z + dx * dy1 * p01.z + dx1 * dy * p10.z + dx * dy * p11.z;
float out_g = dx1 * dy1 * p00.y + dx * dy1 * p01.y + dx1 * dy * p10.y + dx * dy * p11.y;
float out_b = dx1 * dy1 * p00.x + dx * dy1 * p01.x + dx1 * dy * p10.x + dx * dy * p11.x;
int out_idx = y * out_w + x;
if (swap_rb) {
output_nchw[out_idx + 0 * out_w * out_h] = out_r * norm;
output_nchw[out_idx + 1 * out_w * out_h] = out_g * norm;
output_nchw[out_idx + 2 * out_w * out_h] = out_b * norm;
} else {
output_nchw[out_idx + 0 * out_w * out_h] = out_b * norm;
output_nchw[out_idx + 1 * out_w * out_h] = out_g * norm;
output_nchw[out_idx + 2 * out_w * out_h] = out_r * norm;
}
}
__global__ void letterbox_kernel_pitched(
const unsigned char* __restrict__ d_input_bgr,
size_t pitch,
int src_w,
int src_h,
float* __restrict__ d_nchw,
int OUT_W,
int OUT_H,
float scale,
int pad_t,
int pad_l,
float norm,
bool swap_rb
) {
int ox = blockIdx.x * blockDim.x + threadIdx.x;
int oy = blockIdx.y * blockDim.y + threadIdx.y;
if (ox >= OUT_W || oy >= OUT_H)
return;
float fx = (ox - pad_l) / scale;
float fy = (oy - pad_t) / scale;
int out_idx = oy * OUT_W + ox;
int plane = OUT_W * OUT_H;
// clamp coordinates
fx = fmaxf(0.f, fminf(fx, src_w - 2.f));
fy = fmaxf(0.f, fminf(fy, src_h - 2.f));
int x0 = floorf(fx), y0 = floorf(fy);
int x1 = x0 + 1, y1 = y0 + 1;
float dx = fx - x0, dy = fy - y0;
float dx1 = 1.f - dx, dy1 = 1.f - dy;
// row pointers
const uchar3* row0 = (const uchar3*)((const char*)d_input_bgr + y0 * pitch);
const uchar3* row1 = (const uchar3*)((const char*)d_input_bgr + y1 * pitch);
uchar3 p00 = row0[x0];
uchar3 p01 = row0[x1];
uchar3 p10 = row1[x0];
uchar3 p11 = row1[x1];
// bilinear interpolation
float r = dx1 * dy1 * p00.z + dx * dy1 * p01.z + dx1 * dy * p10.z + dx * dy * p11.z;
float g = dx1 * dy1 * p00.y + dx * dy1 * p01.y + dx1 * dy * p10.y + dx * dy * p11.y;
float b = dx1 * dy1 * p00.x + dx * dy1 * p01.x + dx1 * dy * p10.x + dx * dy * p11.x;
if (swap_rb) {
d_nchw[out_idx + 0 * plane] = r * norm;
d_nchw[out_idx + 1 * plane] = g * norm;
d_nchw[out_idx + 2 * plane] = b * norm;
} else {
d_nchw[out_idx + 0 * plane] = b * norm;
d_nchw[out_idx + 1 * plane] = g * norm;
d_nchw[out_idx + 2 * plane] = r * norm;
}
}

View File

@@ -0,0 +1,42 @@
#pragma once
#include <Eigen/Dense>
#include <NvInferRuntime.h>
#include <cmath>
#include <cstdio>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#include <opencv2/core/hal/interface.h>
#include <opencv2/core/mat.hpp>
#include <thrust/device_ptr.h>
#include <thrust/sort.h>
#include <vector>
static constexpr int TILE_W = 32;
static constexpr int TILE_H = 32;
__global__ void letterbox_kernel_shared(
const uchar* __restrict__ input_bgr,
int in_w,
int in_h,
float* __restrict__ output_nchw,
int out_w,
int out_h,
float scale,
int pad_t,
int pad_l,
float norm,
bool swap_rb
);
__global__ void letterbox_kernel_pitched(
const unsigned char* __restrict__ d_input_bgr,
size_t pitch,
int src_w,
int src_h,
float* __restrict__ d_nchw,
int OUT_W,
int OUT_H,
float scale,
int pad_t,
int pad_l,
float norm,
bool swap_rb
);

19
wust_vision-main/env.bash Normal file
View File

@@ -0,0 +1,19 @@
# #!/bin/bash
export MVCAM_SDK_PATH=/opt/MVS
export MVCAM_COMMON_RUNENV=/opt/MVS/lib
export MVCAM_GENICAM_CLPROTOCOL=/opt/MVS/lib/CLProtocol
export ALLUSERSPROFILE=/opt/MVS/MVFG
export LD_LIBRARY_PATH=/opt/MVS/lib/64:/opt/MVS/lib/32:$LD_LIBRARY_PATH
export MVCAM_SDK_PATH=/opt/MVS
export MVCAM_SDK_VERSION=
export MVCAM_COMMON_RUNENV=/opt/MVS/lib
export MVCAM_GENICAM_CLPROTOCOL=/opt/MVS/lib/CLProtocol
export ALLUSERSPROFILE=/opt/MVS/MVFG
export LD_LIBRARY_PATH=/opt/MVS/lib/aarch64:$LD_LIBRARY_PATH
WORK_DIR="$(dirname "$(realpath "${BASH_SOURCE[0]}")")"
export VISION_ROOT="$WORK_DIR"

3
wust_vision-main/format.sh Executable file
View File

@@ -0,0 +1,3 @@
find . -path ./build -prune -o \
-type f \( -name '*.h' -o -name '*.hpp' -o -name '*.c' -o -name '*.cu' -o -name '*.cpp' \) \
-exec clang-format -i {} +

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,9 @@
1
2
3
4
5
outpost
guard
base
negative

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,218 @@
7767517
216 240
Input images 0 1 images
Convolution Conv_1 1 1 images input.4 0=16 1=6 11=6 2=1 12=1 3=2 13=2 4=2 14=2 15=2 16=2 5=1 6=1728
HardSwish HardSwish_2 1 1 input.4 onnx::Conv_483 0=2.000000e-01 1=5.000000e-01
ConvolutionDepthWise Conv_3 1 1 onnx::Conv_483 input.12 0=16 1=3 11=3 2=1 12=1 3=2 13=2 4=1 14=1 15=1 16=1 5=1 6=144 7=16
Convolution Conv_4 1 1 input.12 input.20 0=32 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=512
HardSwish HardSwish_5 1 1 input.20 onnx::Conv_488 0=2.000000e-01 1=5.000000e-01
Split splitncnn_0 1 2 onnx::Conv_488 onnx::Conv_488_splitncnn_0 onnx::Conv_488_splitncnn_1
ConvolutionDepthWise Conv_6 1 1 onnx::Conv_488_splitncnn_1 input.28 0=32 1=3 11=3 2=1 12=1 3=2 13=2 4=1 14=1 15=1 16=1 5=1 6=288 7=32
Convolution Conv_7 1 1 input.28 input.36 0=16 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=512
HardSwish HardSwish_8 1 1 input.36 onnx::Concat_493 0=2.000000e-01 1=5.000000e-01
Convolution Conv_9 1 1 onnx::Conv_488_splitncnn_0 input.44 0=16 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=512
HardSwish HardSwish_10 1 1 input.44 onnx::Conv_496 0=2.000000e-01 1=5.000000e-01
ConvolutionDepthWise Conv_11 1 1 onnx::Conv_496 input.52 0=16 1=3 11=3 2=1 12=1 3=2 13=2 4=1 14=1 15=1 16=1 5=1 6=144 7=16
Convolution Conv_12 1 1 input.52 input.60 0=48 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=768
HardSwish HardSwish_13 1 1 input.60 onnx::Concat_501 0=2.000000e-01 1=5.000000e-01
Concat Concat_14 2 1 onnx::Concat_493 onnx::Concat_501 x 0=0
ShuffleChannel Reshape_19 1 1 x onnx::Slice_507 0=2 1=0
Split splitncnn_1 1 2 onnx::Slice_507 onnx::Slice_507_splitncnn_0 onnx::Slice_507_splitncnn_1
Crop Slice_24 1 1 onnx::Slice_507_splitncnn_1 onnx::Concat_512 -23309=1,0 -23310=1,32 -23311=1,0
Crop Slice_29 1 1 onnx::Slice_507_splitncnn_0 onnx::Conv_517 -23309=1,32 -23310=1,2147483647 -23311=1,0
Convolution Conv_30 1 1 onnx::Conv_517 input.68 0=32 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=1024
HardSwish HardSwish_31 1 1 input.68 onnx::Conv_520 0=2.000000e-01 1=5.000000e-01
ConvolutionDepthWise Conv_32 1 1 onnx::Conv_520 input.76 0=32 1=3 11=3 2=1 12=1 3=1 13=1 4=1 14=1 15=1 16=1 5=1 6=288 7=32
Convolution Conv_33 1 1 input.76 input.84 0=32 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=1024
HardSwish HardSwish_34 1 1 input.84 onnx::Concat_525 0=2.000000e-01 1=5.000000e-01
Concat Concat_35 2 1 onnx::Concat_512 onnx::Concat_525 x.3 0=0
ShuffleChannel Reshape_40 1 1 x.3 onnx::Slice_531 0=2 1=0
Split splitncnn_2 1 2 onnx::Slice_531 onnx::Slice_531_splitncnn_0 onnx::Slice_531_splitncnn_1
Crop Slice_45 1 1 onnx::Slice_531_splitncnn_1 onnx::Concat_536 -23309=1,0 -23310=1,32 -23311=1,0
Crop Slice_50 1 1 onnx::Slice_531_splitncnn_0 onnx::Conv_541 -23309=1,32 -23310=1,2147483647 -23311=1,0
Convolution Conv_51 1 1 onnx::Conv_541 input.92 0=32 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=1024
HardSwish HardSwish_52 1 1 input.92 onnx::Conv_544 0=2.000000e-01 1=5.000000e-01
ConvolutionDepthWise Conv_53 1 1 onnx::Conv_544 input.100 0=32 1=3 11=3 2=1 12=1 3=1 13=1 4=1 14=1 15=1 16=1 5=1 6=288 7=32
Convolution Conv_54 1 1 input.100 input.108 0=32 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=1024
HardSwish HardSwish_55 1 1 input.108 onnx::Concat_549 0=2.000000e-01 1=5.000000e-01
Concat Concat_56 2 1 onnx::Concat_536 onnx::Concat_549 x.7 0=0
ShuffleChannel Reshape_61 1 1 x.7 input.112 0=2 1=0
Split splitncnn_3 1 4 input.112 input.112_splitncnn_0 input.112_splitncnn_1 input.112_splitncnn_2 input.112_splitncnn_3
ConvolutionDepthWise Conv_62 1 1 input.112_splitncnn_3 input.120 0=64 1=3 11=3 2=1 12=1 3=2 13=2 4=1 14=1 15=1 16=1 5=1 6=576 7=64
Convolution Conv_63 1 1 input.120 input.128 0=32 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=2048
HardSwish HardSwish_64 1 1 input.128 onnx::Concat_560 0=2.000000e-01 1=5.000000e-01
Convolution Conv_65 1 1 input.112_splitncnn_2 input.136 0=32 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=2048
HardSwish HardSwish_66 1 1 input.136 onnx::Conv_563 0=2.000000e-01 1=5.000000e-01
ConvolutionDepthWise Conv_67 1 1 onnx::Conv_563 input.144 0=32 1=3 11=3 2=1 12=1 3=2 13=2 4=1 14=1 15=1 16=1 5=1 6=288 7=32
Convolution Conv_68 1 1 input.144 input.152 0=96 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=3072
HardSwish HardSwish_69 1 1 input.152 onnx::Concat_568 0=2.000000e-01 1=5.000000e-01
Concat Concat_70 2 1 onnx::Concat_560 onnx::Concat_568 x.11 0=0
ShuffleChannel Reshape_75 1 1 x.11 onnx::Slice_574 0=2 1=0
Split splitncnn_4 1 2 onnx::Slice_574 onnx::Slice_574_splitncnn_0 onnx::Slice_574_splitncnn_1
Crop Slice_80 1 1 onnx::Slice_574_splitncnn_1 onnx::Concat_579 -23309=1,0 -23310=1,64 -23311=1,0
Crop Slice_85 1 1 onnx::Slice_574_splitncnn_0 onnx::Conv_584 -23309=1,64 -23310=1,2147483647 -23311=1,0
Convolution Conv_86 1 1 onnx::Conv_584 input.160 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096
HardSwish HardSwish_87 1 1 input.160 onnx::Conv_587 0=2.000000e-01 1=5.000000e-01
ConvolutionDepthWise Conv_88 1 1 onnx::Conv_587 input.168 0=64 1=3 11=3 2=1 12=1 3=1 13=1 4=1 14=1 15=1 16=1 5=1 6=576 7=64
Convolution Conv_89 1 1 input.168 input.176 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096
HardSwish HardSwish_90 1 1 input.176 onnx::Concat_592 0=2.000000e-01 1=5.000000e-01
Concat Concat_91 2 1 onnx::Concat_579 onnx::Concat_592 x.15 0=0
ShuffleChannel Reshape_96 1 1 x.15 onnx::Slice_598 0=2 1=0
Split splitncnn_5 1 2 onnx::Slice_598 onnx::Slice_598_splitncnn_0 onnx::Slice_598_splitncnn_1
Crop Slice_101 1 1 onnx::Slice_598_splitncnn_1 onnx::Concat_603 -23309=1,0 -23310=1,64 -23311=1,0
Crop Slice_106 1 1 onnx::Slice_598_splitncnn_0 onnx::Conv_608 -23309=1,64 -23310=1,2147483647 -23311=1,0
Convolution Conv_107 1 1 onnx::Conv_608 input.184 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096
HardSwish HardSwish_108 1 1 input.184 onnx::Conv_611 0=2.000000e-01 1=5.000000e-01
ConvolutionDepthWise Conv_109 1 1 onnx::Conv_611 input.192 0=64 1=3 11=3 2=1 12=1 3=1 13=1 4=1 14=1 15=1 16=1 5=1 6=576 7=64
Convolution Conv_110 1 1 input.192 input.200 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096
HardSwish HardSwish_111 1 1 input.200 onnx::Concat_616 0=2.000000e-01 1=5.000000e-01
Concat Concat_112 2 1 onnx::Concat_603 onnx::Concat_616 x.19 0=0
ShuffleChannel Reshape_117 1 1 x.19 input.204 0=2 1=0
Split splitncnn_6 1 3 input.204 input.204_splitncnn_0 input.204_splitncnn_1 input.204_splitncnn_2
ConvolutionDepthWise Conv_118 1 1 input.204_splitncnn_2 input.212 0=128 1=3 11=3 2=1 12=1 3=2 13=2 4=1 14=1 15=1 16=1 5=1 6=1152 7=128
Convolution Conv_119 1 1 input.212 input.220 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=8192
HardSwish HardSwish_120 1 1 input.220 onnx::Concat_627 0=2.000000e-01 1=5.000000e-01
Convolution Conv_121 1 1 input.204_splitncnn_1 input.228 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=8192
HardSwish HardSwish_122 1 1 input.228 onnx::Conv_630 0=2.000000e-01 1=5.000000e-01
ConvolutionDepthWise Conv_123 1 1 onnx::Conv_630 input.236 0=64 1=3 11=3 2=1 12=1 3=2 13=2 4=1 14=1 15=1 16=1 5=1 6=576 7=64
Convolution Conv_124 1 1 input.236 input.244 0=192 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=12288
HardSwish HardSwish_125 1 1 input.244 onnx::Concat_635 0=2.000000e-01 1=5.000000e-01
Concat Concat_126 2 1 onnx::Concat_627 onnx::Concat_635 x.23 0=0
ShuffleChannel Reshape_131 1 1 x.23 onnx::Slice_641 0=2 1=0
Split splitncnn_7 1 2 onnx::Slice_641 onnx::Slice_641_splitncnn_0 onnx::Slice_641_splitncnn_1
Crop Slice_136 1 1 onnx::Slice_641_splitncnn_1 onnx::Concat_646 -23309=1,0 -23310=1,128 -23311=1,0
Crop Slice_141 1 1 onnx::Slice_641_splitncnn_0 onnx::Conv_651 -23309=1,128 -23310=1,2147483647 -23311=1,0
Convolution Conv_142 1 1 onnx::Conv_651 input.252 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=16384
HardSwish HardSwish_143 1 1 input.252 onnx::Conv_654 0=2.000000e-01 1=5.000000e-01
ConvolutionDepthWise Conv_144 1 1 onnx::Conv_654 input.260 0=128 1=3 11=3 2=1 12=1 3=1 13=1 4=1 14=1 15=1 16=1 5=1 6=1152 7=128
Convolution Conv_145 1 1 input.260 input.268 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=16384
HardSwish HardSwish_146 1 1 input.268 onnx::Concat_659 0=2.000000e-01 1=5.000000e-01
Concat Concat_147 2 1 onnx::Concat_646 onnx::Concat_659 x.27 0=0
ShuffleChannel Reshape_152 1 1 x.27 onnx::Slice_665 0=2 1=0
Split splitncnn_8 1 2 onnx::Slice_665 onnx::Slice_665_splitncnn_0 onnx::Slice_665_splitncnn_1
Crop Slice_157 1 1 onnx::Slice_665_splitncnn_1 onnx::Concat_670 -23309=1,0 -23310=1,128 -23311=1,0
Crop Slice_162 1 1 onnx::Slice_665_splitncnn_0 onnx::Conv_675 -23309=1,128 -23310=1,2147483647 -23311=1,0
Convolution Conv_163 1 1 onnx::Conv_675 input.276 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=16384
HardSwish HardSwish_164 1 1 input.276 onnx::Conv_678 0=2.000000e-01 1=5.000000e-01
ConvolutionDepthWise Conv_165 1 1 onnx::Conv_678 input.284 0=128 1=3 11=3 2=1 12=1 3=1 13=1 4=1 14=1 15=1 16=1 5=1 6=1152 7=128
Convolution Conv_166 1 1 input.284 input.292 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=16384
HardSwish HardSwish_167 1 1 input.292 onnx::Concat_683 0=2.000000e-01 1=5.000000e-01
Concat Concat_168 2 1 onnx::Concat_670 onnx::Concat_683 x.31 0=0
ShuffleChannel Reshape_173 1 1 x.31 input.296 0=2 1=0
Convolution Conv_174 1 1 input.296 input.304 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=32768
HardSwish HardSwish_175 1 1 input.304 onnx::Conv_692 0=2.000000e-01 1=5.000000e-01
Split splitncnn_9 1 2 onnx::Conv_692 onnx::Conv_692_splitncnn_0 onnx::Conv_692_splitncnn_1
Convolution Conv_176 1 1 onnx::Conv_692_splitncnn_1 input.312 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=8192
HardSwish HardSwish_177 1 1 input.312 onnx::Resize_695 0=2.000000e-01 1=5.000000e-01
Interp Resize_178 1 1 onnx::Resize_695 onnx::Concat_700 0=1 1=2.000000e+00 2=2.000000e+00 3=0 4=0 6=0
ConvolutionDepthWise Conv_179 1 1 input.112_splitncnn_1 input.320 0=64 1=3 11=3 2=1 12=1 3=2 13=2 4=1 14=1 15=1 16=1 5=1 6=576 7=64
Convolution Conv_180 1 1 input.320 input.328 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096
HardSwish HardSwish_181 1 1 input.328 onnx::Concat_705 0=2.000000e-01 1=5.000000e-01
Concat Concat_182 3 1 onnx::Concat_705 input.204_splitncnn_0 onnx::Concat_700 input.332 0=0
Split splitncnn_10 1 2 input.332 input.332_splitncnn_0 input.332_splitncnn_1
Convolution Conv_183 1 1 input.332_splitncnn_1 input.340 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=16384
HardSwish HardSwish_184 1 1 input.340 onnx::Conv_709 0=2.000000e-01 1=5.000000e-01
Convolution Conv_185 1 1 input.332_splitncnn_0 input.348 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=16384
HardSwish HardSwish_186 1 1 input.348 onnx::Concat_712 0=2.000000e-01 1=5.000000e-01
Convolution Conv_187 1 1 onnx::Conv_709 input.356 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096
HardSwish HardSwish_188 1 1 input.356 onnx::Conv_715 0=2.000000e-01 1=5.000000e-01
ConvolutionDepthWise Conv_189 1 1 onnx::Conv_715 input.364 0=64 1=3 11=3 2=1 12=1 3=1 13=1 4=1 14=1 15=1 16=1 5=1 6=576 7=64
Convolution Conv_190 1 1 input.364 input.372 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096
HardSwish HardSwish_191 1 1 input.372 onnx::Concat_720 0=2.000000e-01 1=5.000000e-01
Concat Concat_192 2 1 onnx::Concat_720 onnx::Concat_712 input.376 0=0
Convolution Conv_193 1 1 input.376 input.384 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=16384
HardSwish HardSwish_194 1 1 input.384 onnx::Conv_724 0=2.000000e-01 1=5.000000e-01
Split splitncnn_11 1 3 onnx::Conv_724 onnx::Conv_724_splitncnn_0 onnx::Conv_724_splitncnn_1 onnx::Conv_724_splitncnn_2
Convolution Conv_195 1 1 onnx::Conv_724_splitncnn_2 input.392 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=8192
HardSwish HardSwish_196 1 1 input.392 onnx::Resize_727 0=2.000000e-01 1=5.000000e-01
Interp Resize_197 1 1 onnx::Resize_727 onnx::Concat_732 0=1 1=2.000000e+00 2=2.000000e+00 3=0 4=0 6=0
Concat Concat_198 2 1 input.112_splitncnn_0 onnx::Concat_732 input.396 0=0
Split splitncnn_12 1 2 input.396 input.396_splitncnn_0 input.396_splitncnn_1
Convolution Conv_199 1 1 input.396_splitncnn_1 input.404 0=32 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096
HardSwish HardSwish_200 1 1 input.404 onnx::Conv_736 0=2.000000e-01 1=5.000000e-01
Convolution Conv_201 1 1 input.396_splitncnn_0 input.412 0=32 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096
HardSwish HardSwish_202 1 1 input.412 onnx::Concat_739 0=2.000000e-01 1=5.000000e-01
Convolution Conv_203 1 1 onnx::Conv_736 input.420 0=32 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=1024
HardSwish HardSwish_204 1 1 input.420 onnx::Conv_742 0=2.000000e-01 1=5.000000e-01
ConvolutionDepthWise Conv_205 1 1 onnx::Conv_742 input.428 0=32 1=3 11=3 2=1 12=1 3=1 13=1 4=1 14=1 15=1 16=1 5=1 6=288 7=32
Convolution Conv_206 1 1 input.428 input.436 0=32 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=1024
HardSwish HardSwish_207 1 1 input.436 onnx::Concat_747 0=2.000000e-01 1=5.000000e-01
Concat Concat_208 2 1 onnx::Concat_747 onnx::Concat_739 input.440 0=0
Convolution Conv_209 1 1 input.440 input.448 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096
HardSwish HardSwish_210 1 1 input.448 onnx::Conv_751 0=2.000000e-01 1=5.000000e-01
ConvolutionDepthWise Conv_211 1 1 onnx::Conv_724_splitncnn_1 input.456 0=128 1=3 11=3 2=1 12=1 3=2 13=2 4=1 14=1 15=1 16=1 5=1 6=1152 7=128
Convolution Conv_212 1 1 input.456 input.464 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=16384
HardSwish HardSwish_213 1 1 input.464 onnx::Concat_756 0=2.000000e-01 1=5.000000e-01
Concat Concat_214 2 1 onnx::Concat_756 onnx::Conv_692_splitncnn_0 input.468 0=0
Split splitncnn_13 1 2 input.468 input.468_splitncnn_0 input.468_splitncnn_1
Convolution Conv_215 1 1 input.468_splitncnn_1 input.476 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=32768
HardSwish HardSwish_216 1 1 input.476 onnx::Conv_760 0=2.000000e-01 1=5.000000e-01
Convolution Conv_217 1 1 input.468_splitncnn_0 input.484 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=32768
HardSwish HardSwish_218 1 1 input.484 onnx::Concat_763 0=2.000000e-01 1=5.000000e-01
Convolution Conv_219 1 1 onnx::Conv_760 input.492 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=16384
HardSwish HardSwish_220 1 1 input.492 onnx::Conv_766 0=2.000000e-01 1=5.000000e-01
ConvolutionDepthWise Conv_221 1 1 onnx::Conv_766 input.500 0=128 1=3 11=3 2=1 12=1 3=1 13=1 4=1 14=1 15=1 16=1 5=1 6=1152 7=128
Convolution Conv_222 1 1 input.500 input.508 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=16384
HardSwish HardSwish_223 1 1 input.508 onnx::Concat_771 0=2.000000e-01 1=5.000000e-01
Concat Concat_224 2 1 onnx::Concat_771 onnx::Concat_763 input.512 0=0
Convolution Conv_225 1 1 input.512 input.520 0=256 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=65536
HardSwish HardSwish_226 1 1 input.520 onnx::Conv_775 0=2.000000e-01 1=5.000000e-01
Convolution Conv_227 1 1 onnx::Conv_751 input.528 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096
HardSwish HardSwish_228 1 1 input.528 onnx::Conv_778 0=2.000000e-01 1=5.000000e-01
ConvolutionDepthWise Conv_229 1 1 onnx::Conv_778 input.536 0=64 1=5 11=5 2=1 12=1 3=1 13=1 4=2 14=2 15=2 16=2 5=1 6=1600 7=64
ConvolutionDepthWise Conv_230 1 1 input.536 input.544 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096 7=2
HardSwish HardSwish_231 1 1 input.544 onnx::Conv_783 0=2.000000e-01 1=5.000000e-01
ConvolutionDepthWise Conv_232 1 1 onnx::Conv_783 input.552 0=128 1=5 11=5 2=1 12=1 3=1 13=1 4=2 14=2 15=2 16=2 5=1 6=3200 7=128
ConvolutionDepthWise Conv_233 1 1 input.552 input.560 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=8192 7=2
HardSwish HardSwish_234 1 1 input.560 onnx::Slice_788 0=2.000000e-01 1=5.000000e-01
Split splitncnn_14 1 2 onnx::Slice_788 onnx::Slice_788_splitncnn_0 onnx::Slice_788_splitncnn_1
Crop Slice_239 1 1 onnx::Slice_788_splitncnn_1 onnx::Conv_793 -23309=1,0 -23310=1,64 -23311=1,0
Split splitncnn_15 1 2 onnx::Conv_793 onnx::Conv_793_splitncnn_0 onnx::Conv_793_splitncnn_1
Crop Slice_244 1 1 onnx::Slice_788_splitncnn_0 onnx::Conv_798 -23309=1,64 -23310=1,2147483647 -23311=1,0
Convolution Conv_245 1 1 onnx::Conv_798 onnx::Sigmoid_799 0=12 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=768
Convolution Conv_246 1 1 onnx::Conv_793_splitncnn_1 onnx::Concat_800 0=8 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=512
Convolution Conv_247 1 1 onnx::Conv_793_splitncnn_0 onnx::Sigmoid_801 0=1 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=64
Sigmoid Sigmoid_248 1 1 onnx::Sigmoid_801 onnx::Concat_802
Sigmoid Sigmoid_249 1 1 onnx::Sigmoid_799 onnx::Concat_803
Concat Concat_250 3 1 onnx::Concat_800 onnx::Concat_802 onnx::Concat_803 onnx::Shape_804 0=0
Convolution Conv_251 1 1 onnx::Conv_724_splitncnn_0 input.568 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=8192
HardSwish HardSwish_252 1 1 input.568 onnx::Conv_807 0=2.000000e-01 1=5.000000e-01
ConvolutionDepthWise Conv_253 1 1 onnx::Conv_807 input.576 0=64 1=5 11=5 2=1 12=1 3=1 13=1 4=2 14=2 15=2 16=2 5=1 6=1600 7=64
ConvolutionDepthWise Conv_254 1 1 input.576 input.584 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096 7=2
HardSwish HardSwish_255 1 1 input.584 onnx::Conv_812 0=2.000000e-01 1=5.000000e-01
ConvolutionDepthWise Conv_256 1 1 onnx::Conv_812 input.592 0=128 1=5 11=5 2=1 12=1 3=1 13=1 4=2 14=2 15=2 16=2 5=1 6=3200 7=128
ConvolutionDepthWise Conv_257 1 1 input.592 input.600 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=8192 7=2
HardSwish HardSwish_258 1 1 input.600 onnx::Slice_817 0=2.000000e-01 1=5.000000e-01
Split splitncnn_16 1 2 onnx::Slice_817 onnx::Slice_817_splitncnn_0 onnx::Slice_817_splitncnn_1
Crop Slice_263 1 1 onnx::Slice_817_splitncnn_1 onnx::Conv_822 -23309=1,0 -23310=1,64 -23311=1,0
Split splitncnn_17 1 2 onnx::Conv_822 onnx::Conv_822_splitncnn_0 onnx::Conv_822_splitncnn_1
Crop Slice_268 1 1 onnx::Slice_817_splitncnn_0 onnx::Conv_827 -23309=1,64 -23310=1,2147483647 -23311=1,0
Convolution Conv_269 1 1 onnx::Conv_827 onnx::Sigmoid_828 0=12 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=768
Convolution Conv_270 1 1 onnx::Conv_822_splitncnn_1 onnx::Concat_829 0=8 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=512
Convolution Conv_271 1 1 onnx::Conv_822_splitncnn_0 onnx::Sigmoid_830 0=1 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=64
Sigmoid Sigmoid_272 1 1 onnx::Sigmoid_830 onnx::Concat_831
Sigmoid Sigmoid_273 1 1 onnx::Sigmoid_828 onnx::Concat_832
Concat Concat_274 3 1 onnx::Concat_829 onnx::Concat_831 onnx::Concat_832 onnx::Shape_833 0=0
Convolution Conv_275 1 1 onnx::Conv_775 input.608 0=64 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=16384
HardSwish HardSwish_276 1 1 input.608 onnx::Conv_836 0=2.000000e-01 1=5.000000e-01
ConvolutionDepthWise Conv_277 1 1 onnx::Conv_836 input.616 0=64 1=5 11=5 2=1 12=1 3=1 13=1 4=2 14=2 15=2 16=2 5=1 6=1600 7=64
ConvolutionDepthWise Conv_278 1 1 input.616 input.624 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=4096 7=2
HardSwish HardSwish_279 1 1 input.624 onnx::Conv_841 0=2.000000e-01 1=5.000000e-01
ConvolutionDepthWise Conv_280 1 1 onnx::Conv_841 input.632 0=128 1=5 11=5 2=1 12=1 3=1 13=1 4=2 14=2 15=2 16=2 5=1 6=3200 7=128
ConvolutionDepthWise Conv_281 1 1 input.632 input.640 0=128 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=8192 7=2
HardSwish HardSwish_282 1 1 input.640 onnx::Slice_846 0=2.000000e-01 1=5.000000e-01
Split splitncnn_18 1 2 onnx::Slice_846 onnx::Slice_846_splitncnn_0 onnx::Slice_846_splitncnn_1
Crop Slice_287 1 1 onnx::Slice_846_splitncnn_1 onnx::Conv_851 -23309=1,0 -23310=1,64 -23311=1,0
Split splitncnn_19 1 2 onnx::Conv_851 onnx::Conv_851_splitncnn_0 onnx::Conv_851_splitncnn_1
Crop Slice_292 1 1 onnx::Slice_846_splitncnn_0 onnx::Conv_856 -23309=1,64 -23310=1,2147483647 -23311=1,0
Convolution Conv_293 1 1 onnx::Conv_856 onnx::Sigmoid_857 0=12 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=768
Convolution Conv_294 1 1 onnx::Conv_851_splitncnn_1 onnx::Concat_858 0=8 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=512
Convolution Conv_295 1 1 onnx::Conv_851_splitncnn_0 onnx::Sigmoid_859 0=1 1=1 11=1 2=1 12=1 3=1 13=1 4=0 14=0 15=0 16=0 5=1 6=64
Sigmoid Sigmoid_296 1 1 onnx::Sigmoid_859 onnx::Concat_860
Sigmoid Sigmoid_297 1 1 onnx::Sigmoid_857 onnx::Concat_861
Concat Concat_298 3 1 onnx::Concat_858 onnx::Concat_860 onnx::Concat_861 output.1 0=0
Reshape Reshape_306 1 1 onnx::Shape_804 onnx::Concat_870 0=-1 1=21
Reshape Reshape_314 1 1 onnx::Shape_833 onnx::Concat_878 0=-1 1=21
Reshape Reshape_322 1 1 output.1 onnx::Concat_886 0=-1 1=21
Concat Concat_323 3 1 onnx::Concat_870 onnx::Concat_878 onnx::Concat_886 onnx::Transpose_887 0=1
Permute Transpose_324 1 1 onnx::Transpose_887 output 0=1

Binary file not shown.

View File

@@ -0,0 +1,193 @@
7767517
191 215
Input in0 0 1 in0
Convolution conv_69 1 1 in0 1 0=16 1=6 11=6 12=1 13=2 14=2 2=1 3=2 4=2 5=1 6=1728
HardSwish hswish_0 1 1 1 2 0=1.666667e-01 1=5.000000e-01
ConvolutionDepthWise convdw_128 1 1 2 3 0=16 1=3 11=3 12=1 13=2 14=1 2=1 3=2 4=1 5=1 6=144 7=16
Convolution conv_70 1 1 3 4 0=32 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=512
HardSwish hswish_1 1 1 4 5 0=1.666667e-01 1=5.000000e-01
Split splitncnn_0 1 2 5 6 7
ConvolutionDepthWise convdw_129 1 1 7 8 0=32 1=3 11=3 12=1 13=2 14=1 2=1 3=2 4=1 5=1 6=288 7=32
Convolution conv_71 1 1 8 9 0=16 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=512
HardSwish hswish_2 1 1 9 10 0=1.666667e-01 1=5.000000e-01
Convolution conv_72 1 1 6 11 0=16 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=512
HardSwish hswish_3 1 1 11 12 0=1.666667e-01 1=5.000000e-01
ConvolutionDepthWise convdw_130 1 1 12 13 0=16 1=3 11=3 12=1 13=2 14=1 2=1 3=2 4=1 5=1 6=144 7=16
Convolution conv_73 1 1 13 14 0=48 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=768
HardSwish hswish_4 1 1 14 15 0=1.666667e-01 1=5.000000e-01
Concat cat_0 2 1 10 15 16 0=0
ShuffleChannel channelshuffle_60 1 1 16 17 0=2 1=0
Slice tensor_split_0 1 2 17 18 19 -23300=2,32,-233 1=0
Convolution conv_74 1 1 19 20 0=32 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=1024
HardSwish hswish_5 1 1 20 21 0=1.666667e-01 1=5.000000e-01
ConvolutionDepthWise convdw_131 1 1 21 22 0=32 1=3 11=3 12=1 13=1 14=1 2=1 3=1 4=1 5=1 6=288 7=32
Convolution conv_75 1 1 22 23 0=32 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=1024
HardSwish hswish_6 1 1 23 24 0=1.666667e-01 1=5.000000e-01
Concat cat_1 2 1 18 24 25 0=0
ShuffleChannel channelshuffle_61 1 1 25 26 0=2 1=0
Slice tensor_split_1 1 2 26 27 28 -23300=2,32,-233 1=0
Convolution conv_76 1 1 28 29 0=32 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=1024
HardSwish hswish_7 1 1 29 30 0=1.666667e-01 1=5.000000e-01
ConvolutionDepthWise convdw_132 1 1 30 31 0=32 1=3 11=3 12=1 13=1 14=1 2=1 3=1 4=1 5=1 6=288 7=32
Convolution conv_77 1 1 31 32 0=32 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=1024
HardSwish hswish_8 1 1 32 33 0=1.666667e-01 1=5.000000e-01
Concat cat_2 2 1 27 33 34 0=0
ShuffleChannel channelshuffle_62 1 1 34 35 0=2 1=0
Split splitncnn_1 1 4 35 36 37 38 39
ConvolutionDepthWise convdw_133 1 1 39 40 0=64 1=3 11=3 12=1 13=2 14=1 2=1 3=2 4=1 5=1 6=576 7=64
Convolution conv_78 1 1 40 41 0=32 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=2048
HardSwish hswish_9 1 1 41 42 0=1.666667e-01 1=5.000000e-01
Convolution conv_79 1 1 37 43 0=32 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=2048
HardSwish hswish_10 1 1 43 44 0=1.666667e-01 1=5.000000e-01
ConvolutionDepthWise convdw_134 1 1 44 45 0=32 1=3 11=3 12=1 13=2 14=1 2=1 3=2 4=1 5=1 6=288 7=32
Convolution conv_80 1 1 45 46 0=96 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=3072
HardSwish hswish_11 1 1 46 47 0=1.666667e-01 1=5.000000e-01
Concat cat_3 2 1 42 47 48 0=0
ShuffleChannel channelshuffle_63 1 1 48 49 0=2 1=0
Slice tensor_split_2 1 2 49 50 51 -23300=2,64,-233 1=0
Convolution conv_81 1 1 51 52 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096
HardSwish hswish_12 1 1 52 53 0=1.666667e-01 1=5.000000e-01
ConvolutionDepthWise convdw_135 1 1 53 54 0=64 1=3 11=3 12=1 13=1 14=1 2=1 3=1 4=1 5=1 6=576 7=64
Convolution conv_82 1 1 54 55 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096
HardSwish hswish_13 1 1 55 56 0=1.666667e-01 1=5.000000e-01
Concat cat_4 2 1 50 56 57 0=0
ShuffleChannel channelshuffle_64 1 1 57 58 0=2 1=0
Slice tensor_split_3 1 2 58 59 60 -23300=2,64,-233 1=0
Convolution conv_83 1 1 60 61 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096
HardSwish hswish_14 1 1 61 62 0=1.666667e-01 1=5.000000e-01
ConvolutionDepthWise convdw_136 1 1 62 63 0=64 1=3 11=3 12=1 13=1 14=1 2=1 3=1 4=1 5=1 6=576 7=64
Convolution conv_84 1 1 63 64 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096
HardSwish hswish_15 1 1 64 65 0=1.666667e-01 1=5.000000e-01
Concat cat_5 2 1 59 65 66 0=0
ShuffleChannel channelshuffle_65 1 1 66 67 0=2 1=0
Split splitncnn_2 1 3 67 68 69 70
ConvolutionDepthWise convdw_137 1 1 70 71 0=128 1=3 11=3 12=1 13=2 14=1 2=1 3=2 4=1 5=1 6=1152 7=128
Convolution conv_85 1 1 71 72 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=8192
HardSwish hswish_16 1 1 72 73 0=1.666667e-01 1=5.000000e-01
Convolution conv_86 1 1 69 74 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=8192
HardSwish hswish_17 1 1 74 75 0=1.666667e-01 1=5.000000e-01
ConvolutionDepthWise convdw_138 1 1 75 76 0=64 1=3 11=3 12=1 13=2 14=1 2=1 3=2 4=1 5=1 6=576 7=64
Convolution conv_87 1 1 76 77 0=192 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=12288
HardSwish hswish_18 1 1 77 78 0=1.666667e-01 1=5.000000e-01
Concat cat_6 2 1 73 78 79 0=0
ShuffleChannel channelshuffle_66 1 1 79 80 0=2 1=0
Slice tensor_split_4 1 2 80 81 82 -23300=2,128,-233 1=0
Convolution conv_88 1 1 82 83 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=16384
HardSwish hswish_19 1 1 83 84 0=1.666667e-01 1=5.000000e-01
ConvolutionDepthWise convdw_139 1 1 84 85 0=128 1=3 11=3 12=1 13=1 14=1 2=1 3=1 4=1 5=1 6=1152 7=128
Convolution conv_89 1 1 85 86 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=16384
HardSwish hswish_20 1 1 86 87 0=1.666667e-01 1=5.000000e-01
Concat cat_7 2 1 81 87 88 0=0
ShuffleChannel channelshuffle_67 1 1 88 89 0=2 1=0
Slice tensor_split_5 1 2 89 90 91 -23300=2,128,-233 1=0
Convolution conv_90 1 1 91 92 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=16384
HardSwish hswish_21 1 1 92 93 0=1.666667e-01 1=5.000000e-01
ConvolutionDepthWise convdw_140 1 1 93 94 0=128 1=3 11=3 12=1 13=1 14=1 2=1 3=1 4=1 5=1 6=1152 7=128
Convolution conv_91 1 1 94 95 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=16384
HardSwish hswish_22 1 1 95 96 0=1.666667e-01 1=5.000000e-01
Concat cat_8 2 1 90 96 97 0=0
ShuffleChannel channelshuffle_68 1 1 97 98 0=2 1=0
Convolution conv_92 1 1 98 99 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=32768
HardSwish hswish_23 1 1 99 100 0=1.666667e-01 1=5.000000e-01
Split splitncnn_3 1 2 100 101 102
Convolution conv_93 1 1 102 103 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=8192
HardSwish hswish_24 1 1 103 104 0=1.666667e-01 1=5.000000e-01
Interp interpolate_52 1 1 104 105 0=1 1=2.000000e+00 2=2.000000e+00 6=0
ConvolutionDepthWise convdw_141 1 1 38 106 0=64 1=3 11=3 12=1 13=2 14=1 2=1 3=2 4=1 5=1 6=576 7=64
Convolution conv_94 1 1 106 107 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096
HardSwish hswish_25 1 1 107 108 0=1.666667e-01 1=5.000000e-01
Concat cat_9 3 1 108 68 105 109 0=0
Split splitncnn_4 1 2 109 110 111
Convolution conv_95 1 1 111 112 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=16384
HardSwish hswish_26 1 1 112 113 0=1.666667e-01 1=5.000000e-01
Convolution conv_96 1 1 110 114 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=16384
HardSwish hswish_27 1 1 114 115 0=1.666667e-01 1=5.000000e-01
Convolution conv_97 1 1 113 116 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096
HardSwish hswish_28 1 1 116 117 0=1.666667e-01 1=5.000000e-01
ConvolutionDepthWise convdw_142 1 1 117 118 0=64 1=3 11=3 12=1 13=1 14=1 2=1 3=1 4=1 5=1 6=576 7=64
Convolution conv_98 1 1 118 119 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096
HardSwish hswish_29 1 1 119 120 0=1.666667e-01 1=5.000000e-01
Concat cat_10 2 1 120 115 121 0=0
Convolution conv_99 1 1 121 122 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=16384
HardSwish hswish_30 1 1 122 123 0=1.666667e-01 1=5.000000e-01
Split splitncnn_5 1 3 123 124 125 126
Convolution conv_100 1 1 125 127 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=8192
HardSwish hswish_31 1 1 127 128 0=1.666667e-01 1=5.000000e-01
Interp interpolate_53 1 1 128 129 0=1 1=2.000000e+00 2=2.000000e+00 6=0
Concat cat_11 2 1 36 129 130 0=0
Split splitncnn_6 1 2 130 131 132
Convolution conv_101 1 1 132 133 0=32 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096
HardSwish hswish_32 1 1 133 134 0=1.666667e-01 1=5.000000e-01
Convolution conv_102 1 1 131 135 0=32 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096
HardSwish hswish_33 1 1 135 136 0=1.666667e-01 1=5.000000e-01
Convolution conv_103 1 1 134 137 0=32 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=1024
HardSwish hswish_34 1 1 137 138 0=1.666667e-01 1=5.000000e-01
ConvolutionDepthWise convdw_143 1 1 138 139 0=32 1=3 11=3 12=1 13=1 14=1 2=1 3=1 4=1 5=1 6=288 7=32
Convolution conv_104 1 1 139 140 0=32 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=1024
HardSwish hswish_35 1 1 140 141 0=1.666667e-01 1=5.000000e-01
Concat cat_12 2 1 141 136 142 0=0
Convolution conv_105 1 1 142 143 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096
HardSwish hswish_36 1 1 143 144 0=1.666667e-01 1=5.000000e-01
ConvolutionDepthWise convdw_144 1 1 126 145 0=128 1=3 11=3 12=1 13=2 14=1 2=1 3=2 4=1 5=1 6=1152 7=128
Convolution conv_106 1 1 145 146 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=16384
HardSwish hswish_37 1 1 146 147 0=1.666667e-01 1=5.000000e-01
Concat cat_13 2 1 147 101 148 0=0
Split splitncnn_7 1 2 148 149 150
Convolution conv_107 1 1 150 151 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=32768
HardSwish hswish_38 1 1 151 152 0=1.666667e-01 1=5.000000e-01
Convolution conv_108 1 1 149 153 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=32768
HardSwish hswish_39 1 1 153 154 0=1.666667e-01 1=5.000000e-01
Convolution conv_109 1 1 152 155 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=16384
HardSwish hswish_40 1 1 155 156 0=1.666667e-01 1=5.000000e-01
ConvolutionDepthWise convdw_145 1 1 156 157 0=128 1=3 11=3 12=1 13=1 14=1 2=1 3=1 4=1 5=1 6=1152 7=128
Convolution conv_110 1 1 157 158 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=16384
HardSwish hswish_41 1 1 158 159 0=1.666667e-01 1=5.000000e-01
Concat cat_14 2 1 159 154 160 0=0
Convolution conv_111 1 1 160 161 0=256 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=65536
HardSwish hswish_42 1 1 161 162 0=1.666667e-01 1=5.000000e-01
Convolution conv_112 1 1 144 163 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096
HardSwish hswish_43 1 1 163 164 0=1.666667e-01 1=5.000000e-01
ConvolutionDepthWise convdw_146 1 1 164 165 0=64 1=5 11=5 12=1 13=1 14=2 2=1 3=1 4=2 5=1 6=1600 7=64
ConvolutionDepthWise convdw_147 1 1 165 166 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096 7=2
HardSwish hswish_44 1 1 166 167 0=1.666667e-01 1=5.000000e-01
ConvolutionDepthWise convdw_148 1 1 167 168 0=128 1=5 11=5 12=1 13=1 14=2 2=1 3=1 4=2 5=1 6=3200 7=128
ConvolutionDepthWise convdw_149 1 1 168 169 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=8192 7=2
HardSwish hswish_45 1 1 169 170 0=1.666667e-01 1=5.000000e-01
Slice tensor_split_6 1 2 170 171 172 -23300=2,64,-233 1=0
Split splitncnn_8 1 2 171 173 174
Convolution conv_114 1 1 174 175 0=8 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=512
Convolution convsigmoid_0 1 1 173 176 0=1 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=64 9=4
Convolution convsigmoid_1 1 1 172 177 0=12 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=768 9=4
Concat cat_15 3 1 175 176 177 178 0=0
Convolution conv_116 1 1 124 179 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=8192
HardSwish hswish_46 1 1 179 180 0=1.666667e-01 1=5.000000e-01
ConvolutionDepthWise convdw_150 1 1 180 181 0=64 1=5 11=5 12=1 13=1 14=2 2=1 3=1 4=2 5=1 6=1600 7=64
ConvolutionDepthWise convdw_151 1 1 181 182 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096 7=2
HardSwish hswish_47 1 1 182 183 0=1.666667e-01 1=5.000000e-01
ConvolutionDepthWise convdw_152 1 1 183 184 0=128 1=5 11=5 12=1 13=1 14=2 2=1 3=1 4=2 5=1 6=3200 7=128
ConvolutionDepthWise convdw_153 1 1 184 185 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=8192 7=2
HardSwish hswish_48 1 1 185 186 0=1.666667e-01 1=5.000000e-01
Slice tensor_split_7 1 2 186 187 188 -23300=2,64,-233 1=0
Split splitncnn_9 1 2 187 189 190
Convolution conv_118 1 1 190 191 0=8 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=512
Convolution convsigmoid_2 1 1 189 192 0=1 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=64 9=4
Convolution convsigmoid_3 1 1 188 193 0=12 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=768 9=4
Concat cat_16 3 1 191 192 193 194 0=0
Convolution conv_120 1 1 162 195 0=64 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=16384
HardSwish hswish_49 1 1 195 196 0=1.666667e-01 1=5.000000e-01
ConvolutionDepthWise convdw_154 1 1 196 197 0=64 1=5 11=5 12=1 13=1 14=2 2=1 3=1 4=2 5=1 6=1600 7=64
ConvolutionDepthWise convdw_155 1 1 197 198 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=4096 7=2
HardSwish hswish_50 1 1 198 199 0=1.666667e-01 1=5.000000e-01
ConvolutionDepthWise convdw_156 1 1 199 200 0=128 1=5 11=5 12=1 13=1 14=2 2=1 3=1 4=2 5=1 6=3200 7=128
ConvolutionDepthWise convdw_157 1 1 200 201 0=128 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=8192 7=2
HardSwish hswish_51 1 1 201 202 0=1.666667e-01 1=5.000000e-01
Slice tensor_split_8 1 2 202 203 204 -23300=2,64,-233 1=0
Split splitncnn_10 1 2 203 205 206
Convolution conv_122 1 1 206 207 0=8 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=512
Convolution convsigmoid_4 1 1 205 208 0=1 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=64 9=4
Convolution convsigmoid_5 1 1 204 209 0=12 1=1 11=1 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=768 9=4
Concat cat_17 3 1 207 208 209 210 0=0
Reshape reshape_125 1 1 178 211 0=2704 1=21
Reshape reshape_126 1 1 194 212 0=676 1=21
Reshape reshape_127 1 1 210 213 0=169 1=21
Concat cat_18 3 1 211 212 213 out0 0=1

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,203 @@
#!/usr/bin/env python3
"""
read_shm_image_mmap_only.py
仅使用 mmap 读取共享内存的 reader
- 共享内存 /dev/shm/debug_frame 前 4 bytes = uint32 jpg_size (little endian)
紧随其后的是 jpg_bytes (jpg_size bytes)
- 不再回退到文件读取;若 mmap 不可用则按 MMAP_RETRY_SECONDS 重试打开
- 原子写入 out_path.tmp -> out_path
- 支持环境变量覆盖: SHM_PATH, SHM_SIZE, OUT_PATH, FPS, DEBUG, MMAP_RETRY_SECONDS
"""
import os
import sys
import time
import struct
import hashlib
import signal
try:
import mmap as _mmap
except Exception:
_mmap = None
from io import BytesIO
# ---------- 配置(环境变量覆盖) ----------
SHM_PATH = os.environ.get("SHM_PATH", "/dev/shm/debug_frame")
SHM_SIZE = int(os.environ.get("SHM_SIZE", str(2 * 1024 * 1024))) # bytes
OUT_PATH = os.environ.get("OUT_PATH", "/dev/shm/frame_preview.jpg")
FPS = float(os.environ.get("FPS", "10.0"))
DEBUG = os.environ.get("DEBUG", "0") != "0"
MMAP_RETRY_SECONDS = float(os.environ.get("MMAP_RETRY_SECONDS", "5.0"))
# -------------------------------------------------------------------
sleep_interval = 1.0 / max(1.0, FPS)
def log(*a, **k):
if DEBUG:
print(*a, **k)
def is_valid_image_bytes(b: bytes) -> bool:
"""尝试用 Pillow 验证图像是否完整(若 Pillow 不存在,返回 True 以不阻塞)"""
try:
from PIL import Image
im = Image.open(BytesIO(b))
im.verify()
return True
except ImportError:
log("Pillow not installed; skipping image verify")
return True
except Exception as e:
log("image verify failed:", e)
return False
def open_mmap():
"""尝试打开并返回 (fd, mmap_obj) 或 (None, None)"""
if _mmap is None:
return None, None
try:
fd = os.open(SHM_PATH, os.O_RDONLY)
mm = _mmap.mmap(fd, SHM_SIZE, access=_mmap.ACCESS_READ)
log("mmap opened:", SHM_PATH, "size", SHM_SIZE)
return fd, mm
except Exception as e:
log("open_mmap failed:", e)
try:
if 'fd' in locals() and fd:
os.close(fd)
except:
pass
return None, None
def close_mmap(fd, mm):
try:
if mm:
mm.close()
except:
pass
try:
if fd:
os.close(fd)
except:
pass
def read_from_mmap(mm):
"""
从 mmap 读取一帧:
结构: [uint32 little-endian jpg_size][jpg_bytes...]
返回 bytes 或 None
"""
try:
mm.seek(0)
hdr = mm.read(4)
if len(hdr) < 4:
return None
jpg_size = struct.unpack("<I", hdr)[0]
if jpg_size <= 0 or jpg_size > (SHM_SIZE - 4):
log("invalid jpg_size:", jpg_size)
return None
jpg_bytes = mm.read(jpg_size)
if len(jpg_bytes) != jpg_size:
log(f"mmap truncated: expected {jpg_size} got {len(jpg_bytes)}")
return None
# quick header check
if not (jpg_bytes.startswith(b"\xff\xd8\xff")):
log("jpg header mismatch")
return None
return jpg_bytes
except ValueError as e:
log("mmap read ValueError:", e)
return None
except Exception as e:
log("mmap read exception:", e)
raise
def atomic_write(out_path, data):
tmp = out_path + ".tmp"
with open(tmp, "wb") as f:
f.write(data)
os.replace(tmp, out_path)
def loop():
if _mmap is None:
print("ERROR: mmap module not available in this Python environment. Exiting.", file=sys.stderr)
sys.exit(2)
fd = None
mm = None
last_hash = None
last_mmap_try = 0.0
# 初始尝试打开 mmap若失败则进入重试循环
fd, mm = open_mmap()
print(f"watching (mmap_only) '{SHM_PATH}' -> '{OUT_PATH}' @ {FPS} fps")
while True:
try:
jpg_bytes = None
# 尝试 mmap 读取(若 mm 为 None 则尝试重建)
if mm is not None:
try:
jpg_bytes = read_from_mmap(mm)
if jpg_bytes is None:
# 写端可能尚未写入完整帧,短暂等待
pass
else:
log("read frame from mmap, len=", len(jpg_bytes))
except Exception as e:
log("mmap error, closing and will retry:", e)
close_mmap(fd, mm)
fd, mm = None, None
last_mmap_try = time.time()
# 如果没有读到数据并且 mmap 不可用或不可读,则重试打开 mmap按间隔
if jpg_bytes is None:
now = time.time()
if mm is None and (now - last_mmap_try) >= MMAP_RETRY_SECONDS:
fd, mm = open_mmap()
last_mmap_try = now
time.sleep(sleep_interval)
continue
# 校验完整性PIL verify
if not is_valid_image_bytes(jpg_bytes):
log("image bytes failed verify, skipping")
time.sleep(sleep_interval)
continue
# 内容哈希,避免重复写磁盘
h = hashlib.sha256(jpg_bytes).hexdigest()
if h != last_hash:
atomic_write(OUT_PATH, jpg_bytes)
last_hash = h
log("wrote", OUT_PATH, "len=", len(jpg_bytes))
# else: skip (no change)
time.sleep(sleep_interval)
except KeyboardInterrupt:
print("terminated by user")
break
except Exception as e:
log("unexpected error:", e)
try:
if mm is not None:
close_mmap(fd, mm)
except:
pass
fd, mm = None, None
time.sleep(1.0)
# 退出前清理
close_mmap(fd, mm)
def _handle_sigterm(signum, frame):
print("terminated", file=sys.stderr)
sys.exit(0)
if __name__ == "__main__":
signal.signal(signal.SIGINT, _handle_sigterm)
signal.signal(signal.SIGTERM, _handle_sigterm)
loop()

View File

@@ -0,0 +1,136 @@
#pragma once
#include "geometry_msgs/msg/transform_stamped.hpp"
#include "rclcpp/rclcpp.hpp"
#include "tf2_ros/buffer.h"
#include "tf2_ros/transform_broadcaster.h"
#include "tf2_ros/transform_listener.h"
#include <chrono>
#include <functional>
#include <memory>
#include <mutex>
#include <string>
#include <thread>
#include <typeindex>
#include <unordered_map>
#include <vector>
class Ros2Node: public rclcpp::Node {
public:
explicit Ros2Node(const std::string& node_name = "vision"):
Node(node_name),
tf_buffer_(this->get_clock()),
tf_listener_(tf_buffer_),
tf_broadcaster_(std::make_shared<tf2_ros::TransformBroadcaster>(this)) {}
using Ptr = std::shared_ptr<Ros2Node>;
static Ptr instance() {
static Ptr inst = std::make_shared<Ros2Node>();
return inst;
}
~Ros2Node() {
stop();
}
template<typename MsgT>
void add_subscription(
const std::string& topic_name,
std::function<void(const typename MsgT::SharedPtr)> callback,
const rclcpp::QoS& qos = rclcpp::QoS(rclcpp::KeepLast(10))
) {
auto sub = this->create_subscription<MsgT>(topic_name, qos, callback);
std::lock_guard<std::mutex> lock(map_mutex_);
subscriptions_[topic_name] = sub;
}
template<typename MsgT>
void add_publisher(
const std::string& topic_name,
const rclcpp::QoS& qos = rclcpp::QoS(rclcpp::KeepLast(10))
) {
auto pub = this->create_publisher<MsgT>(topic_name, qos);
std::lock_guard<std::mutex> lock(map_mutex_);
publishers_[topic_name] = pub;
}
template<typename MsgT>
void publish(const std::string& topic_name, const MsgT& msg) {
std::lock_guard<std::mutex> lock(map_mutex_);
auto it = publishers_.find(topic_name);
if (it != publishers_.end()) {
auto typed_pub = std::dynamic_pointer_cast<rclcpp::Publisher<MsgT>>(it->second);
if (typed_pub)
typed_pub->publish(msg);
}
}
template<typename Rep, typename Period>
void
add_timer(const std::chrono::duration<Rep, Period>& interval, std::function<void()> callback) {
auto timer = this->create_wall_timer(interval, callback);
std::lock_guard<std::mutex> lock(map_mutex_);
timers_.push_back(timer);
}
bool lookup_transform(
const std::string& target_frame,
const std::string& source_frame,
geometry_msgs::msg::TransformStamped& out_tf,
std::chrono::milliseconds timeout = std::chrono::milliseconds(200)
) const {
try {
out_tf =
tf_buffer_.lookupTransform(target_frame, source_frame, tf2::TimePointZero, timeout);
return true;
} catch (const tf2::TransformException& ex) {
RCLCPP_WARN(
this->get_logger(),
"lookup_transform failed %s -> %s: %s",
source_frame.c_str(),
target_frame.c_str(),
ex.what()
);
return false;
}
}
void broadcast_tf(const geometry_msgs::msg::TransformStamped& tf_msg) {
tf_broadcaster_->sendTransform(tf_msg);
}
template<typename Rep, typename Period>
void broadcast_tf_periodic(
const geometry_msgs::msg::TransformStamped& tf_msg,
const std::chrono::duration<Rep, Period>& interval
) {
add_timer(interval, [this, tf_msg]() { broadcast_tf(tf_msg); });
}
void start() {
if (!spin_thread_.joinable()) {
auto node = shared_from_this();
spin_thread_ = std::thread([node]() { rclcpp::spin(node); });
}
}
void stop() {
rclcpp::shutdown();
if (spin_thread_.joinable()) {
spin_thread_.join();
}
}
private:
mutable std::mutex map_mutex_;
std::unordered_map<std::string, rclcpp::SubscriptionBase::SharedPtr> subscriptions_;
std::unordered_map<std::string, rclcpp::PublisherBase::SharedPtr> publishers_;
std::vector<rclcpp::TimerBase::SharedPtr> timers_;
std::thread spin_thread_;
tf2_ros::Buffer tf_buffer_;
tf2_ros::TransformListener tf_listener_;
std::shared_ptr<tf2_ros::TransformBroadcaster> tf_broadcaster_;
};

View File

@@ -0,0 +1,88 @@
#pragma once
#include "rclcpp/rclcpp.hpp"
#include <Eigen/Dense>
#include <tf2_geometry_msgs/tf2_geometry_msgs.hpp>
#include <tf2_ros/buffer.h>
#include <tf2_ros/transform_broadcaster.h>
#include <tf2_ros/transform_listener.h>
template<typename PublisherT>
inline bool publisherSubscribed(const PublisherT& publisher) noexcept {
return publisher && publisher->get_subscription_count() > 0;
}
inline Eigen::Isometry3f tf2ToEigen(const geometry_msgs::msg::TransformStamped& tf) noexcept {
Eigen::Isometry3f T = Eigen::Isometry3f::Identity();
const auto& t = tf.transform.translation;
T.translation() =
Eigen::Vector3f(static_cast<float>(t.x), static_cast<float>(t.y), static_cast<float>(t.z));
const auto& q = tf.transform.rotation;
Eigen::Quaternionf Q(
static_cast<float>(q.w),
static_cast<float>(q.x),
static_cast<float>(q.y),
static_cast<float>(q.z)
);
Q.normalize();
T.linear() = Q.toRotationMatrix();
return T;
}
class TF {
public:
using Ptr = std::shared_ptr<TF>;
TF(rclcpp::Node& n) {
tf_buffer_ = std::make_unique<tf2_ros::Buffer>(n.get_clock());
tf_listener_ = std::make_unique<tf2_ros::TransformListener>(*tf_buffer_);
tf_broadcaster_ = std::make_unique<tf2_ros::TransformBroadcaster>(n);
node_ = &n;
}
static Ptr create(rclcpp::Node& n) {
return std::make_shared<TF>(n);
}
std::optional<Eigen::Isometry3f>
getTransform(const std::string& target, const std::string& source, rclcpp::Time t)
const noexcept {
Eigen::Isometry3f T_out = Eigen::Isometry3f::Identity();
try {
geometry_msgs::msg::TransformStamped tf_msg =
tf_buffer_->lookupTransform(target, source, t, rclcpp::Duration::from_seconds(0.1));
T_out = tf2ToEigen(tf_msg);
return T_out;
} catch (tf2::TransformException& ex) {
RCLCPP_WARN(rclcpp::get_logger("tf"), "TF lookup failed: %s", ex.what());
return std::nullopt;
}
return T_out;
}
void publishTransform(
const Eigen::Isometry3d& transform,
const std::string& parent_frame,
const std::string& child_frame,
const rclcpp::Time& stamp
) const noexcept {
geometry_msgs::msg::TransformStamped tmsg;
tmsg.header.stamp = stamp;
tmsg.header.frame_id = parent_frame;
tmsg.child_frame_id = child_frame;
const Eigen::Vector3d tr = transform.translation();
const Eigen::Quaterniond q(transform.rotation());
tmsg.transform.translation.x = tr.x();
tmsg.transform.translation.y = tr.y();
tmsg.transform.translation.z = tr.z();
tmsg.transform.rotation.x = q.x();
tmsg.transform.rotation.y = q.y();
tmsg.transform.rotation.z = q.z();
tmsg.transform.rotation.w = q.w();
tf_broadcaster_->sendTransform(tmsg);
}
rclcpp::Node* node_;
std::unique_ptr<tf2_ros::Buffer> tf_buffer_;
std::unique_ptr<tf2_ros::TransformListener> tf_listener_;
std::unique_ptr<tf2_ros::TransformBroadcaster> tf_broadcaster_;
};

161
wust_vision-main/run.sh Executable file
View File

@@ -0,0 +1,161 @@
#!/bin/bash
WORK_DIR="$(dirname "$(realpath "${BASH_SOURCE[0]}")")"
BUILD_DIR="$WORK_DIR/build"
CONFIG_DIR="$WORK_DIR/config"
BIN_DIR="$WORK_DIR/bin"
source "$WORK_DIR/env.bash"
export VISION_ROOT="$WORK_DIR"
export MVCAM_SDK_PATH=/opt/MVS
export MVCAM_COMMON_RUNENV=/opt/MVS/lib
export MVCAM_GENICAM_CLPROTOCOL=/opt/MVS/lib/CLProtocol
export ALLUSERSPROFILE=/opt/MVS/MVFG
export LD_LIBRARY_PATH=/opt/MVS/lib/64:/opt/MVS/lib/32:$LD_LIBRARY_PATH
blue="\033[1;34m"
yellow="\033[1;33m"
reset="\033[0m"
red="\033[1;31m"
if [ "$EUID" -eq 0 ]; then
USER_HOME=$(getent passwd $SUDO_USER | cut -d: -f6)
COPY_BASHRC="$WORK_DIR/user_bashrc_copy.bash"
if [ -f "$USER_HOME/.bashrc" ]; then
# 复制 bashrc 到 WORK_DIR并删除前10行
tail -n +11 "$USER_HOME/.bashrc" > "$COPY_BASHRC"
# 设置权限,普通用户可读
chmod 644 "$COPY_BASHRC"
chown $SUDO_USER:$SUDO_USER "$COPY_BASHRC"
echo -e "${yellow}Copied ~/.bashrc to $COPY_BASHRC with first 10 lines removed${reset}"
# 加载复制的 bashrc
source "$COPY_BASHRC"
echo -e "${yellow}Loaded bashrc from copy${reset}"
else
echo -e "${red}Original ~/.bashrc not found: $USER_HOME/.bashrc${reset}"
source "$COPY_BASHRC"
fi
else
# 普通用户直接加载原 bashrc
if [ -f "$HOME/.bashrc" ]; then
source "$HOME/.bashrc"
echo -e "${yellow}Loaded bashrc from $HOME/.bashrc${reset}"
fi
fi
chmod 777 /dev/shm/debug_frame
rm -rf "$BIN_DIR/config"
ln -sf "$CONFIG_DIR" "$BIN_DIR/config"
ln -sf "$WORK_DIR/env.bash" "$BUILD_DIR/env.bash"
if [ "$1" == "rebuild" ]; then
echo -e "${yellow}<--- Rebuilding: This will REMOVE the entire build directory --->${reset}"
read -p "Are you sure you want to rebuild? [y/N]: " confirm
confirm=${confirm,,}
if [[ "$confirm" != "y" && "$confirm" != "yes" ]]; then
echo -e "${red}Rebuild cancelled.${reset}"
exit 0
fi
echo -e "${yellow}<--- Removing build directory --->${reset}"
rm -rf "$BUILD_DIR"
mkdir -p "$BUILD_DIR"
else
mkdir -p "$BUILD_DIR"
fi
if [[ "$1" == "build" || "$1" == "rebuild" || "$1" == "run" ]]; then
echo -e "${yellow}<--- Start CMake (Ninja) --->${reset}"
cmake -S "$WORK_DIR" -B "$BUILD_DIR" \
-G Ninja \
if [ $? -ne 0 ]; then
echo -e "${red}\n--- CMake Failed ---${reset}"
exit 1
fi
SECONDS=0
echo -e "${yellow}\n<--- Start Ninja Build --->${reset}"
ninja -C "$BUILD_DIR"
if [ $? -ne 0 ]; then
echo -e "${red}\n--- Ninja Build Failed ---${reset}"
exit 1
fi
build_time=$SECONDS
printf "${blue}\n<--- Build Time --->\n %02d:%02d (mm:ss)\n${reset}" \
$((build_time / 60)) $((build_time % 60))
echo -e "${yellow}\n<--- Total Lines --->${reset}"
total=$(find "$WORK_DIR" \
-type d \( \
-path "$BUILD_DIR" -o \
-path "$WORK_DIR/model" -o \
-path "$WORK_DIR/3rdparty" -o \
-path "$WORK_DIR/.cache" \
\) -prune -o \
-type f \( \
-name "*.cpp" -o -name "*.hpp" -o -name "*.c" -o -name "*.h" \
-o -name "*.py" -o -name "*.html" -o -name "*.sh" -o -name "*.md" \
-o -name "*.yaml" -o -name "*.json" -o -name "*.css" -o -name "*.js" \
-o -name "*.cu" -o -name "*.txt" \
\) -exec wc -l {} + | awk 'END{print $1}')
echo -e "${blue} $total${reset}"
# Only build
if [ "$1" == "build" ] || [ "$1" == "rebuild" ]; then
echo -e "${yellow}\n<--- Only building --->${reset}"
echo -e "${yellow}<----- OVER ----->${reset}"
exit 0
fi
# Run mode
if [ "$1" == "run" ]; then
echo -e "${yellow}\n<--- Running WUST_VISION --->${reset}"
RUN_PROGRAM="$BIN_DIR/$2"
sudo pkill -9 $2
ORIGINAL_ARGS=("$@")
shift 2
"$RUN_PROGRAM" "$@"
RET=$?
set -- "${ORIGINAL_ARGS[@]}"
if [ $RET -ne 0 ]; then
echo -e "${red}\n--- Program crashed, running guard.sh ---${reset}"
pkill "$2"
timeout=10
while pgrep "$2" > /dev/null; do
sleep 0.5
timeout=$((timeout - 1))
if [ $timeout -le 0 ]; then
echo "$2 did not exit after 10 seconds, forcing kill"
pkill -9 "$2"
break
fi
done
GUARD_SCRIPT="$CONFIG_DIR/guard.sh"
TARGET_PATH="$RUN_PROGRAM"
if [ ! -f "$GUARD_SCRIPT" ]; then
echo -e "${red}guard.sh not found: $GUARD_SCRIPT${reset}"
exit 1
fi
echo -e "${yellow}Starting guard.sh ...${reset}"
exec "$GUARD_SCRIPT" "$TARGET_PATH" "$@"
fi
fi
echo -e "${yellow}<----- OVER ----->${reset}"
else
echo -e "${yellow}Warning:${reset} Invalid argument '$1'."
echo -e "${yellow}Usage:${reset} $0 {build|rebuild|run <program> [args...]}"
echo -e "${yellow}No action performed.${reset}"
exit 0
fi

View File

@@ -0,0 +1,6 @@
#!/usr/bin/env bash
sudo apt update && sudo apt install -y \
rsync ninja-build \
libfmt-dev libceres-dev libeigen3-dev \
nlohmann-json3-dev libyaml-cpp-dev

View File

@@ -0,0 +1,28 @@
#!/usr/bin/env bash
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
WORK_DIR="$(dirname "$SCRIPT_DIR")"
echo "脚本目录: $SCRIPT_DIR"
echo "wust_vision目录: $WORK_DIR"
# 用法说明
if [ $# -ne 2 ]; then
echo "Usage: $0 <remote_user> <remote_ip>"
echo "Example:"
echo " $0 nvidiaa 192.168.10.100"
exit 1
fi
REMOTE_USER="$1"
REMOTE_IP="$2"
TARGET_PATH="/home/${REMOTE_USER}/wust_vision"
rsync -avz \
--exclude='.cache/' \
--exclude='.vscode/' \
--exclude='.git/' \
--exclude='bin/' \
--exclude='build/' \
--exclude='model/' \
--exclude='config/' \
--exclude='CMakeLists.txt' \
"${WORK_DIR}/" \
"${REMOTE_USER}@${REMOTE_IP}:${TARGET_PATH}/"

View File

@@ -0,0 +1,65 @@
#!/usr/bin/env bash
set -euo pipefail
BASHRC="$HOME/.bashrc"
MARKER_START="# >>> wust_vision dev >>>"
MARKER_END="# <<< wust_vision dev <<<"
BACKUP="$BASHRC.bak.$(date +%Y%m%d_%H%M%S)"
# 备份
cp "$BASHRC" "$BACKUP"
echo "💾 已备份 $BASHRC$BACKUP"
# 临时文件
TMP=$(mktemp)
# 写入 bashrc 除去旧 block
awk -v start="$MARKER_START" -v end="$MARKER_END" '
$0 == start {inside=1; next}
inside && $0 == end {inside=0; next}
!inside {print}
' "$BASHRC" > "$TMP"
# 使用 cat <<EOF 方式追加 block保证特殊字符安全
cat <<'EOF' >> "$TMP"
# >>> wust_vision dev >>>
py() {
source ~/anaconda3/etc/profile.d/conda.sh
echo "已激活 conda, conda activate激活环境conda deactivate退出环境 conda info --envs查看环境"
}
buildme() {
colcon build --symlink-install --cmake-args -DCMAKE_BUILD_TYPE=Release
}
builddebug() {
colcon build --packages-select "$1" --cmake-args -DCMAKE_BUILD_TYPE=Debug
}
killros() {
pkill -f ros && ros2 daemon stop && ros2 daemon start
}
format() {
find . -path ./build -prune -o \
-type f \( -name '*.h' -o -name '*.hpp' -o -name '*.c' -o -name '*.cu' -o -name '*.cpp' \) \
-exec clang-format -i {} +
}
s() {
source install/setup.bash
}
hik() {
export MVCAM_SDK_PATH=/opt/MVS
export MVCAM_COMMON_RUNENV=/opt/MVS/lib
export MVCAM_GENICAM_CLPROTOCOL=/opt/MVS/lib/CLProtocol
export ALLUSERSPROFILE=/opt/MVS/MVFG
export LD_LIBRARY_PATH=/opt/MVS/lib/64:/opt/MVS/lib/32:$LD_LIBRARY_PATH
}
# <<< wust_vision dev <<<
EOF
# 原子替换
mv "$TMP" "$BASHRC"
echo "✅ 已安全替换或追加 block 到 $BASHRC"

View File

@@ -0,0 +1,109 @@
#!/usr/bin/env bash
# 自动管理 systemd service 文件
# 工作目录 = 脚本所在路径的上一级目录
SERVICE_NAME="wust_vision"
SERVICE_FILE="/etc/systemd/system/${SERVICE_NAME}.service"
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
WORK_DIR="$(dirname "$SCRIPT_DIR")"
ACTION=$1
echo "正在操作: standard"
echo "📂 脚本路径: $SCRIPT_DIR"
echo "🏗 工作区路径: $WORK_DIR"
echo "🧾 目标 Service 文件: $SERVICE_FILE"
echo
uninstall_service() {
if [[ ! -f "$SERVICE_FILE" ]]; then
echo "⚠️ Service 文件不存在,无需卸载。"
exit 0
fi
echo "🛑 停止并卸载 Service..."
sudo systemctl stop "${SERVICE_NAME}.service" 2>/dev/null || true
sudo systemctl disable "${SERVICE_NAME}.service" 2>/dev/null || true
sudo rm -f "$SERVICE_FILE"
sudo systemctl daemon-reload
echo "✅ Service 已成功卸载。"
exit 0
}
install_service() {
if [[ -f "$SERVICE_FILE" ]]; then
echo "⚠️ 检测到已存在的 Service 文件:$SERVICE_FILE"
read -p "是否覆盖?(y/N): " confirm
if [[ "$confirm" != "y" && "$confirm" != "Y" ]]; then
echo "🚫 已取消安装。"
exit 0
else
echo "🧹 删除旧版本 Service..."
sudo systemctl stop ${SERVICE_NAME}.service 2>/dev/null || true
sudo systemctl disable ${SERVICE_NAME}.service 2>/dev/null || true
sudo rm -f "$SERVICE_FILE"
fi
fi
echo "✏️ 正在生成新的 Service 文件..."
sudo tee "$SERVICE_FILE" > /dev/null <<EOF
[Unit]
Description=Wust Vision Standard Service
After=network.target
[Service]
Type=simple
User=root
WorkingDirectory=${WORK_DIR}
ExecStart=${WORK_DIR}/run.sh run standard false
Restart=always
RestartSec=5
StandardOutput=journal
StandardError=journal
[Install]
WantedBy=multi-user.target
EOF
sudo chmod 644 "$SERVICE_FILE"
sudo systemctl daemon-reload
sudo systemctl enable "${SERVICE_NAME}.service"
sudo systemctl restart "${SERVICE_NAME}.service"
echo
echo "✅ Service 已生成并启动成功!"
echo "🔍 查看状态sudo systemctl status ${SERVICE_NAME}.service"
echo "📜 实时日志journalctl -u ${SERVICE_NAME}.service -f"
}
if [[ -z "$ACTION" ]]; then
echo "❌ 请输入要执行的操作:$0 [install|uninstall|start|stop|restart|status|journal]"
exit 1
fi
case "$ACTION" in
install) install_service ;;
uninstall) uninstall_service ;;
start)
sudo systemctl start "${SERVICE_NAME}.service"
echo "✅ Service 已启动"
;;
stop)
sudo systemctl stop "${SERVICE_NAME}.service"
echo "🛑 Service 已停止"
;;
restart)
sudo systemctl restart "${SERVICE_NAME}.service"
echo "🔄 Service 已重启"
;;
status)
sudo systemctl status "${SERVICE_NAME}.service"
;;
journal)
sudo journalctl -u "${SERVICE_NAME}.service" -f
;;
*)
echo "❌ 参数错误: $0 [install|uninstall|start|stop|restart|status|journal]"
exit 1
;;
esac

View File

@@ -0,0 +1,109 @@
#!/usr/bin/env bash
# 自动管理 systemd service 文件
# 工作目录 = 脚本所在路径的上一级目录
SERVICE_NAME="wust_vision"
SERVICE_FILE="/etc/systemd/system/${SERVICE_NAME}.service"
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
WORK_DIR="$(dirname "$SCRIPT_DIR")"
ACTION=$1
echo "正在操作: sentry"
echo "📂 脚本路径: $SCRIPT_DIR"
echo "🏗 工作区路径: $WORK_DIR"
echo "🧾 目标 Service 文件: $SERVICE_FILE"
echo
uninstall_service() {
if [[ ! -f "$SERVICE_FILE" ]]; then
echo "⚠️ Service 文件不存在,无需卸载。"
exit 0
fi
echo "🛑 停止并卸载 Service..."
sudo systemctl stop "${SERVICE_NAME}.service" 2>/dev/null || true
sudo systemctl disable "${SERVICE_NAME}.service" 2>/dev/null || true
sudo rm -f "$SERVICE_FILE"
sudo systemctl daemon-reload
echo "✅ Service 已成功卸载。"
exit 0
}
install_service() {
if [[ -f "$SERVICE_FILE" ]]; then
echo "⚠️ 检测到已存在的 Service 文件:$SERVICE_FILE"
read -p "是否覆盖?(y/N): " confirm
if [[ "$confirm" != "y" && "$confirm" != "Y" ]]; then
echo "🚫 已取消安装。"
exit 0
else
echo "🧹 删除旧版本 Service..."
sudo systemctl stop ${SERVICE_NAME}.service 2>/dev/null || true
sudo systemctl disable ${SERVICE_NAME}.service 2>/dev/null || true
sudo rm -f "$SERVICE_FILE"
fi
fi
echo "✏️ 正在生成新的 Service 文件..."
sudo tee "$SERVICE_FILE" > /dev/null <<EOF
[Unit]
Description=Wust Vision Standard Service
After=network.target
[Service]
Type=simple
User=root
WorkingDirectory=${WORK_DIR}
ExecStart=${WORK_DIR}/run.sh run sentry false
Restart=always
RestartSec=5
StandardOutput=journal
StandardError=journal
[Install]
WantedBy=multi-user.target
EOF
sudo chmod 644 "$SERVICE_FILE"
sudo systemctl daemon-reload
sudo systemctl enable "${SERVICE_NAME}.service"
sudo systemctl restart "${SERVICE_NAME}.service"
echo
echo "✅ Service 已生成并启动成功!"
echo "🔍 查看状态sudo systemctl status ${SERVICE_NAME}.service"
echo "📜 实时日志journalctl -u ${SERVICE_NAME}.service -f"
}
if [[ -z "$ACTION" ]]; then
echo "❌ 请输入要执行的操作:$0 [install|uninstall|start|stop|restart|status|journal]"
exit 1
fi
case "$ACTION" in
install) install_service ;;
uninstall) uninstall_service ;;
start)
sudo systemctl start "${SERVICE_NAME}.service"
echo "✅ Service 已启动"
;;
stop)
sudo systemctl stop "${SERVICE_NAME}.service"
echo "🛑 Service 已停止"
;;
restart)
sudo systemctl restart "${SERVICE_NAME}.service"
echo "🔄 Service 已重启"
;;
status)
sudo systemctl status "${SERVICE_NAME}.service"
;;
journal)
sudo journalctl -u "${SERVICE_NAME}.service" -f
;;
*)
echo "❌ 参数错误: $0 [install|uninstall|start|stop|restart|status|journal]"
exit 1
;;
esac

View File

@@ -0,0 +1,8 @@
#!/usr/bin/env bash
sudo tee /etc/udev/rules.d/99-stm32-acm.rules > /dev/null <<EOF
SUBSYSTEM=="tty", ATTRS{idVendor}=="0483", ATTRS{idProduct}=="5740", ATTRS{serial}=="2069398C464D", SYMLINK+="stm32_acm"
EOF
sudo udevadm control --reload-rules
sudo udevadm trigger

View File

@@ -0,0 +1,204 @@
#include "tasks/auto_guidance/auto_guidance.hpp"
#include "tasks/utils/main_base.hpp"
#include "tasks/utils/utils.hpp"
#include "wust_vl/common/concurrency/ThreadPool.h"
#include "wust_vl/common/drivers/serial_driver.hpp"
#include "wust_vl/common/utils/logger.hpp"
#include "wust_vl/common/utils/timer.hpp"
#include "wust_vl/video/camera.hpp"
ENABLE_BACKWARD()
namespace wust_vision {
struct SendDartCmdData {
float diff_center_norm = 0;
} __attribute__((packed));
class vision {
public:
bool init(bool debug_mode) {
debug_mode_ = debug_mode;
const char* v = std::getenv("VISION_ROOT");
if (v)
std::cout << "[env] VISION_ROOT = " << v << "\n";
else
std::cout << "[env] VISION_ROOT not set in this process\n";
config_ = YAML::LoadFile("/home/hy/wust_vision/config/auto_guidance.yaml");
std::string log_level_ = config_["logger"]["log_level"].as<std::string>("INFO");
std::string log_path_ = config_["logger"]["log_path"].as<std::string>("wust_log");
bool use_logcli = config_["logger"]["use_logcli"].as<bool>();
bool use_logfile = config_["logger"]["use_logfile"].as<bool>();
bool use_simplelog = config_["logger"]["use_simplelog"].as<bool>();
wust_vl::initLogger(log_level_, log_path_, use_logcli, use_logfile, use_simplelog);
YAML::Node camera_config = config_["camera"];
camera_ = std::make_unique<wust_vl::video::Camera>();
camera_->init(camera_config);
camera_->setFrameCallback(std::bind(&vision::frameCallback, this, std::placeholders::_1));
std::string camera_info_path =
utils::expandEnv(camera_config["camera_info_path"].as<std::string>());
YAML::Node config_camera_info = YAML::LoadFile(camera_info_path);
std::vector<double> camera_k =
config_camera_info["camera_matrix"]["data"].as<std::vector<double>>();
std::vector<double> camera_d =
config_camera_info["distortion_coefficients"]["data"].as<std::vector<double>>();
assert(camera_k.size() == 9);
assert(camera_d.size() == 5);
cv::Mat K(3, 3, CV_64F);
std::memcpy(K.data, camera_k.data(), 9 * sizeof(double));
cv::Mat D(1, 5, CV_64F);
std::memcpy(D.data, camera_d.data(), 5 * sizeof(double));
auto camera_info = std::make_pair(K.clone(), D.clone());
camera_info_ = camera_info;
auto_guidance_ = auto_guidance::AutoGuidance::create();
auto_guidance_->setDebug(debug_mode);
auto_guidance_->init(config_, camera_info);
max_infer_running_ = config_["max_infer_running"].as<int>();
thread_pool_ = std::make_unique<wust_vl::common::concurrency::ThreadPool>(
std::thread::hardware_concurrency() * 2
);
std::string device_name = config_["control"]["device_name"].as<std::string>();
serial_ = std::make_shared<wust_vl::common::drivers::SerialDriver>();
bool use_serial = config_["control"]["use_serial"].as<bool>();
if (use_serial) {
wust_vl::common::drivers::SerialDriver::SerialPortConfig cfg {
/*baud*/ 115200,
/*csize*/ 8,
boost::asio::serial_port_base::parity::none,
boost::asio::serial_port_base::stop_bits::one,
boost::asio::serial_port_base::flow_control::none
};
serial_->init_port(device_name, cfg);
serial_->set_receive_callback(std::bind(
&vision::serialCallback,
this,
std::placeholders::_1,
std::placeholders::_2
));
std::cout << "Serial port opened" << std::endl;
serial_->set_error_callback([&](const boost::system::error_code& ec) {
WUST_ERROR("serial") << "serial error: " << ec.message();
});
}
timer_ = std::make_unique<wust_vl::common::utils::Timer>();
WUST_MAIN("vision") << "starting vision task";
return true;
}
~vision() {
run_flag_ = false;
camera_->stop();
thread_pool_->waitUntilEmpty();
if (debug_thread_.joinable()) {
debug_thread_.join();
}
}
void start() {
run_flag_ = true;
camera_->start();
auto_guidance_->start();
if (serial_) {
serial_->start();
}
if (timer_) {
auto timercallback = std::bind(&vision::timerCallback, this, std::placeholders::_1);
double rate_hz = static_cast<double>(config_["control"]["control_rate"].as<int>());
timer_->start(rate_hz, timercallback);
}
if (debug_mode_) {
debug_thread_ = std::thread([this]() { this->debugThread(); });
}
}
void serialCallback(const uint8_t* data, std::size_t len) {}
void frameCallback(wust_vl::video::ImageFrame& frame) {
if (!run_flag_ || infer_running_count_ >= max_infer_running_) {
return;
}
CommonFrame common_frame;
if (frame.src_img.empty()) {
return;
}
common_frame.img_frame = std::move(frame);
common_frame.expanded = cv::Rect(
0,
0,
common_frame.img_frame.src_img.cols,
common_frame.img_frame.src_img.rows
);
common_frame.offset = cv::Point2f(0, 0);
thread_pool_->enqueue([this, frame = std::move(common_frame)]() mutable {
infer_running_count_++;
if (frame.img_frame.src_img.data == nullptr) {
return;
}
if (frame.img_frame.src_img.empty()) {
return;
}
if (auto_guidance_) {
auto_guidance_->pushInput(frame);
}
infer_running_count_--;
});
}
void timerCallback(double dt_ms) {
if (!run_flag_) {
return;
}
auto target = auto_guidance_->getTarget();
cx_norm_ = target.center().x / target.image_size_.width * 2.0 - 1.0;
SendDartCmdData send_data;
// send_data.cmd_ID = ID_ROBOT_CMD;
// send_data.time_stamp = std::chrono::duration_cast<std::chrono::milliseconds>(
// std::chrono::steady_clock::now().time_since_epoch()
// )
// .count();
// send_data.appear = target.is_tracking_;
send_data.diff_center_norm = (target.is_tracking_) ? cx_norm_ : 0;
if (serial_) {
serial_->write(std::move(wust_vl::common::drivers::toVector(send_data)));
}
}
void debugThread() {
using namespace std::chrono;
double us_interval = 1e6 / static_cast<double>(30.0);
auto kInterval = std::chrono::microseconds(static_cast<int64_t>(us_interval));
while (run_flag_) {
auto start_time = steady_clock::now();
try {
auto dbg = auto_guidance_->getDebug();
drawDebugOverlayShm(dbg, false);
debuglog(dbg.target);
} catch (std::exception& e) {
std::cout << "debug thread error: " << e.what() << std::endl;
}
auto elapsed = steady_clock::now() - start_time;
if (elapsed < kInterval) {
std::this_thread::sleep_for(kInterval - elapsed);
}
}
}
void checkStateMatchMode() {}
YAML::Node config_;
std::unique_ptr<wust_vl::common::concurrency::ThreadPool> thread_pool_;
std::unique_ptr<auto_guidance::AutoGuidance> auto_guidance_;
std::unique_ptr<wust_vl::video::Camera> camera_;
std::unique_ptr<wust_vl::common::utils::Timer> timer_;
std::shared_ptr<wust_vl::common::drivers::SerialDriver> serial_;
std::atomic<int> infer_running_count_ { 0 };
bool run_flag_ = false;
int max_infer_running_;
bool debug_mode_ = false;
std::pair<cv::Mat, cv::Mat> camera_info_;
std::thread debug_thread_;
double cx_norm_;
};
} // namespace wust_vision
VISION_MAIN(wust_vision::vision)

View File

@@ -0,0 +1,35 @@
#include "ros2/ros2.hpp"
#include "tasks/auto_sniper/auto_sniper.hpp"
#include "tasks/type_common.hpp"
#include "tasks/utils/config.hpp"
#include "tasks/utils/main_base.hpp"
#include "tasks/vision_base.hpp"
ENABLE_BACKWARD()
namespace wust_vision {
class vision: public VisionBase<HeroMode> {
public:
vision(): VisionBase(COMMON_CONFIG, CAMERA_CONFIG, AUTO_AIM_CONFIG, AUTO_BUFF_CONFIG) {}
bool init(bool debug_mode) {
VisionBase::init(debug_mode);
auto auto_aim =
auto_aim::AutoAim::create(auto_aim_config_, tf_config_, camera_info_, debug_mode);
modules_.emplace(HeroMode::AttackMode::ARMOR, auto_aim);
modules_.emplace(HeroMode::AttackMode::UNKNOWN, auto_aim);
rclcpp::init(0, nullptr);
ros2_ = std::make_shared<Ros2Node>("vison_node");
auto auto_sniper = auto_sniper::AutoSniper::create(*ros2_, motion_buffer_);
modules_.emplace(HeroMode::AttackMode::SNIPER, auto_sniper);
return true;
}
void start() {
VisionBase::start();
ros2_->start();
}
std::shared_ptr<Ros2Node> ros2_;
};
} // namespace wust_vision
VISION_MAIN(wust_vision::vision)

View File

@@ -0,0 +1,302 @@
#include "geometry_msgs/msg/twist.hpp"
#include "ros2/ros2.hpp"
#include "sentry_interfaces/msg/detail/mode__struct.hpp"
#include "sentry_interfaces/msg/detail/robo_state__struct.hpp"
#include "sentry_interfaces/msg/detail/target__struct.hpp"
#include "sentry_interfaces/msg/mode.hpp"
#include "sentry_interfaces/msg/robo_state.hpp"
#include "sentry_interfaces/msg/target.hpp"
#include "tasks/auto_aim/armor_control/very_aimer.hpp"
#include "tasks/auto_aim/armor_omni/armor_omni.hpp"
#include "tasks/auto_aim/auto_aim.hpp"
#include "tasks/packet_typedef.hpp"
#include "tasks/type_common.hpp"
#include "tasks/utils/config.hpp"
#include "tasks/utils/main_base.hpp"
#include "tasks/vision_base.hpp"
#include "visualization_msgs/msg/marker.hpp"
#include <cmath>
#include <memory>
#include <wust_vl/common/utils/timer.hpp>
#include <yaml-cpp/node/parse.h>
ENABLE_BACKWARD()
namespace wust_vision {
class vision: public VisionBase<InfantryMode> {
public:
static constexpr const char* TARGET_MARKER = "target_marker";
vision(): VisionBase(COMMON_CONFIG, CAMERA_CONFIG, AUTO_AIM_CONFIG, AUTO_BUFF_CONFIG) {}
~vision() {
armor_omni_.reset();
}
bool init(bool debug_mode) {
VisionBase::init(debug_mode);
auto auto_aim =
auto_aim::AutoAim::create(auto_aim_config_, tf_config_, camera_info_, debug_mode);
modules_.emplace(InfantryMode::AttackMode::ARMOR, auto_aim);
modules_.emplace(InfantryMode::AttackMode::UNKNOWN, auto_aim);
auto auto_buff =
auto_buff::AutoBuff::create(auto_buff_config_, tf_config_, camera_info_, debug_mode);
modules_.emplace(InfantryMode::AttackMode::BIG_RUNE, auto_buff);
modules_.emplace(InfantryMode::AttackMode::SMALL_RUNE, auto_buff);
serial_->set_receive_callback(
std::bind(&vision::serialCallback, this, std::placeholders::_1, std::placeholders::_2)
);
timer_B_ = std::make_unique<wust_vl::common::utils::Timer>("10hz");
big_yaw_motion_buffer_ =
std::make_shared<wust_vl::common::utils::MotionBufferGeneric<BigYaw, 1024>>();
// auto very_aimer_copy = std::make_shared<auto_aim::VeryAimer>(*auto_aim->getVeryAimer());
auto_aim::ArmorOmni::Ctx omni_ctx = {
.car_motion_buffer = motion_buffer_,
.big_yaw_motion_buffer = big_yaw_motion_buffer_,
.very_aimer = auto_aim->getVeryAimer(),
};
armor_omni_ = auto_aim::ArmorOmni::create(detect_color_, omni_ctx);
rclcpp::init(0, nullptr);
ros2_ = std::make_shared<Ros2Node>("vison_node");
ros2_->add_subscription<geometry_msgs::msg::Twist>(
"cmd_vel",
std::bind(&vision::twistCb, this, std::placeholders::_1),
rclcpp::QoS(10)
);
ros2_->add_subscription<sentry_interfaces::msg::Mode>(
MODE_TOPIC,
std::bind(&vision::modeCb, this, std::placeholders::_1)
);
ros2_->add_publisher<visualization_msgs::msg::Marker>(TARGET_MARKER);
ros2_->add_publisher<sentry_interfaces::msg::Target>(TARGET_TOPIC);
ros2_->add_publisher<sentry_interfaces::msg::RoboState>(ROBO_STATE_TOPIC);
return true;
}
void start() {
VisionBase::start();
if (timer_) {
const auto timercallback =
std::bind(&vision::timerCallback, this, std::placeholders::_1);
const double rate_hz = control_config_->control_rate_param.get();
timer_->start(rate_hz, timercallback);
}
if (timer_B_) {
const auto timercallback =
std::bind(&vision::timerBCallback, this, std::placeholders::_1);
const double rate_hz = 10.0;
timer_B_->start(rate_hz, timercallback);
}
armor_omni_->start();
ros2_->start();
}
void timerCallback(double dt_ms) {
if (!run_flag_) {
return;
}
GimbalCmd cmd;
try {
InfantryMode::AttackMode mode = InfantryMode::toAttackMode(attack_mode_);
auto module = modules_.at(mode);
if (!module) {
return;
}
cmd = module->solve(bullet_speed_);
} catch (const std::exception& e) {
std::cout << "solve error: " << e.what() << std::endl;
}
if (!cmd.isValid()) {
return;
}
if (!cmd.appear && armor_omni_) {
cmd = armor_omni_->solve(bullet_speed_);
}
last_cmd_ = cmd;
double cmd_pitch = cmd.pitch;
double cmd_yaw = cmd.yaw;
SendRobotCmdData send_data;
send_data.cmd_ID = ID_ROBOT_CMD;
send_data.time_stamp = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now().time_since_epoch()
)
.count();
if (cmd.distance > 0.5) {
send_data.appear = cmd.appear;
} else {
send_data.appear = false;
}
send_data.detect_color = detect_color_;
send_data.pitch = cmd_pitch + cmd.v_pitch * control_config_->pitch_ramp_param.get();
send_data.yaw = cmd_yaw + cmd.v_yaw * control_config_->yaw_ramp_param.get();
send_data.v_pitch = cmd.v_pitch;
send_data.v_yaw = cmd.v_yaw;
send_data.a_pitch = cmd.a_pitch;
send_data.a_yaw = cmd.a_yaw;
send_data.target_yaw = cmd.target_yaw;
send_data.target_pitch = cmd.target_pitch;
send_data.enable_pitch_diff = cmd.enable_pitch_diff;
send_data.enable_yaw_diff = cmd.enable_yaw_diff;
send_data.shoot_rate = shoot_config_->rate_param.get();
if (serial_) {
serial_->write(std::move(wust_vl::common::drivers::toVector(send_data)));
}
}
void timerBCallback(double dt_ms) {
InfantryMode::AttackMode mode = InfantryMode::toAttackMode(attack_mode_);
do {
if (mode == InfantryMode::AttackMode::ARMOR) {
auto auto_aim = auto_aim::toAutoAim(modules_.at(mode));
if (auto_aim) {
auto target = auto_aim->getTarget();
if (armor_omni_) {
armor_omni_->setDetectColor(detect_color_);
armor_omni_->updateMainTracking(target.checkTargetAppear());
}
if (!target.checkTargetAppear()) {
break;
}
auto target_pos = target.target_state_.pos(); // world frame
double big_yaw = 0.0;
Eigen::Rotation2Dd rot(-big_yaw);
Eigen::Vector2d pos_world(target_pos.x(), target_pos.y());
Eigen::Vector2d pos_bigyaw = rot * pos_world;
publishMarker(pos_bigyaw);
sentry_interfaces::msg::Target target_msg;
target_msg.pos.x = pos_bigyaw.x();
target_msg.pos.y = pos_bigyaw.y();
target_msg.pos.z = 0.0;
target_msg.color = detect_color_;
target_msg.id = 0;
target_msg.header.frame_id = "gimbal_yaw";
target_msg.header.stamp = ros2_->now();
ros2_->publish(TARGET_TOPIC, target_msg);
}
}
} while (0);
}
void serialCallback(const uint8_t* data, std::size_t len) {
if (len < 1)
return;
uint8_t cmd = data[0];
try {
if (cmd == ID_AIM_INFO) {
if (len != sizeof(ReceiveAimINFO))
return;
const std::vector<uint8_t> buf(data, data + len);
auto aim_data = wust_vl::common::drivers::fromVector<ReceiveAimINFO>(buf);
processAimData(aim_data);
}
else if (cmd == ID_REFEREE_INFO)
{
if (len != sizeof(ReceiveReferee))
return;
const std::vector<uint8_t> buf(data, data + len);
auto referee_data = wust_vl::common::drivers::fromVector<ReceiveReferee>(buf);
processRefereeData(referee_data);
}
else
{
std::cerr << "Unknown cmd_ID: " << int(cmd) << std::endl;
}
} catch (const std::exception& e) {
std::cerr << "serialCallback exception: " << e.what() << std::endl;
}
}
void processRefereeData(const ReceiveReferee& ref) {
if (debug_mode_) {
updateSerialLog(ref);
flushSerialLog();
}
const auto now = std::chrono::steady_clock::now();
double big_yaw_rad = ref.big_yaw_in_world / 180.0 * M_PI;
if (big_yaw_motion_buffer_) {
BigYaw big_yaw { .big_yaw = big_yaw_rad };
big_yaw_motion_buffer_->push(big_yaw, now);
}
sentry_interfaces::msg::RoboState robo_state;
robo_state.game_time = ref.game_time;
robo_state.cur_hp = ref.cur_health;
robo_state.max_hp = ref.max_health;
robo_state.cur_bullet = ref.cur_bullet;
robo_state.center_state = ref.center_state;
ros2_->publish(ROBO_STATE_TOPIC, robo_state);
}
void publishMarker(const Eigen::Vector2d& pos) {
visualization_msgs::msg::Marker marker;
marker.header.frame_id = "gimbal_yaw";
marker.header.stamp = ros2_->now();
marker.ns = "target";
marker.id = 0;
marker.type = visualization_msgs::msg::Marker::SPHERE;
marker.action = visualization_msgs::msg::Marker::ADD;
marker.pose.position.x = pos.x();
marker.pose.position.y = pos.y();
marker.pose.position.z = 0.0;
marker.pose.orientation.w = 1.0;
marker.scale.x = 0.15;
marker.scale.y = 0.15;
marker.scale.z = 0.15;
marker.color.a = 1.0;
marker.color.r = 1.0;
marker.color.g = 0.0;
marker.color.b = 0.0;
marker.lifetime = rclcpp::Duration(0, 0);
ros2_->publish(TARGET_MARKER, marker);
}
void modeCb(const sentry_interfaces::msg::Mode::SharedPtr msg) {
last_mode_ = msg->mode;
}
int last_mode_ = 0;
void twistCb(const geometry_msgs::msg::Twist::SharedPtr msg) {
NavRobotCmdData send_data;
send_data.cmd_ID = ID_NAV_CMD;
send_data.packet_type = ID_NAV_CONTROL;
send_data.time_stamp =
static_cast<uint32_t>(std::chrono::duration_cast<std::chrono::milliseconds>(
wust_vl::common::utils::time_utils::now().time_since_epoch()
)
.count());
send_data.vx = -msg->linear.y;
send_data.vy = msg->linear.x;
send_data.wz = msg->angular.z;
if (serial_) {
serial_->write(std::move(wust_vl::common::drivers::toVector(send_data)));
}
last_cmd_time_ = wust_vl::common::utils::time_utils::now();
}
std::shared_ptr<Ros2Node> ros2_;
std::unique_ptr<wust_vl::common::utils::Timer> timer_B_;
auto_aim::ArmorOmni::Ptr armor_omni_;
std::chrono::steady_clock::time_point last_cmd_time_;
std::shared_ptr<wust_vl::common::utils::MotionBufferGeneric<BigYaw, 1024>>
big_yaw_motion_buffer_;
bool use_sim_ = false;
};
} // namespace wust_vision
VISION_MAIN(wust_vision::vision)

View File

@@ -0,0 +1,488 @@
// #include "ros2/ros2.hpp"
// #include "sensor_msgs/msg/camera_info.hpp"
// #include "sensor_msgs/msg/image.hpp"
// #include "tasks/utils/config.hpp"
// #include "tasks/utils/main_base.hpp"
// #include "tasks/vision_base.hpp"
// ENABLE_BACKWARD()
// namespace wust_vision {
// class vision {
// public:
// vision() {
// common_config_ = COMMON_CONFIG;
// camera_config_ = CAMERA_CONFIG;
// auto_aim_config_ = AUTO_AIM_CONFIG;
// auto_buff_config_ = AUTO_BUFF_CONFIG;
// };
// ~vision() {
// run_flag_ = false;
// has_camera_info_ = true;
// if (debug_thread_.joinable()) {
// debug_thread_.join();
// }
// WUST_MAIN("main") << "vision stop already!";
// }
// bool init(bool debug_mode) {
// try {
// rclcpp::init(0, nullptr);
// ros2_ = std::make_shared<Ros2Node>("vison_node");
// ros2_->add_subscription<sensor_msgs::msg::Image>(
// "image_raw",
// std::bind(&vision::imageCallback, this, std::placeholders::_1),
// rclcpp::SensorDataQoS()
// );
// ros2_->add_subscription<sensor_msgs::msg::CameraInfo>(
// "camera_info",
// [this](sensor_msgs::msg::CameraInfo::ConstSharedPtr _camera_info) {
// if (has_camera_info_)
// return;
// std::cout << "camera info received" << std::endl;
// auto& msg = *_camera_info;
// cv::Mat K(3, 3, CV_64F);
// std::memcpy(K.data, msg.k.data(), 9 * sizeof(double));
// cv::Mat D(1, msg.d.size(), CV_64F);
// std::memcpy(D.data, msg.d.data(), msg.d.size() * sizeof(double));
// camera_info_ = std::make_pair(K.clone(), D.clone());
// has_camera_info_ = true;
// }
// );
// ros2_->start();
// while (rclcpp::ok() && !has_camera_info_) {
// std::this_thread::sleep_for(std::chrono::milliseconds(1000));
// WUST_INFO("sim") << "Waiting for camera info...";
// }
// const char* v = std::getenv("VISION_ROOT");
// if (v)
// std::cout << "[env] VISION_ROOT = " << v << "\n";
// else
// std::cout << "[env] VISION_ROOT not set in this process\n";
// debug_mode_ = debug_mode;
// control_config_ = ControlConfig::create(this);
// shoot_config_ = ShootConfig::create();
// logger_config_ = LoggerConfig::create();
// tf_config_ = TFConfig::create();
// max_infer_running_config_ = MaxInferRunningConfig::create();
// common_config_parameter_.registerGroup(*control_config_);
// common_config_parameter_.registerGroup(*shoot_config_);
// common_config_parameter_.registerGroup(*logger_config_);
// common_config_parameter_.registerGroup(*tf_config_);
// common_config_parameter_.registerGroup(*max_infer_running_config_);
// common_config_parameter_.loadFromFile(common_config_);
// auto config = common_config_parameter_.getConfig();
// debug_fps_ = config["debug_fps"].as<int>();
// attack_mode_ = config["attack_mode"].as<int>();
// detect_color_ = config["detect_color"].as<int>();
// wust_vl::common::utils::ParameterManager::instance().registerParameter(
// common_config_parameter_
// );
// auto_aim_ = std::make_unique<auto_aim::AutoAim>(
// auto_aim_config_,
// tf_config_,
// camera_info_,
// debug_mode
// );
// auto_buff_ = std::make_unique<auto_buff::AutoBuff>(
// auto_buff_config_,
// tf_config_,
// camera_info_,
// debug_mode
// );
// thread_pool_ = std::make_unique<wust_vl::common::concurrency::ThreadPool>(
// std::thread::hardware_concurrency() * 2
// );
// motion_buffer_ =
// std::make_shared<wust_vl::common::utils::MotionBufferGeneric<Motion, 1024>>();
// timer_ = std::make_unique<wust_vl::common::utils::Timer>();
// } catch (std::exception& e) {
// std::cerr << "init exception: " << e.what() << std::endl;
// }
// return true;
// }
// void imageCallback(const sensor_msgs::msg::Image::ConstSharedPtr img_msg) {
// if (!run_flag_) {
// return;
// }
// cv::Mat image(
// img_msg->height,
// img_msg->width,
// CV_8UC3,
// const_cast<unsigned char*>(img_msg->data.data()), // raw pointer
// img_msg->step
// );
// CommonFrame common_frame;
// common_frame.img_frame.src_img = std::move(image);
// common_frame.detect_color = detect_color_;
// common_frame.img_frame.timestamp = std::chrono::steady_clock::now();
// common_frame.img_frame.pixel_format = wust_vl::video::PixelFormat::BGR;
// common_frame.expanded = cv::Rect(
// 0,
// 0,
// common_frame.img_frame.src_img.cols,
// common_frame.img_frame.src_img.rows
// );
// common_frame.any_ctx = VisionCtx { .motion_buffer = motion_buffer_,
// .communication_delay_μs =
// control_config_->communication_delay_us_param.get(),
// .attack_mode = toAttackMode(attack_mode_) };
// common_frame.offset = cv::Point2f(0, 0);
// thread_pool_->enqueue([this, frame = std::move(common_frame)]() mutable {
// infer_running_count_++;
// if (frame.img_frame.src_img.empty()) {
// return;
// }
// AttackMode mode = toAttackMode(attack_mode_);
// switch (mode) {
// case AttackMode::ARMOR: {
// auto_aim_->pushInput(frame);
// } break;
// case AttackMode::SMALL_RUNE: {
// auto_buff_->pushInput(frame);
// } break;
// case AttackMode::BIG_RUNE: {
// auto_buff_->pushInput(frame);
// } break;
// case AttackMode::UNKNOWN: {
// auto_aim_->pushInput(frame);
// } break;
// }
// infer_running_count_--;
// });
// }
// void start() {
// run_flag_ = true;
// auto_aim_->start();
// auto_buff_->start();
// if (timer_) {
// auto timercallback = std::bind(&vision::timerCallback, this, std::placeholders::_1);
// double rate_hz = control_config_->control_rate_param.get();
// timer_->start(rate_hz, timercallback);
// }
// if (debug_mode_) {
// debug_thread_ = std::thread([this]() { this->debugThread(); });
// }
// }
// void timerCallback(double dt_ms) {
// if (!run_flag_) {
// return;
// }
// geometry_msgs::msg::TransformStamped tf;
// if (!ros2_->lookup_transform("odom", "gimbal_link", tf)) {
// RCLCPP_WARN(ros2_->get_logger(), "TF lookup failed");
// } else {
// Eigen::Quaterniond q(
// tf.transform.rotation.w,
// tf.transform.rotation.x,
// tf.transform.rotation.y,
// tf.transform.rotation.z
// );
// // Convert to rotation matrix
// Eigen::Matrix3d R = q.toRotationMatrix();
// // Euler angles (ZYX order -> yaw pitch roll)
// Eigen::Vector3d euler = R.eulerAngles(2, 1, 0); // yaw, pitch, roll
// double yaw = euler[0];
// double pitch = euler[1];
// double roll = euler[2];
// if (motion_buffer_) {
// Motion motion { yaw, -pitch, roll, 0.0, 0, 0, 0, 0, 0 };
// motion_buffer_->push(motion, std::chrono::steady_clock::now());
// }
// }
// GimbalCmd cmd;
// try {
// AttackMode mode = toAttackMode(attack_mode_);
// switch (mode) {
// case AttackMode::ARMOR: {
// cmd = auto_aim_->solve(shoot_config_->bullet_speed_param.get());
// } break;
// case AttackMode::SMALL_RUNE:
// case AttackMode::BIG_RUNE: {
// cmd = auto_buff_->solve(shoot_config_->bullet_speed_param.get());
// } break;
// case AttackMode::UNKNOWN: {
// cmd = auto_aim_->solve(shoot_config_->bullet_speed_param.get());
// } break;
// }
// } catch (const std::exception& e) {
// std::cout << "auto_aim_solve error: " << e.what() << std::endl;
// }
// last_cmd_ = cmd;
// }
// void debugThread() {
// const double us_interval = 1e6 / static_cast<double>(debug_fps_);
// const auto kInterval = std::chrono::microseconds(static_cast<int64_t>(us_interval));
// while (run_flag_) {
// const auto start_time = std::chrono::steady_clock::now();
// do {
// try {
// if (!auto_aim_ || !auto_buff_) {
// break;
// }
// auto dbg_armor = auto_aim_->getDebugFrame();
// auto dbg_rune = auto_buff_->getDebugFrame();
// AttackMode mode = toAttackMode(attack_mode_);
// switch (mode) {
// case AttackMode::UNKNOWN:
// case AttackMode::ARMOR: {
// drawDebugOverlayShm(dbg_armor, camera_info_, false);
// } break;
// case AttackMode::SMALL_RUNE:
// case AttackMode::BIG_RUNE: {
// drawDebugOverlayShm(dbg_rune, camera_info_, false);
// } break;
// }
// std::pair<double, double> gimbal_py;
// if (motion_buffer_) {
// auto last_att = motion_buffer_->get_last();
// if (last_att) {
// gimbal_py.first = last_att->data.pitch;
// gimbal_py.second = last_att->data.yaw;
// }
// }
// debuglog(dbg_armor, dbg_rune, last_cmd_, gimbal_py);
// utils::XSecOnce(
// [this]() {
// wust_vl::common::utils::ParameterManager::instance()
// .allReloadFromOldPath();
// },
// 1.0
// );
// } catch (std::exception& e) {
// std::cout << "debug thread error: " << e.what() << std::endl;
// }
// } while (0);
// const auto elapsed = std::chrono::steady_clock::now() - start_time;
// if (elapsed < kInterval) {
// std::this_thread::sleep_for(kInterval - elapsed);
// }
// }
// }
// void checkStateMatchMode() {
// AttackMode mode = toAttackMode(attack_mode_);
// switch (mode) {
// case AttackMode::ARMOR: {
// if (!auto_aim_->isActive()) {
// auto_aim_->processingUp();
// }
// if (auto_buff_->isActive()) {
// auto_buff_->processingWait();
// }
// } break;
// case AttackMode::SMALL_RUNE:
// case AttackMode::BIG_RUNE: {
// if (auto_aim_->isActive()) {
// auto_aim_->processingWait();
// }
// if (!auto_buff_->isActive()) {
// auto_buff_->processingUp();
// }
// } break;
// case AttackMode::UNKNOWN: {
// if (!auto_aim_->isActive()) {
// auto_aim_->processingUp();
// }
// if (auto_buff_->isActive()) {
// auto_buff_->processingWait();
// }
// } break;
// }
// }
// struct ControlConfig: wust_vl::common::utils::ParamGroup {
// public:
// static constexpr const char* kKey = "control";
// static constexpr const char* Logger = "Config: common::control";
// const char* key() const override {
// return kKey;
// }
// using Ptr = std::shared_ptr<ControlConfig>;
// ControlConfig(vision* b) {
// base = b;
// communication_delay_us_param.onChange([this](int o, int n) {
// if (isBaseACtive()) {
// WUST_DEBUG(Logger)
// << "communication_delay_μs from: " << o << " to: " << n << " us";
// }
// });
// }
// static Ptr create(vision* b) {
// return std::make_shared<ControlConfig>(b);
// }
// GEN_PARAM(double, communication_delay_us);
// GEN_PARAM(double, yaw_ramp);
// GEN_PARAM(double, pitch_ramp);
// GEN_PARAM(double, control_rate);
// vision* base;
// bool first_load = false;
// bool isBaseACtive() {
// return base != nullptr;
// }
// void loadSelf(const YAML::Node& node) override {
// if (!isBaseACtive())
// return;
// if (!first_load) {
// communication_delay_us_param.set(node["communication_delay_us"].as<double>());
// yaw_ramp_param.set(node["yaw_ramp"].as<double>());
// pitch_ramp_param.set(node["pitch_ramp"].as<double>());
// control_rate_param.set(node["control_rate"].as<double>());
// first_load = true;
// } else {
// communication_delay_us_param.load(node);
// yaw_ramp_param.load(node);
// pitch_ramp_param.load(node);
// control_rate_param.load(node);
// }
// }
// };
// ControlConfig::Ptr control_config_;
// struct ShootConfig: wust_vl::common::utils::ParamGroup {
// public:
// static constexpr const char* kKey = "shoot";
// static constexpr const char* Logger = "Config: common::shoot";
// const char* key() const override {
// return kKey;
// }
// using Ptr = std::shared_ptr<ShootConfig>;
// ShootConfig() {
// rate_param.onChange([this](int o, int n) {
// WUST_DEBUG(Logger) << "shoot_rate from: " << o << " to: " << n << " HZ";
// });
// }
// static Ptr create() {
// return std::make_shared<ShootConfig>();
// }
// GEN_PARAM(int, rate);
// GEN_PARAM(double, bullet_speed);
// bool first_load = false;
// void loadSelf(const YAML::Node& node) override {
// if (!first_load) {
// rate_param.set(node["rate"].as<int>());
// bullet_speed_param.set(node["bullet_speed"].as<double>());
// first_load = true;
// } else {
// rate_param.load(node);
// }
// }
// };
// ShootConfig::Ptr shoot_config_;
// LoggerConfig::Ptr logger_config_;
// TFConfig::Ptr tf_config_;
// MaxInferRunningConfig::Ptr max_infer_running_config_;
// int attack_mode_;
// int debug_fps_;
// bool detect_color_;
// wust_vl::common::utils::Parameter common_config_parameter_;
// std::unique_ptr<wust_vl::common::concurrency::ThreadPool> thread_pool_;
// std::unique_ptr<auto_aim::AutoAim> auto_aim_;
// std::unique_ptr<auto_buff::AutoBuff> auto_buff_;
// std::unique_ptr<wust_vl::common::utils::Timer> timer_;
// std::shared_ptr<wust_vl::common::utils::MotionBufferGeneric<Motion, 1024>> motion_buffer_;
// std::thread debug_thread_;
// GimbalCmd last_cmd_;
// std::pair<cv::Mat, cv::Mat> camera_info_;
// bool run_flag_ = false;
// bool debug_mode_ = false;
// std::atomic<int> infer_running_count_ { 0 };
// std::string common_config_;
// std::string camera_config_;
// std::string auto_aim_config_;
// std::string auto_buff_config_;
// bool has_camera_info_ = false;
// std::shared_ptr<Ros2Node> ros2_;
// };
// } // namespace wust_vision
// // VISION_MAIN(wust_vision::vision)
// int main(int argc, char** argv) {
// wust_vision::printBanner();
// bool debug = false;
// if (argc > 1) {
// std::string firstArg = argv[1];
// debug = (firstArg == "true" || firstArg == "1");
// std::cout << "debug: " << firstArg << std::endl;
// }
// std::set_terminate([]() {
// std::cerr << "Uncaught exception, terminating program.\n";
// if (auto e = std::current_exception()) {
// try {
// std::rethrow_exception(e);
// } catch (const std::exception& ex) {
// std::cerr << "Exception: " << ex.what() << std::endl;
// } catch (...) {
// std::cerr << "Unknown exception" << std::endl;
// }
// }
// std::abort();
// });
// try {
// int exit_code = 0;
// {
// wust_vision::vision v;
// v.init(debug);
// v.start();
// wust_vl::common::utils::SignalHandler sig;
// sig.start([&] { rclcpp::shutdown(); });
// bool exit_flag = false;
// while (!sig.shouldExit() && !exit_flag) {
// wust_vl::common::concurrency::ThreadManager::instance().printStatus();
// auto all_status =
// wust_vl::common::concurrency::ThreadManager::instance().getAllThreadStatuses();
// v.checkStateMatchMode();
// for (auto& status: all_status) {
// if (status.second
// == wust_vl::common::concurrency::MonitoredThread::Status::Hung) {
// std::cerr << status.first << " is Hunging! Exiting program..." << std::endl;
// exit_flag = true;
// exit_code = -1;
// sig.requestExit();
// break;
// }
// }
// std::this_thread::sleep_for(std::chrono::milliseconds(1000));
// }
// }
// std::cout << "Exiting program..." << std::endl;
// return exit_code;
// } catch (const std::exception& e) {
// std::cerr << "Caught exception in main: " << e.what() << "\n";
// throw;
// return -1;
// } catch (...) {
// std::cerr << "Unknown exception caught in main!\n";
// return -1;
// }
// }

View File

@@ -0,0 +1,27 @@
#include "tasks/imodule.hpp"
#include "tasks/utils/config.hpp"
#include "tasks/utils/main_base.hpp"
#include "tasks/vision_base.hpp"
ENABLE_BACKWARD()
namespace wust_vision {
class vision: public VisionBase<InfantryMode> {
public:
vision(): VisionBase(COMMON_CONFIG, CAMERA_CONFIG, AUTO_AIM_CONFIG, AUTO_BUFF_CONFIG) {}
bool init(bool debug_mode) {
if (!VisionBase::init(debug_mode)) {
return false;
}
auto auto_aim =
auto_aim::AutoAim::create(auto_aim_config_, tf_config_, camera_info_, debug_mode);
modules_.emplace(InfantryMode::AttackMode::ARMOR, auto_aim);
modules_.emplace(InfantryMode::AttackMode::UNKNOWN, auto_aim);
auto auto_buff =
auto_buff::AutoBuff::create(auto_buff_config_, tf_config_, camera_info_, debug_mode);
modules_.emplace(InfantryMode::AttackMode::BIG_RUNE, auto_buff);
modules_.emplace(InfantryMode::AttackMode::SMALL_RUNE, auto_buff);
return true;
}
};
} // namespace wust_vision
VISION_MAIN(wust_vision::vision)

View File

@@ -0,0 +1,424 @@
:root {
--primary-color: #00bfff;
--primary-dark: #004f8c;
--accent-color: #00d4ff;
--bg-dark: #121720;
--text-light: #a0d6f5;
--shadow-glow: #00d4ffcc;
}
/* 基础样式 */
html, body {
height: 100%;
margin: 0;
padding: 0;
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background: linear-gradient(135deg, var(--bg-dark) 0%, #0f1526 100%);
color: var(--text-light);
-webkit-user-select: text;
user-select: text;
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
text-shadow: 0 0 3px rgba(0, 120, 255, 0.4);
}
/* 头部区域 */
#header {
display: flex;
align-items: center;
padding: 12px 24px;
background: rgba(0,31,63,0.85);
-webkit-backdrop-filter: saturate(180%) blur(6px);
backdrop-filter: saturate(180%) blur(6px);
background-image: linear-gradient(90deg, var(--primary-dark), #003366);
border-bottom: 3px solid var(--primary-color);
box-shadow: 0 0 14px var(--shadow-glow);
height: 100px;
box-sizing: border-box;
}
#header img {
width: 52px;
height: 52px;
margin-right: 20px;
border-radius: 12px;
box-shadow: 0 0 12px var(--primary-color);
transition: transform 0.3s ease;
cursor: pointer;
}
#header img:hover {
animation: glowPulse 2s infinite ease-in-out;
transform: scale(1.1);
}
/* 标签按钮栏 */
#tab-buttons {
padding: 10px 24px;
background: #101425;
border-bottom: 2px solid #004466;
box-shadow: inset 0 -2px 8px #00334daa;
height: 44px;
box-sizing: border-box;
display: flex;
align-items: center;
}
#tab-buttons button {
background: linear-gradient(145deg, #0088cc, #005f7f);
border: none;
color: #ccf8ff;
padding: 8px 18px;
margin-right: 14px;
border-radius: 10px;
font-weight: 600;
cursor: pointer;
box-shadow:
inset 0 -2px 4px #004a6d99,
0 0 12px var(--primary-color);
transition: all 0.3s ease;
-webkit-user-select: text;
user-select: text;
text-shadow: 0 0 4px var(--primary-color);
}
#tab-buttons button:hover, #tab-buttons button:focus {
background: linear-gradient(145deg, var(--primary-color), #0077aa);
box-shadow:
inset 0 -2px 6px #006db6cc,
0 0 22px var(--accent-color);
outline: none;
}
#tab-buttons button:active {
transform: scale(0.97);
box-shadow:
inset 0 2px 4px #002d43cc,
0 0 16px #0088ccdd;
}
/* 视频区域 */
#video-tab {
height: calc(100vh - 60px - 44px);
width: 100vw;
box-sizing: border-box;
background: #0f131f;
padding: 12px 20px;
overflow: hidden;
box-shadow: inset 0 0 40px #003554aa;
}
/* 主网格布局 */
#main-grid {
display: grid;
grid-template-columns: 1fr 1fr;
grid-template-rows: 1fr 1fr;
height: 100%;
width: 100%;
gap: 6px;
box-sizing: border-box;
background: #121b2b;
border-radius: 14px;
box-shadow:
0 0 12px #007acc88,
inset 0 0 20px #002b4daa;
}
/* 面板通用样式 */
.panel {
background: linear-gradient(145deg, #1b2237, #121a2e);
padding: 16px;
box-sizing: border-box;
height: 100%;
width: 100%;
display: flex;
flex-direction: column;
border-radius: 14px;
border: 1.8px solid transparent;
background-clip: padding-box;
position: relative;
overflow: hidden;
cursor: default;
box-shadow:
0 0 15px var(--primary-color),
inset 0 0 25px var(--primary-dark);
}
.panel::before {
content: "";
position: absolute;
inset: 0;
border-radius: 14px;
padding: 1.8px;
pointer-events: none;
background: linear-gradient(60deg, var(--primary-color), var(--primary-dark), var(--accent-color), #005fbc);
-webkit-mask:
linear-gradient(#fff 0 0) content-box,
linear-gradient(#fff 0 0);
mask:
linear-gradient(#fff 0 0) content-box,
linear-gradient(#fff 0 0);
-webkit-mask-composite: destination-out;
mask-composite: exclude;
animation: pulseBorder 3s ease-in infinite;
}
/* 边框发光动画 */
@keyframes pulseBorder {
0% {
filter: drop-shadow(0 0 5px var(--primary-color));
}
100% {
filter: drop-shadow(0 0 15px var(--accent-color));
}
}
/* 四个区域微调 */
#video-panel { padding-right: 10px; }
#json-info { padding-left: 10px; overflow-y: auto; }
#main-chart-panel { padding-right: 10px; overflow-y: auto; }
#individual-chart-panel { padding-left: 10px; overflow-y: auto; }
/* 视频容器 */
.video-container {
flex-grow: 1;
display: flex;
justify-content: center;
align-items: center;
overflow: hidden;
border-radius: 12px;
box-shadow: inset 0 0 40px #005a9e88;
background: #12203f;
position: relative;
}
.video-container img {
max-width: 100%;
max-height: 100%;
border-radius: 12px;
border: 2px solid #006db6;
box-shadow:
0 0 20px var(--primary-color),
inset 0 0 30px #004975cc;
transition: transform 0.4s ease, box-shadow 0.4s ease;
cursor: pointer;
}
.video-container img:hover {
transform: scale(1.05);
box-shadow:
0 0 30px var(--accent-color),
inset 0 0 40px #0069c9cc;
}
/* JSON信息容器 */
#json-info {
color: #a0cbdc;
font-family: "Source Code Pro", monospace, monospace;
overflow-y: auto;
text-shadow: 0 0 3px #005e8a;
-webkit-user-select: text;
user-select: text;
}
/* JSON子容器 */
.json-container {
background: linear-gradient(135deg, #12273e 0%, #0c1c33 100%);
border-radius: 12px;
padding: 14px 18px;
margin-bottom: 14px;
height: 48%;
overflow-y: auto;
box-shadow: inset 0 0 12px #004573;
border: 1.2px solid #0068ac;
}
/* 主图表和独立图表容器撑满 */
#main-chart-panel,
#individual-chart-panel {
height: 100%;
overflow-y: auto;
display: flex;
flex-direction: column;
}
/* 主图表样式 */
#mainChart {
flex-grow: 1;
width: 100% !important;
border-radius: 14px;
background: linear-gradient(135deg, #1a243a, #0f1526);
padding: 12px;
box-sizing: border-box;
margin-bottom: 14px;
box-shadow:
0 0 12px #0099ffbb,
inset 0 0 18px #004575aa;
border: 2px solid #0088ee;
transition: box-shadow 0.3s ease;
cursor: default;
}
#mainChart:hover {
box-shadow:
0 0 24px var(--primary-color),
inset 0 0 30px #0077ccbb;
}
/* 独立图表容器 */
#individualCharts {
flex-grow: 1;
overflow-y: auto;
display: flex;
flex-wrap: wrap;
gap: 18px;
padding-bottom: 6px;
}
/* 交互控件 */
.chart-controls {
margin-bottom: 10px;
color: #a0d4f7;
text-shadow: 0 0 3px #005e8a;
-webkit-user-select: text;
user-select: text;
}
/* 按钮、下拉框等统一风格 */
button, select, input[type="checkbox"] {
background: linear-gradient(145deg, #00bcd4, #0088aa);
color: #001a25;
border: none;
padding: 8px 14px;
border-radius: 10px;
margin-right: 10px;
cursor: pointer;
font-weight: 700;
box-shadow: 0 0 14px #00cfffcc;
transition: all 0.25s ease;
-webkit-user-select: text;
user-select: text;
text-shadow: 0 0 6px #00dfffcc;
}
button:hover, select:hover {
background: linear-gradient(145deg, #00e3ff, #0099cc);
box-shadow: 0 0 24px #00eaffee;
}
input[type="checkbox"] {
transform: scale(1.3);
cursor: pointer;
}
/* 标签文字 */
label {
color: #9bc8f7;
margin-right: 12px;
-webkit-user-select: text;
user-select: text;
text-shadow: 0 0 3px #004477;
}
/* 配置面板 */
#config-tab {
padding: 12px;
background: linear-gradient(135deg, #132237, #0b1528);
border-radius: 14px;
border: 1.5px solid #005f8f;
box-shadow: 0 0 12px #0077bbaa;
color: #a0cde8;
}
/* 响应式调整:小屏幕时改为单列布局 */
@media (max-width: 1024px) {
#video-tab {
height: auto;
}
#main-grid {
grid-template-columns: 1fr;
grid-template-rows: auto auto auto auto;
height: auto;
gap: 18px;
}
.panel {
height: auto;
}
}
/* 头部图片发光动画 */
@keyframes glowPulse {
0%, 100% {
box-shadow: 0 0 10px var(--primary-color);
}
50% {
box-shadow: 0 0 20px var(--accent-color);
}
}
/* 全屏模式下,仅显示视频容器,隐藏其他所有元素 */
.fullscreen-mode #header,
.fullscreen-mode #tab-buttons,
.fullscreen-mode #json-info,
.fullscreen-mode #main-chart-panel,
.fullscreen-mode #individual-chart-panel,
.fullscreen-mode .fullscreen-btn {
display: none !important;
}
.fullscreen-mode #main-grid {
padding: 0 !important;
margin: 0 !important;
gap: 0 !important;
background: black !important;
box-shadow: none !important;
border-radius: 0 !important;
}
.fullscreen-mode #video-panel {
grid-column: 1 / -1;
grid-row: 1 / -1;
padding: 0 !important;
}
/* 全屏时视频容器铺满全屏 */
.fullscreen-mode .video-container {
position: fixed;
top: 0; left: 0;
width: 100vw !important;
height: 100vh !important;
margin: 0; padding: 0;
background: black !important;
border-radius: 0 !important;
box-shadow: none !important;
overflow: hidden;
z-index: 9999;
display: flex !important;
justify-content: center;
align-items: center;
}
/* 图片铺满容器,保持比例,可能裁剪 */
.fullscreen-mode .video-container img {
height: 90% !important;
width: auto !important;
max-width: 100vw !important;
object-fit: contain !important;
border: none !important;
box-shadow: none !important;
border-radius: 0 !important;
}
/* 全屏按钮样式 */
.fullscreen-btn {
position: absolute;
bottom: 12px;
right: 14px;
padding: 8px 12px;
font-size: 14px;
background: rgba(0, 136, 204, 0.85);
color: white;
border: none;
border-radius: 8px;
cursor: pointer;
box-shadow: 0 0 10px #00bfff88;
z-index: 10;
transition: background 0.2s ease;
}
.fullscreen-btn:hover {
background: rgba(0, 180, 255, 0.95);
}

View File

@@ -0,0 +1,249 @@
const UPDATE_HZ = 100;
const UPDATE_INTERVAL_MS = 1000 / UPDATE_HZ;
let updateTimer = null;
function startUpdateLoop() {
if (updateTimer) return;
updateTimer = setInterval(fetchDataAndUpdateCharts, UPDATE_INTERVAL_MS);
}
startUpdateLoop();
const chartMap = {
raw_yaw: { label: "Raw Yaw" },
raw_pitch: { label: "Raw Pitch" },
yaw: { label: "Yaw" },
pitch: { label: "Pitch" },
armor_dis: { label: "Armor Distance" },
armor_x: { label: "Armor X" },
armor_y: { label: "Armor Y" },
armor_z: { label: "Armor Z" },
armor_yaw: { label: "Armor Yaw" },
ypd_p: { label: "Ypd Pitch" },
ypd_y: { label: "Ypd Yaw" },
rune_obs: { label: "Rune Obs" },
rune_pre: { label: "Rune Pre" },
rune_obsv: { label: "Rune ObsV" },
rune_fitv: { label: "Rune FitV" },
gimbal_yaw: { label: "Gimbal Yaw" },
gimbal_pitch: { label: "Gimbal Pitch" },
target_v_yaw: { label: "Target V Yaw" },
control_v_yaw: { label: "Control V Yaw" },
control_v_pitch: { label: "Control V Pitch" },
yaw_diff: { label: "Yaw Diff" },
fire: { label: "Fire" },
rune_dis: { label: "Rune Distance" },
fly_time: { label: "Fly Time" },
control_a_yaw: { label: "Control A Yaw" },
control_a_pitch: { label: "Control A Pitch" },
};
const mainCtx = document.getElementById("mainChart").getContext("2d");
let individualCharts = {};
let individualRanges = {};
const commonChartOptions = {
animation: false,
responsive: false,
interaction: {
mode: "nearest",
axis: "x",
intersect: false,
},
elements: {
line: {
tension: 0,
},
point: {
radius: 0,
hoverRadius: 0,
},
},
plugins: {
tooltip: {
enabled: true,
mode: "nearest",
intersect: false,
backgroundColor: "rgba(0, 188, 212, 0.85)",
padding: 8,
cornerRadius: 6,
titleFont: {
size: 12,
family: "'Segoe UI', Tahoma, Geneva, Verdana, sans-serif",
},
bodyFont: {
size: 16,
family: "'Segoe UI', Tahoma, Geneva, Verdana, sans-serif",
weight: "bold",
},
callbacks: {
title: () => "",
label: (context) => `值: ${context.parsed.y.toFixed(3)}`,
},
},
legend: {
labels: {
font: {
size: 14,
family: "'Segoe UI', Tahoma, Geneva, Verdana, sans-serif",
},
color: "#00bcd4",
},
},
},
scales: {
x: {
display: false, // ✅ 关键:彻底关闭 X 轴显示
},
y: {
title: { display: true, text: "Value" },
ticks: {
color: "#00bcd4",
font: {
size: 14,
family: "'Segoe UI', Tahoma, Geneva, Verdana, sans-serif",
},
},
grid: { color: "#2f3241" },
},
},
};
const mainChart = new Chart(mainCtx, {
type: "line",
data: { labels: [], datasets: [] },
options: commonChartOptions,
});
function updateMainRange() {
const maxPts = parseInt(document.getElementById("mainMaxPts").value) || 100;
mainChart._maxPoints = maxPts;
}
function updateCharts() {
const showMulti = document.getElementById("multiLineChart").checked;
const selected = Array.from(
document.querySelectorAll(
'.chart-select-controls input[type="checkbox"]:checked'
)
)
.map((cb) => cb.dataset.key)
.filter(Boolean);
const container = document.getElementById("individualCharts");
container.innerHTML = "";
individualCharts = {};
individualRanges = {};
if (showMulti) {
mainChart.data.datasets = selected.map((key) => ({
label: chartMap[key]?.label || key,
data: [],
fill: false,
}));
} else {
mainChart.data.datasets = [];
}
mainChart.update();
selected.forEach((key) => {
const box = document.createElement("div");
box.className = "chart-box";
const title = document.createElement("h4");
title.textContent = chartMap[key]?.label || key;
box.appendChild(title);
const rdiv = document.createElement("div");
rdiv.className = "range-controls";
rdiv.innerHTML = `
<label><input type="checkbox" class="childEnable" /> 固定范围</label>
min:<input type="number" class="childMin" value="0" step="0.1" />
max:<input type="number" class="childMax" value="1" step="0.1" />
<button type="button" class="applyRange">应用</button>
`;
box.appendChild(rdiv);
const canvas = document.createElement("canvas");
canvas.width = 440;
canvas.height = 280;
box.appendChild(canvas);
container.appendChild(box);
const ctx = canvas.getContext("2d");
const chart = new Chart(ctx, {
type: "line",
data: {
labels: [],
datasets: [
{
label: chartMap[key]?.label || key,
data: [],
fill: false,
},
],
},
options: commonChartOptions,
});
individualCharts[key] = chart;
const applyBtn = rdiv.querySelector(".applyRange");
applyBtn.addEventListener("click", () => {
const enabled = rdiv.querySelector(".childEnable").checked;
const minVal = parseFloat(rdiv.querySelector(".childMin").value);
const maxVal = parseFloat(rdiv.querySelector(".childMax").value);
chart.options.scales.y.min = enabled ? minVal : undefined;
chart.options.scales.y.max = enabled ? maxVal : undefined;
chart.update("none");
});
});
}
async function fetchDataAndUpdateCharts() {
try {
const res = await fetch("/data");
const json = await res.json();
const time = json.time;
if (!time) return;
const maxPts = mainChart._maxPoints || 100;
const start = time.length > maxPts ? time.length - maxPts : 0;
const slicedTime = time.slice(start);
if (document.getElementById("multiLineChart").checked) {
mainChart.data.labels = slicedTime;
const keys = Object.keys(individualCharts);
mainChart.data.datasets.forEach((ds, i) => {
const key = keys[i];
ds.data = json[key]?.slice(start) || [];
});
const allValues = mainChart.data.datasets.flatMap(ds => ds.data);
if (allValues.length > 0) {
const minVal = Math.min(...allValues);
const maxVal = Math.max(...allValues);
const padding = (maxVal - minVal) * 0.1 || 1; // 至少留1的余量
mainChart.options.scales.y.min = minVal - padding;
mainChart.options.scales.y.max = maxVal + padding;
}
mainChart.update();
}
Object.entries(individualCharts).forEach(([key, ch]) => {
ch.data.labels = slicedTime;
ch.data.datasets[0].data = json[key]?.slice(start) || [];
// 每个子图也加 margin避免线条贴边
const arr = ch.data.datasets[0].data;
if (arr.length > 0) {
const minVal = Math.min(...arr);
const maxVal = Math.max(...arr);
const padding = (maxVal - minVal) * 0.1 || 1;
ch.options.scales.y.min = minVal - padding;
ch.options.scales.y.max = maxVal + padding;
}
ch.update();
});
} catch (e) {
console.error("fetch error:", e);
}
}

View File

@@ -0,0 +1,61 @@
function jsonToHtml(data, container) {
const title = container.querySelector("h3");
container.innerHTML = "";
if (title) container.appendChild(title);
const contentDiv = document.createElement("div");
container.appendChild(contentDiv);
function buildTree(d, parent) {
if (typeof d !== "object" || d === null) {
parent.textContent = String(d);
return;
}
const ul = document.createElement("ul");
ul.className = "json-tree";
const entries = Array.isArray(d) ? d.map((v, i) => [i, v]) : Object.entries(d);
entries.forEach(([k, v]) => {
const li = document.createElement("li");
if (typeof v === "object" && v !== null) {
const details = document.createElement("details");
details.open = true;
const summary = document.createElement("summary");
summary.textContent = k;
details.appendChild(summary);
const childUl = document.createElement("ul");
childUl.className = "json-tree";
buildTree(v, childUl);
details.appendChild(childUl);
li.appendChild(details);
} else {
li.textContent = `${k}: ${v}`;
}
ul.appendChild(li);
});
parent.appendChild(ul);
}
buildTree(data, contentDiv);
}
let lastSerialData = null, lastTargetData = null;
async function fetchAndDisplayJsonWithTree(id, url) {
const parent = document.getElementById(id + "-container");
const cont = document.getElementById(id);
try {
parent.classList.add("json-updating");
const res = await fetch(url);
if (!res.ok) throw new Error(res.statusText);
const data = await res.json();
const prev = id === "json-serial" ? lastSerialData : lastTargetData;
if (JSON.stringify(data) !== JSON.stringify(prev)) {
jsonToHtml(data, cont);
if (id === "json-serial") lastSerialData = data;
else lastTargetData = data;
}
} catch (e) {
console.warn(`请求失败(${url}): ${e.message}`);
} finally {
parent.classList.remove("json-updating");
}
}

View File

@@ -0,0 +1,35 @@
function showTab(tab) {
document.getElementById("video-tab").style.display = tab === "video" ? "flex" : "none";
}
document.addEventListener("DOMContentLoaded", () => {
updateCharts();
updateMainRange();
setInterval(() => {
fetchDataAndUpdateCharts();
fetchAndDisplayJsonWithTree("json-target", "/target_log");
fetchAndDisplayJsonWithTree("json-serial", "/serial_log");
}, 200);
});
function toggleFullscreen() {
const container = document.querySelector('.video-container');
if (!document.fullscreenElement) {
container.requestFullscreen().then(() => {
document.body.classList.add('fullscreen-mode');
}).catch((err) => {
alert(`无法进入全屏模式: ${err.message}`);
});
} else {
document.exitFullscreen().then(() => {
document.body.classList.remove('fullscreen-mode');
});
}
}
// 监听用户按 ESC 或其他方式退出全屏
document.addEventListener('fullscreenchange', () => {
if (!document.fullscreenElement) {
document.body.classList.remove('fullscreen-mode');
}
});

Binary file not shown.

After

Width:  |  Height:  |  Size: 94 KiB

View File

@@ -0,0 +1,43 @@
add_library(tinympcstatic STATIC
admm.cpp
tiny_api.cpp
codegen.cpp
rho_benchmark.cpp
)
set_property(TARGET tinympcstatic PROPERTY POSITION_INDEPENDENT_CODE ON)
# target_link_libraries(tinympcstatic PUBLIC Eigen)
target_include_directories(tinympcstatic PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/..)
target_include_directories(tinympcstatic PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../include/Eigen)
if(USING_CODEGEN) # Defined in top-level CMakeLists.txt
# Files that are needed for embedded code generation
list( APPEND EMBEDDED_FILES
"${CMAKE_CURRENT_SOURCE_DIR}/admm.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/admm.hpp"
"${CMAKE_CURRENT_SOURCE_DIR}/tiny_api.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/tiny_api.hpp"
"${CMAKE_CURRENT_SOURCE_DIR}/types.hpp"
"${CMAKE_CURRENT_SOURCE_DIR}/tiny_api_constants.hpp" )
foreach( f ${EMBEDDED_FILES} )
get_filename_component( fname ${f} NAME )
set( dest_file "${EMBEDDED_BUILD_TINYMPC_DIR}/${fname}" )
list( APPEND EMBEDDED_BUILD_TINYMPC_FILES "${dest_file}" )
add_custom_command(OUTPUT ${dest_file}
COMMAND ${CMAKE_COMMAND} -E copy "${f}" "${dest_file}"
DEPENDS ${f}
COMMENT "Copying ${fname}")
endforeach()
add_custom_target( copy_codegen_tinympc_files DEPENDS ${EMBEDDED_BUILD_TINYMPC_FILES} )
add_dependencies( copy_codegen_files copy_codegen_tinympc_files )
endif(USING_CODEGEN)

View File

@@ -0,0 +1,399 @@
#include <iostream>
#include "admm.hpp"
#include "rho_benchmark.hpp"
#define DEBUG_MODULE "TINYALG"
extern "C" {
/**
* Update linear terms from Riccati backward pass
*/
void backward_pass_grad(TinySolver* solver) {
for (int i = solver->work->N - 2; i >= 0; i--) {
(solver->work->d.col(i)).noalias() = solver->cache->Quu_inv
* (solver->work->Bdyn.transpose() * solver->work->p.col(i + 1) + solver->work->r.col(i)
+ solver->cache->BPf);
(solver->work->p.col(i)).noalias() = solver->work->q.col(i)
+ solver->cache->AmBKt.lazyProduct(solver->work->p.col(i + 1))
- (solver->cache->Kinf.transpose()).lazyProduct(solver->work->r.col(i))
+ solver->cache->APf;
}
}
/**
* Use LQR feedback policy to roll out trajectory
*/
void forward_pass(TinySolver* solver) {
for (int i = 0; i < solver->work->N - 1; i++) {
(solver->work->u.col(i)).noalias() =
-solver->cache->Kinf.lazyProduct(solver->work->x.col(i)) - solver->work->d.col(i);
(solver->work->x.col(i + 1)).noalias() =
solver->work->Adyn.lazyProduct(solver->work->x.col(i))
+ solver->work->Bdyn.lazyProduct(solver->work->u.col(i)) + solver->work->fdyn;
}
}
/**
* Project a vector s onto the second order cone defined by mu
* @param s, mu
* @return projection onto cone if s is outside cone. Return s if s is inside cone.
*/
tinyVector project_soc(tinyVector s, float mu) {
tinytype u0 = s(Eigen::placeholders::last) * mu;
tinyVector u1 = s.head(s.rows() - 1);
float a = u1.norm();
tinyVector cone_origin(s.rows());
cone_origin.setZero();
if (a <= -u0) { // below cone
return cone_origin;
} else if (a <= u0) { // in cone
return s;
} else if (a >= abs(u0)) { // outside cone
Matrix<tinytype, 3, 1> u2(u1.size() + 1);
u2 << u1, a / mu;
return 0.5 * (1 + u0 / a) * u2;
} else {
return cone_origin;
}
}
/**
* Project a vector z onto a hyperplane defined by a^T z = b
* Implements equation (21): ΠH(z) = z - (⟨z, a⟩ b)/||a||² * a
* @param z Vector to project
* @param a Normal vector of the hyperplane
* @param b Offset of the hyperplane
* @return Projection of z onto the hyperplane
*/
tinyVector project_hyperplane(const tinyVector& z, const tinyVector& a, tinytype b) {
tinytype dist = (a.dot(z) - b) / a.squaredNorm();
return z - dist * a;
}
/**
* Project slack (auxiliary) variables into their feasible domain, defined by
* projection functions related to each constraint
* TODO: pass in meta information with each constraint assigning it to a
* projection function
*/
void update_slack(TinySolver* solver) {
// Update bound constraint slack variables for state
solver->work->vnew = solver->work->x + solver->work->g;
// Update bound constraint slack variables for input
solver->work->znew = solver->work->u + solver->work->y;
// Box constraints on state
if (solver->settings->en_state_bound) {
solver->work->vnew =
solver->work->x_max.cwiseMin(solver->work->x_min.cwiseMax(solver->work->vnew));
}
// Box constraints on input
if (solver->settings->en_input_bound) {
solver->work->znew =
solver->work->u_max.cwiseMin(solver->work->u_min.cwiseMax(solver->work->znew));
}
// Update second order cone slack variables for state
if (solver->settings->en_state_soc && solver->work->numStateCones > 0) {
solver->work->vcnew = solver->work->x + solver->work->gc;
}
// Update second order cone slack variables for input
if (solver->settings->en_input_soc && solver->work->numInputCones > 0) {
solver->work->zcnew = solver->work->u + solver->work->yc;
}
// Cone constraints on state
if (solver->settings->en_state_soc) {
for (int i = 0; i < solver->work->N; i++) {
for (int k = 0; k < solver->work->numStateCones; k++) {
int start = solver->work->Acx(k);
int num_xs = solver->work->qcx(k);
tinytype mu = solver->work->cx(k);
tinyVector col = solver->work->vcnew.block(start, i, num_xs, 1);
solver->work->vcnew.block(start, i, num_xs, 1) = project_soc(col, mu);
}
}
}
// Cone constraints on input
if (solver->settings->en_input_soc) {
for (int i = 0; i < solver->work->N - 1; i++) {
for (int k = 0; k < solver->work->numInputCones; k++) {
int start = solver->work->Acu(k);
int num_us = solver->work->qcu(k);
tinytype mu = solver->work->cu(k);
tinyVector col = solver->work->zcnew.block(start, i, num_us, 1);
solver->work->zcnew.block(start, i, num_us, 1) = project_soc(col, mu);
}
}
}
// Update linear constraint slack variables for state
if (solver->settings->en_state_linear) {
solver->work->vlnew = solver->work->x + solver->work->gl;
}
// Update linear constraint slack variables for input
if (solver->settings->en_input_linear) {
solver->work->zlnew = solver->work->u + solver->work->yl;
}
// Linear constraints on state
if (solver->settings->en_state_linear) {
for (int i = 0; i < solver->work->N; i++) {
for (int k = 0; k < solver->work->numStateLinear; k++) {
tinyVector a = solver->work->Alin_x.row(k);
tinytype b = solver->work->blin_x(k);
tinytype constraint_value = a.dot(solver->work->vlnew.col(i));
if (constraint_value > b) { // Only project if constraint is violated
solver->work->vlnew.col(i) =
project_hyperplane(solver->work->vlnew.col(i), a, b);
}
}
}
}
// Linear constraints on input
if (solver->settings->en_input_linear) {
for (int i = 0; i < solver->work->N - 1; i++) {
for (int k = 0; k < solver->work->numInputLinear; k++) {
tinyVector a = solver->work->Alin_u.row(k);
tinytype b = solver->work->blin_u(k);
tinytype constraint_value = a.dot(solver->work->zlnew.col(i));
if (constraint_value > b) { // Only project if constraint is violated
solver->work->zlnew.col(i) =
project_hyperplane(solver->work->zlnew.col(i), a, b);
}
}
}
}
}
/**
* Update next iteration of dual variables by performing the augmented
* lagrangian multiplier update
*/
void update_dual(TinySolver* solver) {
// Update bound constraint dual variables for state
solver->work->g = solver->work->g + solver->work->x - solver->work->vnew;
// Update bound constraint dual variables for input
solver->work->y = solver->work->y + solver->work->u - solver->work->znew;
// Update second order cone dual variables for state
if (solver->settings->en_state_soc && solver->work->numStateCones > 0) {
solver->work->gc = solver->work->gc + solver->work->x - solver->work->vcnew;
}
// Update second order cone dual variables for input
if (solver->settings->en_input_soc && solver->work->numInputCones > 0) {
solver->work->yc = solver->work->yc + solver->work->u - solver->work->zcnew;
}
// Update linear constraint dual variables for state
if (solver->settings->en_state_linear) {
solver->work->gl = solver->work->gl + solver->work->x - solver->work->vlnew;
}
// Update linear constraint dual variables for input
if (solver->settings->en_input_linear) {
solver->work->yl = solver->work->yl + solver->work->u - solver->work->zlnew;
}
}
/**
* Update linear control cost terms in the Riccati feedback using the changing
* slack and dual variables from ADMM
*/
void update_linear_cost(TinySolver* solver) {
// Update state cost terms
solver->work->q = -(solver->work->Xref.array().colwise() * solver->work->Q.array());
(solver->work->q).noalias() -= solver->cache->rho * (solver->work->vnew - solver->work->g);
if (solver->settings->en_state_soc && solver->work->numStateCones > 0) {
(solver->work->q).noalias() -=
solver->cache->rho * (solver->work->vcnew - solver->work->gc);
}
if (solver->settings->en_state_linear) {
(solver->work->q).noalias() -=
solver->cache->rho * (solver->work->vlnew - solver->work->gl);
}
// Update input cost terms
solver->work->r = -(solver->work->Uref.array().colwise() * solver->work->R.array());
(solver->work->r).noalias() -= solver->cache->rho * (solver->work->znew - solver->work->y);
if (solver->settings->en_input_soc && solver->work->numInputCones > 0) {
(solver->work->r).noalias() -=
solver->cache->rho * (solver->work->zcnew - solver->work->yc);
}
if (solver->settings->en_input_linear) {
(solver->work->r).noalias() -=
solver->cache->rho * (solver->work->zlnew - solver->work->yl);
}
// Update terminal cost
solver->work->p.col(solver->work->N - 1) =
-(solver->work->Xref.col(solver->work->N - 1).transpose().lazyProduct(solver->cache->Pinf));
(solver->work->p.col(solver->work->N - 1)).noalias() -= solver->cache->rho
* (solver->work->vnew.col(solver->work->N - 1) - solver->work->g.col(solver->work->N - 1));
if (solver->settings->en_state_soc && solver->work->numStateCones > 0) {
solver->work->p.col(solver->work->N - 1) -= solver->cache->rho
* (solver->work->vcnew.col(solver->work->N - 1)
- solver->work->gc.col(solver->work->N - 1));
}
if (solver->settings->en_state_linear) {
solver->work->p.col(solver->work->N - 1) -= solver->cache->rho
* (solver->work->vlnew.col(solver->work->N - 1)
- solver->work->gl.col(solver->work->N - 1));
}
}
/**
* Check for termination condition by evaluating whether the largest absolute
* primal and dual residuals for states and inputs are below threhold.
*/
bool termination_condition(TinySolver* solver) {
if (solver->work->iter % solver->settings->check_termination == 0) {
solver->work->primal_residual_state =
(solver->work->x - solver->work->vnew).cwiseAbs().maxCoeff();
solver->work->dual_residual_state =
((solver->work->v - solver->work->vnew).cwiseAbs().maxCoeff()) * solver->cache->rho;
solver->work->primal_residual_input =
(solver->work->u - solver->work->znew).cwiseAbs().maxCoeff();
solver->work->dual_residual_input =
((solver->work->z - solver->work->znew).cwiseAbs().maxCoeff()) * solver->cache->rho;
if (solver->work->primal_residual_state < solver->settings->abs_pri_tol
&& solver->work->primal_residual_input < solver->settings->abs_pri_tol
&& solver->work->dual_residual_state < solver->settings->abs_dua_tol
&& solver->work->dual_residual_input < solver->settings->abs_dua_tol)
{
return true;
}
}
return false;
}
int solve(TinySolver* solver) {
// Initialize variables
solver->solution->solved = 0;
solver->solution->iter = 0;
solver->work->status = 11; // TINY_UNSOLVED
solver->work->iter = 0;
// Setup for adaptive rho
RhoAdapter adapter;
adapter.rho_min = solver->settings->adaptive_rho_min;
adapter.rho_max = solver->settings->adaptive_rho_max;
adapter.clip = solver->settings->adaptive_rho_enable_clipping;
RhoBenchmarkResult rho_result;
// Store previous values for residuals
tinyMatrix v_prev = solver->work->vnew;
tinyMatrix z_prev = solver->work->znew;
// Initialize SOC slack variables if needed
if (solver->settings->en_state_soc && solver->work->numStateCones > 0) {
solver->work->vcnew = solver->work->x;
}
if (solver->settings->en_input_soc && solver->work->numInputCones > 0) {
solver->work->zcnew = solver->work->u;
}
// Initialize linear constraint slack variables if needed
if (solver->settings->en_state_linear) {
solver->work->vlnew = solver->work->x;
}
if (solver->settings->en_input_linear) {
solver->work->zlnew = solver->work->u;
}
for (int i = 0; i < solver->settings->max_iter; i++) {
// Solve linear system with Riccati and roll out to get new trajectory
backward_pass_grad(solver);
forward_pass(solver);
// Project slack variables into feasible domain
update_slack(solver);
// Compute next iteration of dual variables
update_dual(solver);
// Update linear control cost terms using reference trajectory, duals, and slack variables
update_linear_cost(solver);
solver->work->iter += 1;
// Handle adaptive rho if enabled
if (solver->settings->adaptive_rho) {
// Calculate residuals for adaptive rho
tinytype pri_res_input = (solver->work->u - solver->work->znew).cwiseAbs().maxCoeff();
tinytype pri_res_state = (solver->work->x - solver->work->vnew).cwiseAbs().maxCoeff();
tinytype dua_res_input =
solver->cache->rho * (solver->work->znew - z_prev).cwiseAbs().maxCoeff();
tinytype dua_res_state =
solver->cache->rho * (solver->work->vnew - v_prev).cwiseAbs().maxCoeff();
// Update rho every 5 iterations
if (i > 0 && i % 5 == 0) {
benchmark_rho_adaptation(
&adapter,
solver->work->x,
solver->work->u,
solver->work->vnew,
solver->work->znew,
solver->work->g,
solver->work->y,
solver->cache,
solver->work,
solver->work->N,
&rho_result
);
// Update matrices using Taylor expansion
update_matrices_with_derivatives(solver->cache, rho_result.final_rho);
}
}
// Store previous values for next iteration
z_prev = solver->work->znew;
v_prev = solver->work->vnew;
// Check for whether cost is minimized by calculating residuals
if (termination_condition(solver)) {
solver->work->status = 1; // TINY_SOLVED
// Save solution
solver->solution->iter = solver->work->iter;
solver->solution->solved = 1;
solver->solution->x = solver->work->vnew;
solver->solution->u = solver->work->znew;
// std::cout << "Solver converged in " << solver->work->iter << " iterations" << std::endl;
return 0;
}
// Save previous slack variables
solver->work->v = solver->work->vnew;
solver->work->z = solver->work->znew;
}
solver->solution->iter = solver->work->iter;
solver->solution->solved = 0;
solver->solution->x = solver->work->vnew;
solver->solution->u = solver->work->znew;
return 1;
}
} /* extern "C" */

View File

@@ -0,0 +1,37 @@
#pragma once
#include "types.hpp"
#ifdef __cplusplus
extern "C" {
#endif
int solve(TinySolver* solver);
void update_primal(TinySolver* solver);
void backward_pass_grad(TinySolver* solver);
void forward_pass(TinySolver* solver);
void update_slack(TinySolver* solver);
void update_dual(TinySolver* solver);
void update_linear_cost(TinySolver* solver);
bool termination_condition(TinySolver* solver);
/**
* Project a vector s onto the second order cone defined by mu
* @param s, mu
* @return projection onto cone if s is outside cone. Return s if s is inside cone.
*/
tinyVector project_soc(tinyVector s, float mu);
/**
* Project a vector z onto a hyperplane defined by a^T z = b
* Implements equation (21): ΠH(z) = z - (⟨z, a⟩ b)/||a||² * a
* @param z Vector to project
* @param a Normal vector of the hyperplane
* @param b Offset of the hyperplane
* @return Projection of z onto the hyperplane
*/
tinyVector project_hyperplane(const tinyVector& z, const tinyVector& a, tinytype b);
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,466 @@
#include <ctype.h>
#include <dirent.h>
#include <stdio.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <time.h>
#include <unistd.h>
//#include <error.h>
#include "error.hpp"
#include <Eigen/Dense>
#include <iostream>
// #include "types.hpp"
#include "codegen.hpp"
#ifdef __MINGW32__
#include <direct.h>
inline int mkdir(const char* pathname, int flags) {
return _mkdir(pathname);
}
#endif
#ifdef __cplusplus
extern "C" {
#endif
/* Define the maximum allowed length of the path (directory + filename + extension) */
#define PATH_LENGTH 2048
using namespace Eigen;
static void print_matrix(FILE* f, MatrixXd mat, int num_elements) {
// Check if matrix is uninitialized or too small
if (mat.size() == 0 || mat.size() < num_elements) {
// Print zeros for all elements
for (int i = 0; i < num_elements; i++) {
fprintf(f, "(tinytype)0.0000000000000000");
if (i < num_elements - 1)
fprintf(f, ",");
}
return;
}
// Matrix is properly initialized and has enough elements
for (int i = 0; i < num_elements; i++) {
fprintf(f, "(tinytype)%.16f", mat.reshaped<RowMajor>()[i]);
if (i < num_elements - 1)
fprintf(f, ",");
}
}
static void create_directory(const char* dir, int verbose) {
// Attempt to create directory
if (mkdir(dir, S_IRWXU | S_IRWXG | S_IROTH)) {
if (errno == EEXIST) { // Skip if directory already exists
if (verbose)
std::cout << dir << " already exists, skipping." << std::endl;
} else {
ERROR_MSG(EXIT_FAILURE, "Failed to create directory %s", dir);
}
}
}
// TODO: Make this fail if tiny_setup has not already been called
int tiny_codegen(TinySolver* solver, const char* output_dir, int verbose) {
if (!solver) {
std::cout << "Error in tiny_codegen: solver is nullptr" << std::endl;
return 1;
}
int status = 0;
status |= codegen_create_directories(output_dir, verbose);
status |= codegen_data_header(output_dir, verbose);
status |= codegen_data_source(solver, output_dir, verbose);
status |= codegen_example(output_dir, verbose);
return status;
}
int tiny_codegen_with_sensitivity(
TinySolver* solver,
const char* output_dir,
tinyMatrix* dK,
tinyMatrix* dP,
tinyMatrix* dC1,
tinyMatrix* dC2,
int verbose
) {
if (!solver) {
std::cout << "Error in tiny_codegen_with_sensitivity: solver is nullptr" << std::endl;
return 1;
}
// Only store sensitivity matrices if adaptive rho is enabled
if (solver->settings->adaptive_rho) {
// Store the sensitivity matrices in the solver's cache
solver->cache->dKinf_drho = *dK;
solver->cache->dPinf_drho = *dP;
solver->cache->dC1_drho = *dC1;
solver->cache->dC2_drho = *dC2;
}
// Call the regular codegen function which will now include the sensitivity matrices if adaptive_rho is enabled
return tiny_codegen(solver, output_dir, verbose);
}
// Create code generation folder structure in whichever directory the executable calling tiny_codegen was called
int codegen_create_directories(const char* output_dir, int verbose) {
// Create output folder (root folder for code generation)
create_directory(output_dir, verbose);
// Create src folder
char src_dir[PATH_LENGTH];
sprintf(src_dir, "%s/src/", output_dir);
create_directory(src_dir, verbose);
// Create tinympc folder
char tinympc_dir[PATH_LENGTH];
sprintf(tinympc_dir, "%s/tinympc/", output_dir);
create_directory(tinympc_dir, verbose);
// // Create include folder
// char inc_dir[PATH_LENGTH];
// sprintf(inc_dir, "%s/include/", output_dir);
// create_directory(inc_dir, verbose);
return EXIT_SUCCESS;
}
// Create inc/tiny_data.hpp file
int codegen_data_header(const char* output_dir, int verbose) {
char data_hpp_fname[PATH_LENGTH];
FILE* data_hpp_f;
sprintf(data_hpp_fname, "%s/tinympc/tiny_data.hpp", output_dir);
// Open data header file
data_hpp_f = fopen(data_hpp_fname, "w+");
if (data_hpp_f == NULL)
ERROR_MSG(EXIT_FAILURE, "Failed to open file %s", data_hpp_fname);
// Preamble
time_t start_time;
time(&start_time);
fprintf(data_hpp_f, "/*\n");
fprintf(data_hpp_f, " * This file was autogenerated by TinyMPC on %s", ctime(&start_time));
fprintf(data_hpp_f, " */\n\n");
fprintf(data_hpp_f, "#pragma once\n\n");
fprintf(data_hpp_f, "#include \"types.hpp\"\n\n");
fprintf(data_hpp_f, "#ifdef __cplusplus\n");
fprintf(data_hpp_f, "extern \"C\" {\n");
fprintf(data_hpp_f, "#endif\n\n");
fprintf(data_hpp_f, "extern TinySolver tiny_solver;\n\n");
fprintf(data_hpp_f, "#ifdef __cplusplus\n");
fprintf(data_hpp_f, "}\n");
fprintf(data_hpp_f, "#endif\n");
// Close codegen data header file
fclose(data_hpp_f);
if (verbose) {
printf("Data header generated in %s\n", data_hpp_fname);
}
return 0;
}
// Create src/tiny_data.cpp file
int codegen_data_source(TinySolver* solver, const char* output_dir, int verbose) {
char data_cpp_fname[PATH_LENGTH];
FILE* data_cpp_f;
int nx = solver->work->nx;
int nu = solver->work->nu;
int N = solver->work->N;
sprintf(data_cpp_fname, "%s/src/tiny_data.cpp", output_dir);
// Open data source file
data_cpp_f = fopen(data_cpp_fname, "w+");
if (data_cpp_f == NULL)
ERROR_MSG(EXIT_FAILURE, "Failed to open file %s", data_cpp_fname);
// Preamble
time_t start_time;
time(&start_time);
fprintf(data_cpp_f, "/*\n");
fprintf(data_cpp_f, " * This file was autogenerated by TinyMPC on %s", ctime(&start_time));
fprintf(data_cpp_f, " */\n\n");
// Open extern C
fprintf(data_cpp_f, "#include \"tinympc/tiny_data.hpp\"\n\n");
fprintf(data_cpp_f, "#ifdef __cplusplus\n");
fprintf(data_cpp_f, "extern \"C\" {\n");
fprintf(data_cpp_f, "#endif\n\n");
// Solution
fprintf(data_cpp_f, "/* Solution */\n");
fprintf(data_cpp_f, "TinySolution solution = {\n");
fprintf(data_cpp_f, "\t%d,\t\t// iter\n", solver->solution->iter);
fprintf(data_cpp_f, "\t%d,\t\t// solved\n", solver->solution->solved);
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
fprintf(data_cpp_f, ").finished(),\t// x\n"); // x solution
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// x\n"); // u solution
fprintf(data_cpp_f, "};\n\n");
// Cache
fprintf(data_cpp_f, "/* Matrices that must be recomputed with changes in time step, rho */\n");
fprintf(data_cpp_f, "TinyCache cache = {\n");
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// rho (step size/penalty)\n", solver->cache->rho);
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, nx);
print_matrix(data_cpp_f, solver->cache->Kinf, nu * nx);
fprintf(data_cpp_f, ").finished(),\t// Kinf\n"); // Kinf
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
print_matrix(data_cpp_f, solver->cache->Pinf, nx * nx);
fprintf(data_cpp_f, ").finished(),\t// Pinf\n"); // Pinf
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, nu);
print_matrix(data_cpp_f, solver->cache->Quu_inv, nu * nu);
fprintf(data_cpp_f, ").finished(),\t// Quu_inv\n"); // Quu_inv
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
print_matrix(data_cpp_f, solver->cache->AmBKt, nx * nx);
fprintf(data_cpp_f, ").finished(),\t// AmBKt\n"); // AmBKt
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
print_matrix(data_cpp_f, solver->cache->C1, nx * nx);
fprintf(data_cpp_f, ").finished(),\t// C1\n"); // C1
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
print_matrix(data_cpp_f, solver->cache->C2, nx * nx);
fprintf(data_cpp_f, ").finished()"); // C2, no comma if no sensitivity matrices
// Only print sensitivity matrices if adaptive rho is enabled
if (solver->settings->adaptive_rho) {
fprintf(data_cpp_f, ",\t// C2\n"); // Add comma and comment for C2 if we have more matrices
// Add sensitivity matrices within the struct initialization
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, nx);
print_matrix(data_cpp_f, solver->cache->dKinf_drho, nu * nx);
fprintf(data_cpp_f, ").finished(),\t// dKinf_drho\n"); // dKinf_drho
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
print_matrix(data_cpp_f, solver->cache->dPinf_drho, nx * nx);
fprintf(data_cpp_f, ").finished(),\t// dPinf_drho\n"); // dPinf_drho
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
print_matrix(data_cpp_f, solver->cache->dC1_drho, nx * nx);
fprintf(data_cpp_f, ").finished(),\t// dC1_drho\n"); // dC1_drho
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
print_matrix(data_cpp_f, solver->cache->dC2_drho, nx * nx);
fprintf(data_cpp_f, ").finished()\t// dC2_drho\n"); // dC2_drho
} else {
fprintf(data_cpp_f, "\t// C2\n"); // Just add comment for C2
}
fprintf(data_cpp_f, "};\n\n");
// Settings
fprintf(data_cpp_f, "/* User settings */\n");
fprintf(data_cpp_f, "TinySettings settings = {\n");
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// primal tolerance\n", solver->settings->abs_pri_tol);
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// dual tolerance\n", solver->settings->abs_dua_tol);
fprintf(data_cpp_f, "\t%d,\t\t// max iterations\n", solver->settings->max_iter);
fprintf(
data_cpp_f,
"\t%d,\t\t// iterations per termination check\n",
solver->settings->check_termination
);
fprintf(data_cpp_f, "\t%d,\t\t// enable state constraints\n", solver->settings->en_state_bound);
fprintf(data_cpp_f, "\t%d\t\t// enable input constraints\n", solver->settings->en_input_bound);
fprintf(data_cpp_f, "};\n\n");
// Workspace
fprintf(data_cpp_f, "/* Problem variables */\n");
fprintf(data_cpp_f, "TinyWorkspace work = {\n");
fprintf(data_cpp_f, "\t%d,\t// Number of states\n", nx);
fprintf(data_cpp_f, "\t%d,\t// Number of control inputs\n", nu);
fprintf(data_cpp_f, "\t%d,\t// Number of knotpoints in the horizon\n", N);
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
fprintf(data_cpp_f, ").finished(),\t// x\n"); // x
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// u\n"); // u
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
fprintf(data_cpp_f, ").finished(),\t// q\n"); // q
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// r\n"); // r
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
fprintf(data_cpp_f, ").finished(),\t// p\n"); // p
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// d\n"); // d
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
fprintf(data_cpp_f, ").finished(),\t// v\n"); // v
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
fprintf(data_cpp_f, ").finished(),\t// vnew\n"); // vnew
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// z\n"); // z
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// znew\n"); // znew
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
fprintf(data_cpp_f, ").finished(),\t// g\n"); // g
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// y\n"); // y
fprintf(data_cpp_f, "\t(tinyVector(%d) << ", nx);
print_matrix(data_cpp_f, solver->work->Q, nx);
fprintf(data_cpp_f, ").finished(),\t// Q\n"); // Q
fprintf(data_cpp_f, "\t(tinyVector(%d) << ", nu);
print_matrix(data_cpp_f, solver->work->R, nu);
fprintf(data_cpp_f, ").finished(),\t// R\n"); // R
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
print_matrix(data_cpp_f, solver->work->Adyn, nx * nx);
fprintf(data_cpp_f, ").finished(),\t// Adyn\n"); // Adyn
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nu);
print_matrix(data_cpp_f, solver->work->Bdyn, nx * nu);
fprintf(data_cpp_f, ").finished(),\t// Bdyn\n"); // Bdyn
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, solver->work->x_min, nx * N);
fprintf(data_cpp_f, ").finished(),\t// x_min\n"); // x_min
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, solver->work->x_max, nx * N);
fprintf(data_cpp_f, ").finished(),\t// x_max\n"); // x_max
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, solver->work->u_min, nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// u_min\n"); // u_min
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, solver->work->u_max, nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// u_max\n"); // u_max
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
fprintf(data_cpp_f, ").finished(),\t// Xref\n"); // Xref
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// Uref\n"); // Uref
fprintf(data_cpp_f, "\t(tinyVector(%d) << ", nu);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, 1), nu);
fprintf(data_cpp_f, ").finished(),\t// Qu\n"); // Qu
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// state primal residual\n", 0.0);
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// input primal residual\n", 0.0);
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// state dual residual\n", 0.0);
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// input dual residual\n", 0.0);
fprintf(data_cpp_f, "\t%d,\t// solve status\n", 0);
fprintf(data_cpp_f, "\t%d,\t// solve iteration\n", 0);
fprintf(data_cpp_f, "};\n\n");
// Write solver struct definition to workspace file
fprintf(data_cpp_f, "TinySolver tiny_solver = {&solution, &settings, &cache, &work};\n\n");
// Close extern C
fprintf(data_cpp_f, "#ifdef __cplusplus\n");
fprintf(data_cpp_f, "}\n");
fprintf(data_cpp_f, "#endif\n\n");
// Close codegen data file
fclose(data_cpp_f);
if (verbose) {
printf("Data generated in %s\n", data_cpp_fname);
}
return 0;
}
int codegen_example(const char* output_dir, int verbose) {
char example_cpp_fname[PATH_LENGTH];
FILE* example_cpp_f;
sprintf(example_cpp_fname, "%s/src/tiny_main.cpp", output_dir);
// Open example file
example_cpp_f = fopen(example_cpp_fname, "w+");
if (example_cpp_f == NULL)
ERROR_MSG(EXIT_FAILURE, "Failed to open file %s", example_cpp_fname);
// Preamble
time_t start_time;
time(&start_time);
fprintf(example_cpp_f, "/*\n");
fprintf(example_cpp_f, " * This file was autogenerated by TinyMPC on %s", ctime(&start_time));
fprintf(example_cpp_f, " */\n\n");
fprintf(example_cpp_f, "#include <iostream>\n\n");
fprintf(example_cpp_f, "#include <tinympc/tiny_api.hpp>\n");
fprintf(example_cpp_f, "#include <tinympc/tiny_data.hpp>\n\n");
fprintf(example_cpp_f, "using namespace Eigen;\n");
fprintf(example_cpp_f, "IOFormat TinyFmt(4, 0, \", \", \"\\n\", \"[\", \"]\");\n\n");
fprintf(example_cpp_f, "#ifdef __cplusplus\n");
fprintf(example_cpp_f, "extern \"C\" {\n");
fprintf(example_cpp_f, "#endif\n\n");
fprintf(example_cpp_f, "int main()\n");
fprintf(example_cpp_f, "{\n");
fprintf(example_cpp_f, "\tint exitflag = 1;\n");
fprintf(example_cpp_f, "\t// Double check some data\n");
fprintf(example_cpp_f, "\tstd::cout << \"rho: \" << tiny_solver.cache->rho << std::endl;\n");
fprintf(
example_cpp_f,
"\tstd::cout << \"\\nmax iters: \" << tiny_solver.settings->max_iter << std::endl;\n"
);
fprintf(
example_cpp_f,
"\tstd::cout << \"\\nState transition matrix:\\n\" << tiny_solver.work->Adyn.format(TinyFmt) << std::endl;\n"
);
fprintf(
example_cpp_f,
"\tstd::cout << \"\\nInput/control matrix:\\n\" << tiny_solver.work->Bdyn.format(TinyFmt) << std::endl;\n\n"
);
fprintf(
example_cpp_f,
"\t// Visit https://tinympc.org/ to see how to set the initial condition and update the reference trajectory.\n\n"
);
fprintf(example_cpp_f, "\tstd::cout << \"\\nSolving...\\n\" << std::endl;\n\n");
fprintf(example_cpp_f, "\texitflag = tiny_solve(&tiny_solver);\n\n");
fprintf(example_cpp_f, "\tif (exitflag == 0) printf(\"Hooray! Solved with no error!\\n\");\n");
fprintf(example_cpp_f, "\telse printf(\"Oops! Something went wrong!\\n\");\n");
fprintf(example_cpp_f, "\treturn 0;\n");
fprintf(example_cpp_f, "}\n\n");
fprintf(example_cpp_f, "#ifdef __cplusplus\n");
fprintf(example_cpp_f, "} /* extern \"C\" */\n");
fprintf(example_cpp_f, "#endif\n");
// Close codegen example main file
fclose(example_cpp_f);
if (verbose) {
printf("Example tinympc main generated in %s\n", example_cpp_fname);
}
return 0;
}
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,28 @@
#pragma once
#include "types.hpp"
#ifdef __cplusplus
extern "C" {
#endif
int tiny_codegen(TinySolver* solver, const char* output_dir, int verbose);
int tiny_codegen_with_sensitivity(
TinySolver* solver,
const char* output_dir,
tinyMatrix* dK,
tinyMatrix* dP,
tinyMatrix* dC1,
tinyMatrix* dC2,
int verbose
);
int codegen_create_directories(const char* output_dir, int verbose);
int codegen_data_header(const char* output_dir, int verbose);
int codegen_data_source(TinySolver* solver, const char* output_dir, int verbose);
int codegen_example(const char* output_dir, int verbose);
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,29 @@
#pragma once
#ifdef __cplusplus
extern "C" {
#endif
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
// #if defined(__linux__) || defined(__unix__)// Check if Linux
// #include <error.h>
// #define ERROR_MSG(exit_code, format, ...) error(exit_code, errno, format, ##__VA_ARGS__)
// #elif defined(__APPLE__) || defined(__MACH__) // Check if macOS
#define ERROR_MSG(exit_code, format, ...) \
{ \
fprintf(stderr, format ": %s\n", ##__VA_ARGS__, strerror(errno)); \
exit(exit_code); \
}
// #else
// #error "Unsupported operating system"
// #endif
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,252 @@
#include "rho_benchmark.hpp"
#include <algorithm>
#include <cmath>
#include <iostream>
#ifdef ARDUINO
#include <Arduino.h>
#else
// For non-Arduino platforms
uint32_t micros() {
return 0; // Replace with appropriate timing function
}
#endif
void initialize_format_matrices(RhoAdapter* adapter, int nx, int nu, int N) {
// Calculate dimensions
int x_decision_size = nx * N + nu * (N - 1);
int constraint_rows = (nx + nu) * (N - 1);
// Pre-allocate matrices
adapter->A_matrix = tinyMatrix::Zero(constraint_rows, x_decision_size);
adapter->z_vector = tinyMatrix::Zero(constraint_rows, 1);
adapter->y_vector = tinyMatrix::Zero(constraint_rows, 1);
adapter->x_decision = tinyMatrix::Zero(x_decision_size, 1);
// Pre-compute P matrix structure
adapter->P_matrix = tinyMatrix::Zero(x_decision_size, x_decision_size);
adapter->q_vector = tinyMatrix::Zero(x_decision_size, 1);
// Pre-allocate residual computation matrices
adapter->Ax_vector = tinyMatrix::Zero(constraint_rows, 1);
adapter->r_prim_vector = tinyMatrix::Zero(constraint_rows, 1);
adapter->r_dual_vector = tinyMatrix::Zero(x_decision_size, 1);
adapter->Px_vector = tinyMatrix::Zero(x_decision_size, 1);
adapter->ATy_vector = tinyMatrix::Zero(x_decision_size, 1);
// Store dimensions
adapter->format_nx = nx;
adapter->format_nu = nu;
adapter->format_N = N;
adapter->matrices_initialized = true;
}
void format_matrices(
RhoAdapter* adapter,
const tinyMatrix& x_prev,
const tinyMatrix& u_prev,
const tinyMatrix& v_prev,
const tinyMatrix& z_prev,
const tinyMatrix& g_prev,
const tinyMatrix& y_prev,
TinyCache* cache,
TinyWorkspace* work,
int N
) {
if (!adapter->matrices_initialized) {
initialize_format_matrices(adapter, x_prev.rows(), u_prev.rows(), N);
}
int nx = adapter->format_nx;
int nu = adapter->format_nu;
// Fill x_decision
int x_idx = 0;
for (int i = 0; i < N; i++) {
adapter->x_decision.block(x_idx, 0, nx, 1) = x_prev.col(i);
x_idx += nx;
if (i < N - 1) {
adapter->x_decision.block(x_idx, 0, nu, 1) = u_prev.col(i);
x_idx += nu;
}
}
// Clear A matrix for reuse
adapter->A_matrix.setZero();
// Fill A matrix with dynamics and input constraints
for (int i = 0; i < N - 1; i++) {
// Input constraints
int row_start = i * nu;
int col_start = i * (nx + nu) + nx;
adapter->A_matrix.block(row_start, col_start, nu, nu) = tinyMatrix::Identity(nu, nu);
// Dynamics constraints
row_start = (N - 1) * nu + i * nx;
col_start = i * (nx + nu);
adapter->A_matrix.block(row_start, col_start, nx, nx) = work->Adyn;
adapter->A_matrix.block(row_start, col_start + nx, nx, nu) = work->Bdyn;
int next_state_idx = col_start + nx + nu;
if (next_state_idx < adapter->A_matrix.cols()) {
adapter->A_matrix.block(row_start, next_state_idx, nx, nx) =
-tinyMatrix::Identity(nx, nx);
}
}
// Fill z and y vectors
for (int i = 0; i < N - 1; i++) {
adapter->z_vector.block(i * nu, 0, nu, 1) = z_prev.col(i);
adapter->z_vector.block((N - 1) * nu + i * nx, 0, nx, 1) = v_prev.col(i + 1);
adapter->y_vector.block(i * nu, 0, nu, 1) = y_prev.col(i);
adapter->y_vector.block((N - 1) * nu + i * nx, 0, nx, 1) = g_prev.col(i + 1);
}
// Build P matrix (cost matrix)
adapter->P_matrix.setZero();
// Fill diagonal blocks
x_idx = 0;
for (int i = 0; i < N; i++) {
// State cost
if (i == N - 1) {
adapter->P_matrix.block(x_idx, x_idx, nx, nx) = cache->Pinf;
} else {
adapter->P_matrix.block(x_idx, x_idx, nx, nx) = work->Q.asDiagonal();
}
x_idx += nx;
// Input cost
if (i < N - 1) {
adapter->P_matrix.block(x_idx, x_idx, nu, nu) = work->R.asDiagonal();
x_idx += nu;
}
}
// Create q vector (linear cost vector)
x_idx = 0;
for (int i = 0; i < N; i++) {
// For simplicity, we'll use zero reference for now
// In a real implementation, you'd use your reference trajectory
tinyMatrix x_ref = tinyMatrix::Zero(nx, 1);
tinyMatrix delta_x = x_prev.col(i) - x_ref;
adapter->q_vector.block(x_idx, 0, nx, 1) = work->Q.asDiagonal() * delta_x;
x_idx += nx;
if (i < N - 1) {
// For simplicity, we'll use zero reference for now
tinyMatrix u_ref = tinyMatrix::Zero(nu, 1);
tinyMatrix delta_u = u_prev.col(i) - u_ref;
adapter->q_vector.block(x_idx, 0, nu, 1) = work->R.asDiagonal() * delta_u;
x_idx += nu;
}
}
}
void compute_residuals(
RhoAdapter* adapter,
tinytype* pri_res,
tinytype* dual_res,
tinytype* pri_norm,
tinytype* dual_norm
) {
// Compute Ax
adapter->Ax_vector = adapter->A_matrix * adapter->x_decision;
// Compute primal residual
adapter->r_prim_vector = adapter->Ax_vector - adapter->z_vector;
*pri_res = adapter->r_prim_vector.cwiseAbs().maxCoeff();
*pri_norm =
std::max(adapter->Ax_vector.cwiseAbs().maxCoeff(), adapter->z_vector.cwiseAbs().maxCoeff());
// Compute dual residual components
adapter->Px_vector = adapter->P_matrix * adapter->x_decision;
adapter->ATy_vector = adapter->A_matrix.transpose() * adapter->y_vector;
// Compute full dual residual
adapter->r_dual_vector = adapter->Px_vector + adapter->q_vector + adapter->ATy_vector;
*dual_res = adapter->r_dual_vector.cwiseAbs().maxCoeff();
// Compute normalization
*dual_norm = std::max(
std::max(
adapter->Px_vector.cwiseAbs().maxCoeff(),
adapter->ATy_vector.cwiseAbs().maxCoeff()
),
adapter->q_vector.cwiseAbs().maxCoeff()
);
}
tinytype predict_rho(
RhoAdapter* adapter,
tinytype pri_res,
tinytype dual_res,
tinytype pri_norm,
tinytype dual_norm,
tinytype current_rho
) {
const tinytype eps = 1e-10;
tinytype normalized_pri = pri_res / (pri_norm + eps);
tinytype normalized_dual = dual_res / (dual_norm + eps);
tinytype ratio = normalized_pri / (normalized_dual + eps);
tinytype new_rho = current_rho * std::sqrt(ratio);
if (adapter->clip) {
new_rho = std::min(std::max(new_rho, adapter->rho_min), adapter->rho_max);
}
return new_rho;
}
void update_matrices_with_derivatives(TinyCache* cache, tinytype new_rho) {
tinytype delta_rho = new_rho - cache->rho;
cache->Kinf = cache->Kinf + delta_rho * cache->dKinf_drho;
cache->Pinf = cache->Pinf + delta_rho * cache->dPinf_drho;
cache->C1 = cache->C1 + delta_rho * cache->dC1_drho;
cache->C2 = cache->C2 + delta_rho * cache->dC2_drho;
cache->rho = new_rho;
}
void benchmark_rho_adaptation(
RhoAdapter* adapter,
const tinyMatrix& x_prev,
const tinyMatrix& u_prev,
const tinyMatrix& v_prev,
const tinyMatrix& z_prev,
const tinyMatrix& g_prev,
const tinyMatrix& y_prev,
TinyCache* cache,
TinyWorkspace* work,
int N,
RhoBenchmarkResult* result
) {
uint32_t start_time = micros();
// Format matrices
format_matrices(adapter, x_prev, u_prev, v_prev, z_prev, g_prev, y_prev, cache, work, N);
// Compute residuals
tinytype pri_res, dual_res, pri_norm, dual_norm;
compute_residuals(adapter, &pri_res, &dual_res, &pri_norm, &dual_norm);
// Predict new rho
tinytype new_rho = predict_rho(adapter, pri_res, dual_res, pri_norm, dual_norm, cache->rho);
// Update matrices
update_matrices_with_derivatives(cache, new_rho);
// Store results
result->time_us = micros() - start_time;
result->initial_rho = cache->rho;
result->final_rho = new_rho;
result->pri_res = pri_res;
result->dual_res = dual_res;
result->pri_norm = pri_norm;
result->dual_norm = dual_norm;
}

View File

@@ -0,0 +1,94 @@
#pragma once
#include "types.hpp"
#include <cstdint>
struct RhoAdapter {
tinytype rho_min;
tinytype rho_max;
bool clip;
bool matrices_initialized;
// Pre-allocated matrices for formatting
tinyMatrix A_matrix;
tinyMatrix z_vector;
tinyMatrix y_vector;
tinyMatrix x_decision;
tinyMatrix P_matrix;
tinyMatrix q_vector;
// Pre-allocated matrices for residual computation
tinyMatrix Ax_vector;
tinyMatrix r_prim_vector;
tinyMatrix r_dual_vector;
tinyMatrix Px_vector;
tinyMatrix ATy_vector;
// Dimensions
int format_nx;
int format_nu;
int format_N;
};
struct RhoBenchmarkResult {
uint32_t time_us;
tinytype initial_rho;
tinytype final_rho;
tinytype pri_res;
tinytype dual_res;
tinytype pri_norm;
tinytype dual_norm;
};
// Initialize matrices for formatting
void initialize_format_matrices(RhoAdapter* adapter, int nx, int nu, int N);
// Format matrices for residual computation
void format_matrices(
RhoAdapter* adapter,
const tinyMatrix& x_prev,
const tinyMatrix& u_prev,
const tinyMatrix& v_prev,
const tinyMatrix& z_prev,
const tinyMatrix& g_prev,
const tinyMatrix& y_prev,
TinyCache* cache,
TinyWorkspace* work,
int N
);
// Compute residuals
void compute_residuals(
RhoAdapter* adapter,
tinytype* pri_res,
tinytype* dual_res,
tinytype* pri_norm,
tinytype* dual_norm
);
// Predict new rho value
tinytype predict_rho(
RhoAdapter* adapter,
tinytype pri_res,
tinytype dual_res,
tinytype pri_norm,
tinytype dual_norm,
tinytype current_rho
);
// Update matrices using derivatives
void update_matrices_with_derivatives(TinyCache* cache, tinytype new_rho);
// Main benchmark function
void benchmark_rho_adaptation(
RhoAdapter* adapter,
const tinyMatrix& x_prev,
const tinyMatrix& u_prev,
const tinyMatrix& v_prev,
const tinyMatrix& z_prev,
const tinyMatrix& g_prev,
const tinyMatrix& y_prev,
TinyCache* cache,
TinyWorkspace* work,
int N,
RhoBenchmarkResult* result
);

View File

@@ -0,0 +1,876 @@
#include "tiny_api.hpp"
#include "tiny_api_constants.hpp"
#include <iostream>
#ifdef __cplusplus
extern "C" {
#endif
using namespace Eigen;
IOFormat TinyApiFmt(4, 0, ", ", "\n", "[", "]");
static int
check_dimension(std::string matrix_name, std::string rows_or_columns, int actual, int expected) {
if (actual != expected) {
std::cout << matrix_name << " has " << actual << " " << rows_or_columns << ". Expected "
<< expected << "." << std::endl;
return 1;
}
return 0;
}
int tiny_setup(
TinySolver** solverp,
tinyMatrix Adyn,
tinyMatrix Bdyn,
tinyMatrix fdyn,
tinyMatrix Q,
tinyMatrix R,
tinytype rho,
int nx,
int nu,
int N,
int verbose
) {
TinySolution* solution = new TinySolution();
TinyCache* cache = new TinyCache();
TinySettings* settings = new TinySettings();
TinyWorkspace* work = new TinyWorkspace();
TinySolver* solver = new TinySolver();
solver->solution = solution;
solver->cache = cache;
solver->settings = settings;
solver->work = work;
*solverp = solver;
// Initialize solution
solution->iter = 0;
solution->solved = 0;
solution->x = tinyMatrix::Zero(nx, N);
solution->u = tinyMatrix::Zero(nu, N - 1);
// Initialize settings
tiny_set_default_settings(settings);
// Initialize workspace
work->nx = nx;
work->nu = nu;
work->N = N;
// Make sure arguments are the correct shapes
int status = 0;
status |= check_dimension("State transition matrix (A)", "rows", Adyn.rows(), nx);
status |= check_dimension("State transition matrix (A)", "columns", Adyn.cols(), nx);
status |= check_dimension("Input matrix (B)", "rows", Bdyn.rows(), nx);
status |= check_dimension("Input matrix (B)", "columns", Bdyn.cols(), nu);
status |= check_dimension("Affine vector (f)", "rows", fdyn.rows(), nx);
status |= check_dimension("Affine vector (f)", "columns", fdyn.cols(), 1);
status |= check_dimension("State stage cost (Q)", "rows", Q.rows(), nx);
status |= check_dimension("State stage cost (Q)", "columns", Q.cols(), nx);
status |= check_dimension("State input cost (R)", "rows", R.rows(), nu);
status |= check_dimension("State input cost (R)", "columns", R.cols(), nu);
if (status) {
return status;
}
work->x = tinyMatrix::Zero(nx, N);
work->u = tinyMatrix::Zero(nu, N - 1);
work->q = tinyMatrix::Zero(nx, N);
work->r = tinyMatrix::Zero(nu, N - 1);
work->p = tinyMatrix::Zero(nx, N);
work->d = tinyMatrix::Zero(nu, N - 1);
// Bound constraint slack variables
work->v = tinyMatrix::Zero(nx, N);
work->vnew = tinyMatrix::Zero(nx, N);
work->z = tinyMatrix::Zero(nu, N - 1);
work->znew = tinyMatrix::Zero(nu, N - 1);
// Bound constraint dual variables
work->g = tinyMatrix::Zero(nx, N);
work->y = tinyMatrix::Zero(nu, N - 1);
// Cone constraint slack variables
work->vc = tinyMatrix::Zero(nx, N);
work->vcnew = tinyMatrix::Zero(nx, N);
work->zc = tinyMatrix::Zero(nu, N - 1);
work->zcnew = tinyMatrix::Zero(nu, N - 1);
// Cone constraint dual variables
work->gc = tinyMatrix::Zero(nx, N);
work->yc = tinyMatrix::Zero(nu, N - 1);
// Linear constraint slack variables
work->vl = tinyMatrix::Zero(nx, N);
work->vlnew = tinyMatrix::Zero(nx, N);
work->zl = tinyMatrix::Zero(nu, N - 1);
work->zlnew = tinyMatrix::Zero(nu, N - 1);
// Linear constraint dual variables
work->gl = tinyMatrix::Zero(nx, N);
work->yl = tinyMatrix::Zero(nu, N - 1);
work->Q = (Q + rho * tinyMatrix::Identity(nx, nx)).diagonal();
work->R = (R + rho * tinyMatrix::Identity(nu, nu)).diagonal();
work->Adyn = Adyn; // State transition matrix
work->Bdyn = Bdyn; // Input matrix
work->fdyn = fdyn; // Affine offset vector
work->Xref = tinyMatrix::Zero(nx, N);
work->Uref = tinyMatrix::Zero(nu, N - 1);
work->Qu = tinyVector::Zero(nu);
work->primal_residual_state = 0;
work->primal_residual_input = 0;
work->dual_residual_state = 0;
work->dual_residual_input = 0;
work->status = 0;
work->iter = 0;
// Initialize cache
status = tiny_precompute_and_set_cache(
cache,
Adyn,
Bdyn,
fdyn,
work->Q.asDiagonal(),
work->R.asDiagonal(),
nx,
nu,
rho,
verbose
);
if (status) {
return status;
}
// Initialize sensitivity matrices for adaptive rho
if (solver->settings->adaptive_rho) {
tiny_initialize_sensitivity_matrices(solver);
}
return 0;
}
int tiny_set_bound_constraints(
TinySolver* solver,
tinyMatrix x_min,
tinyMatrix x_max,
tinyMatrix u_min,
tinyMatrix u_max
) {
if (!solver) {
std::cout << "Error in tiny_set_bound_constraints: solver is nullptr" << std::endl;
return 1;
}
// Make sure all bound constraint matrix sizes are self-consistent
int status = 0;
status |= check_dimension("Lower state bounds (x_min)", "rows", x_min.rows(), solver->work->nx);
status |= check_dimension("Lower state bounds (x_min)", "cols", x_min.cols(), solver->work->N);
status |= check_dimension("Lower state bounds (x_max)", "rows", x_max.rows(), solver->work->nx);
status |= check_dimension("Lower state bounds (x_max)", "cols", x_max.cols(), solver->work->N);
status |= check_dimension("Lower input bounds (u_min)", "rows", u_min.rows(), solver->work->nu);
status |=
check_dimension("Lower input bounds (u_min)", "cols", u_min.cols(), solver->work->N - 1);
status |= check_dimension("Lower input bounds (u_max)", "rows", u_max.rows(), solver->work->nu);
status |=
check_dimension("Lower input bounds (u_max)", "cols", u_max.cols(), solver->work->N - 1);
solver->work->x_min = x_min;
solver->work->x_max = x_max;
solver->work->u_min = u_min;
solver->work->u_max = u_max;
return 0;
}
int tiny_set_cone_constraints(
TinySolver* solver,
VectorXi Acx,
VectorXi qcx,
tinyVector cx,
VectorXi Acu,
VectorXi qcu,
tinyVector cu
) {
if (!solver) {
std::cout << "Error in tiny_set_cone_constraints: solver is nullptr" << std::endl;
return 1;
}
// Make sure all cone constraint vector sizes are self-consistent
int num_state_cones = Acx.rows();
int num_input_cones = Acu.rows();
int status = 0;
status |= check_dimension("Cone state size (qcx)", "rows", qcx.rows(), num_state_cones);
status |= check_dimension("Cone mu value for state (cx)", "rows", cx.rows(), num_state_cones);
status |= check_dimension("Cone input size (qcu)", "rows", qcu.rows(), num_input_cones);
status |= check_dimension("Cone mu value for input (cu)", "rows", cu.rows(), num_input_cones);
if (status) {
return status;
}
solver->work->numStateCones = num_state_cones;
solver->work->numInputCones = num_input_cones;
solver->work->Acx = Acx;
solver->work->qcx = qcx;
solver->work->cx = cx;
solver->work->Acu = Acu;
solver->work->qcu = qcu;
solver->work->cu = cu;
return 0;
}
int tiny_set_linear_constraints(
TinySolver* solver,
tinyMatrix Alin_x,
tinyVector blin_x,
tinyMatrix Alin_u,
tinyVector blin_u
) {
if (!solver) {
std::cout << "Error in tiny_set_linear_constraints: solver is nullptr" << std::endl;
return 1;
}
// Make sure all linear constraint matrix sizes are self-consistent
int num_state_linear = Alin_x.rows();
int num_input_linear = Alin_u.rows();
int status = 0;
// Check state constraint dimensions
if (num_state_linear > 0) {
status |= check_dimension(
"State linear constraint matrix (Alin_x)",
"rows",
Alin_x.rows(),
num_state_linear
);
status |= check_dimension(
"State linear constraint matrix (Alin_x)",
"columns",
Alin_x.cols(),
solver->work->nx
);
status |= check_dimension(
"State linear constraint vector (blin_x)",
"rows",
blin_x.rows(),
num_state_linear
);
status |=
check_dimension("State linear constraint vector (blin_x)", "columns", blin_x.cols(), 1);
}
// Check input constraint dimensions
if (num_input_linear > 0) {
status |= check_dimension(
"Input linear constraint matrix (Alin_u)",
"rows",
Alin_u.rows(),
num_input_linear
);
status |= check_dimension(
"Input linear constraint matrix (Alin_u)",
"columns",
Alin_u.cols(),
solver->work->nu
);
status |= check_dimension(
"Input linear constraint vector (blin_u)",
"rows",
blin_u.rows(),
num_input_linear
);
status |=
check_dimension("Input linear constraint vector (blin_u)", "columns", blin_u.cols(), 1);
}
if (status) {
return status;
}
solver->work->numStateLinear = num_state_linear;
solver->work->numInputLinear = num_input_linear;
solver->work->Alin_x = Alin_x;
solver->work->blin_x = blin_x;
solver->work->Alin_u = Alin_u;
solver->work->blin_u = blin_u;
return 0;
}
int tiny_precompute_and_set_cache(
TinyCache* cache,
tinyMatrix Adyn,
tinyMatrix Bdyn,
tinyMatrix fdyn,
tinyMatrix Q,
tinyMatrix R,
int nx,
int nu,
tinytype rho,
int verbose
) {
if (!cache) {
std::cout << "Error in tiny_precompute_and_set_cache: cache is nullptr" << std::endl;
return 1;
}
// Update by adding rho * identity matrix to Q, R
tinyMatrix Q1 = Q + rho * tinyMatrix::Identity(nx, nx);
tinyMatrix R1 = R + rho * tinyMatrix::Identity(nu, nu);
// Printing
if (verbose) {
std::cout << "A = " << Adyn.format(TinyApiFmt) << std::endl;
std::cout << "B = " << Bdyn.format(TinyApiFmt) << std::endl;
std::cout << "Q = " << Q1.format(TinyApiFmt) << std::endl;
std::cout << "R = " << R1.format(TinyApiFmt) << std::endl;
std::cout << "rho = " << rho << std::endl;
}
// Riccati recursion to get Kinf, Pinf
tinyMatrix Ktp1 = tinyMatrix::Zero(nu, nx);
tinyMatrix Ptp1 = rho * tinyMatrix::Ones(nx, 1).array().matrix().asDiagonal();
tinyMatrix Kinf = tinyMatrix::Zero(nu, nx);
tinyMatrix Pinf = tinyMatrix::Zero(nx, nx);
for (int i = 0; i < 1000; i++) {
Kinf = (R1 + Bdyn.transpose() * Ptp1 * Bdyn).inverse() * Bdyn.transpose() * Ptp1 * Adyn;
Pinf = Q1 + Adyn.transpose() * Ptp1 * (Adyn - Bdyn * Kinf);
// if Kinf converges, break
if ((Kinf - Ktp1).cwiseAbs().maxCoeff() < 1e-5) {
if (verbose) {
std::cout << "Kinf converged after " << i + 1 << " iterations" << std::endl;
}
break;
}
Ktp1 = Kinf;
Ptp1 = Pinf;
}
// Compute cached matrices
tinyMatrix Quu_inv = (R1 + Bdyn.transpose() * Pinf * Bdyn).inverse();
tinyMatrix AmBKt = (Adyn - Bdyn * Kinf).transpose();
// Precomputation for affine term
tinyVector APf = AmBKt * Pinf * fdyn;
tinyVector BPf = Bdyn.transpose() * Pinf * fdyn;
if (verbose) {
std::cout << "Kinf = " << Kinf.format(TinyApiFmt) << std::endl;
std::cout << "Pinf = " << Pinf.format(TinyApiFmt) << std::endl;
std::cout << "Quu_inv = " << Quu_inv.format(TinyApiFmt) << std::endl;
std::cout << "AmBKt = " << AmBKt.format(TinyApiFmt) << std::endl;
std::cout << "APf = " << APf.format(TinyApiFmt) << std::endl;
std::cout << "BPf = " << BPf.format(TinyApiFmt) << std::endl;
std::cout << "\nPrecomputation finished!\n" << std::endl;
}
cache->rho = rho;
cache->Kinf = Kinf;
cache->Pinf = Pinf;
cache->Quu_inv = Quu_inv;
cache->AmBKt = AmBKt;
cache->C1 = Quu_inv;
cache->C2 = AmBKt;
cache->APf = APf;
cache->BPf = BPf;
return 0; // return success
}
int tiny_solve(TinySolver* solver) {
return solve(solver);
}
int tiny_update_settings(
TinySettings* settings,
tinytype abs_pri_tol,
tinytype abs_dua_tol,
int max_iter,
int check_termination,
int en_state_bound,
int en_input_bound,
int en_state_soc,
int en_input_soc,
int en_state_linear,
int en_input_linear
) {
if (!settings) {
std::cout << "Error in tiny_update_settings: settings is nullptr" << std::endl;
return 1;
}
settings->abs_pri_tol = abs_pri_tol;
settings->abs_dua_tol = abs_dua_tol;
settings->max_iter = max_iter;
settings->check_termination = check_termination;
settings->en_state_bound = en_state_bound;
settings->en_input_bound = en_input_bound;
settings->en_state_soc = en_state_soc;
settings->en_input_soc = en_input_soc;
settings->en_state_linear = en_state_linear;
settings->en_input_linear = en_input_linear;
return 0;
}
int tiny_set_default_settings(TinySettings* settings) {
if (!settings) {
std::cout << "Error in tiny_set_default_settings: settings is nullptr" << std::endl;
return 1;
}
settings->abs_pri_tol = TINY_DEFAULT_ABS_PRI_TOL;
settings->abs_dua_tol = TINY_DEFAULT_ABS_DUA_TOL;
settings->max_iter = TINY_DEFAULT_MAX_ITER;
settings->check_termination = TINY_DEFAULT_CHECK_TERMINATION;
// Turn off constraints until they are set by tiny_set_bound_constraints or tiny_set_cone_constraints
settings->en_state_bound = TINY_DEFAULT_EN_STATE_BOUND;
settings->en_input_bound = TINY_DEFAULT_EN_INPUT_BOUND;
settings->en_state_soc = TINY_DEFAULT_EN_STATE_SOC;
settings->en_input_soc = TINY_DEFAULT_EN_INPUT_SOC;
settings->en_state_linear = TINY_DEFAULT_EN_STATE_LINEAR;
settings->en_input_linear = TINY_DEFAULT_EN_INPUT_LINEAR;
// Initialize adaptive rho settings
// NOTE : Adaptive rho currently supports only quadrotor system
settings->adaptive_rho = 0; // Disabled by default
settings->adaptive_rho_min = 1.0;
settings->adaptive_rho_max = 100.0;
settings->adaptive_rho_enable_clipping = 1;
return 0;
}
int tiny_set_x0(TinySolver* solver, tinyVector x0) {
if (!solver) {
std::cout << "Error in tiny_set_x0: solver is nullptr" << std::endl;
return 1;
}
if (x0.rows() != solver->work->nx) {
perror("Error in tiny_set_x0: x0 is not the correct length");
}
solver->work->x.col(0) = x0;
return 0;
}
int tiny_set_x_ref(TinySolver* solver, tinyMatrix x_ref) {
if (!solver) {
std::cout << "Error in tiny_set_x_ref: solver is nullptr" << std::endl;
return 1;
}
int status = 0;
status |= check_dimension(
"State reference trajectory (x_ref)",
"rows",
x_ref.rows(),
solver->work->nx
);
status |= check_dimension(
"State reference trajectory (x_ref)",
"columns",
x_ref.cols(),
solver->work->N
);
solver->work->Xref = x_ref;
return 0;
}
int tiny_set_u_ref(TinySolver* solver, tinyMatrix u_ref) {
if (!solver) {
std::cout << "Error in tiny_set_u_ref: solver is nullptr" << std::endl;
return 1;
}
int status = 0;
status |= check_dimension(
"Control/input reference trajectory (u_ref)",
"rows",
u_ref.rows(),
solver->work->nu
);
status |= check_dimension(
"Control/input reference trajectory (u_ref)",
"columns",
u_ref.cols(),
solver->work->N - 1
);
solver->work->Uref = u_ref;
return 0;
}
void tiny_initialize_sensitivity_matrices(TinySolver* solver) {
int nu = solver->work->nu;
int nx = solver->work->nx;
// Initialize matrices with zeros
solver->cache->dKinf_drho = tinyMatrix::Zero(nu, nx);
solver->cache->dPinf_drho = tinyMatrix::Zero(nx, nx);
solver->cache->dC1_drho = tinyMatrix::Zero(nu, nu);
solver->cache->dC2_drho = tinyMatrix::Zero(nx, nx);
const float dKinf_drho[4][12] = { { 0.0001,
-0.0001,
-0.0025,
0.0003,
0.0007,
0.0050,
0.0001,
-0.0001,
-0.0008,
0.0000,
0.0001,
0.0008 },
{ -0.0001,
-0.0000,
-0.0025,
-0.0001,
-0.0006,
-0.0050,
-0.0001,
0.0000,
-0.0008,
-0.0000,
-0.0001,
-0.0008 },
{ 0.0000,
0.0000,
-0.0025,
0.0001,
0.0004,
0.0050,
0.0000,
0.0000,
-0.0008,
0.0000,
0.0000,
0.0008 },
{ -0.0000,
0.0001,
-0.0025,
-0.0003,
-0.0004,
-0.0050,
-0.0000,
0.0001,
-0.0008,
-0.0000,
-0.0000,
-0.0008 } };
const float dPinf_drho[12][12] = { { 0.0494,
-0.0045,
-0.0000,
0.0110,
0.1300,
-0.0283,
0.0280,
-0.0026,
-0.0000,
0.0004,
0.0070,
-0.0094 },
{ -0.0045,
0.0491,
0.0000,
-0.1320,
-0.0111,
0.0114,
-0.0026,
0.0279,
0.0000,
-0.0076,
-0.0004,
0.0038 },
{ -0.0000,
0.0000,
2.4450,
0.0000,
-0.0000,
-0.0000,
-0.0000,
0.0000,
1.2593,
0.0000,
0.0000,
0.0000 },
{ 0.0110,
-0.1320,
0.0000,
0.3913,
0.0592,
0.3108,
0.0080,
-0.0776,
0.0000,
0.0254,
0.0068,
0.0750 },
{ 0.1300,
-0.0111,
-0.0000,
0.0592,
0.4420,
0.7771,
0.0797,
-0.0081,
-0.0000,
0.0068,
0.0350,
0.1875 },
{ -0.0283,
0.0114,
-0.0000,
0.3108,
0.7771,
10.0441,
0.0272,
-0.0109,
0.0000,
0.0655,
0.1639,
2.6362 },
{ 0.0280,
-0.0026,
-0.0000,
0.0080,
0.0797,
0.0272,
0.0163,
-0.0016,
-0.0000,
0.0005,
0.0047,
0.0032 },
{ -0.0026,
0.0279,
0.0000,
-0.0776,
-0.0081,
-0.0109,
-0.0016,
0.0161,
0.0000,
-0.0046,
-0.0005,
-0.0013 },
{ -0.0000,
0.0000,
1.2593,
0.0000,
-0.0000,
0.0000,
-0.0000,
0.0000,
0.9232,
0.0000,
0.0000,
0.0000 },
{ 0.0004,
-0.0076,
0.0000,
0.0254,
0.0068,
0.0655,
0.0005,
-0.0046,
0.0000,
0.0022,
0.0017,
0.0244 },
{ 0.0070,
-0.0004,
0.0000,
0.0068,
0.0350,
0.1639,
0.0047,
-0.0005,
0.0000,
0.0017,
0.0054,
0.0610 },
{ -0.0094,
0.0038,
0.0000,
0.0750,
0.1875,
2.6362,
0.0032,
-0.0013,
0.0000,
0.0244,
0.0610,
0.9869 } };
const float dC1_drho[4][4] = { { -0.0000, 0.0000, -0.0000, 0.0000 },
{ 0.0000, -0.0000, 0.0000, -0.0000 },
{ -0.0000, 0.0000, -0.0000, 0.0000 },
{ 0.0000, -0.0000, 0.0000, -0.0000 } };
const float dC2_drho[12][12] = { { 0.0000,
-0.0000,
0.0000,
0.0000,
0.0000,
-0.0000,
0.0000,
-0.0000,
0.0000,
0.0000,
0.0000,
-0.0000 },
{ -0.0000,
0.0000,
0.0000,
-0.0000,
-0.0000,
0.0000,
-0.0000,
0.0000,
0.0000,
-0.0000,
-0.0000,
0.0000 },
{ -0.0000,
0.0000,
0.0001,
0.0000,
-0.0000,
-0.0000,
-0.0000,
0.0000,
0.0000,
0.0000,
-0.0000,
-0.0000 },
{ 0.0000,
-0.0000,
-0.0000,
0.0001,
0.0000,
-0.0000,
0.0000,
-0.0000,
-0.0000,
0.0000,
0.0000,
-0.0000 },
{ 0.0000,
-0.0000,
-0.0000,
0.0000,
0.0001,
-0.0000,
0.0000,
-0.0000,
-0.0000,
0.0000,
0.0000,
-0.0000 },
{ -0.0000,
0.0000,
-0.0000,
-0.0000,
0.0000,
0.0001,
-0.0000,
0.0000,
-0.0000,
0.0000,
0.0000,
0.0000 },
{ 0.0000,
-0.0000,
0.0000,
0.0000,
0.0000,
-0.0000,
0.0000,
-0.0000,
0.0000,
0.0000,
0.0000,
-0.0000 },
{ -0.0000,
0.0000,
0.0000,
-0.0000,
-0.0000,
0.0000,
-0.0000,
0.0000,
0.0000,
-0.0000,
-0.0000,
0.0000 },
{ -0.0000,
0.0000,
0.0021,
0.0000,
-0.0000,
-0.0000,
-0.0000,
0.0000,
0.0006,
0.0000,
-0.0000,
-0.0000 },
{ 0.0002,
-0.0027,
-0.0000,
0.0068,
0.0005,
-0.0005,
0.0001,
-0.0015,
-0.0000,
0.0004,
0.0000,
-0.0001 },
{ 0.0027,
-0.0002,
0.0000,
0.0005,
0.0066,
-0.0011,
0.0015,
-0.0001,
0.0000,
0.0000,
0.0004,
-0.0002 },
{ -0.0001,
0.0001,
0.0000,
-0.0000,
0.0000,
0.0041,
-0.0000,
0.0000,
0.0000,
0.0000,
0.0000,
0.0006 } };
// Map arrays to Eigen matrices
solver->cache->dKinf_drho = Map<const Matrix<float, 4, 12>>(dKinf_drho[0]).cast<tinytype>();
solver->cache->dPinf_drho = Map<const Matrix<float, 12, 12>>(dPinf_drho[0]).cast<tinytype>();
solver->cache->dC1_drho = Map<const Matrix<float, 4, 4>>(dC1_drho[0]).cast<tinytype>();
solver->cache->dC2_drho = Map<const Matrix<float, 12, 12>>(dC2_drho[0]).cast<tinytype>();
}
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,118 @@
#pragma once
#include "admm.hpp"
#include <iostream>
#ifdef __cplusplus
extern "C" {
#endif
int tiny_setup(
TinySolver** solverp,
tinyMatrix Adyn,
tinyMatrix Bdyn,
tinyMatrix fdyn,
tinyMatrix Q,
tinyMatrix R,
tinytype rho,
int nx,
int nu,
int N,
int verbose
);
int tiny_set_bound_constraints(
TinySolver* solver,
tinyMatrix x_min,
tinyMatrix x_max,
tinyMatrix u_min,
tinyMatrix u_max
);
int tiny_set_cone_constraints(
TinySolver* solver,
VectorXi Acu,
VectorXi qcu,
tinyVector cu,
VectorXi Acx,
VectorXi qcx,
tinyVector cx
);
int tiny_set_linear_constraints(
TinySolver* solver,
tinyMatrix Alin_x,
tinyVector blin_x,
tinyMatrix Alin_u,
tinyVector blin_u
);
int tiny_precompute_and_set_cache(
TinyCache* cache,
tinyMatrix Adyn,
tinyMatrix Bdyn,
tinyMatrix fdyn,
tinyMatrix Q,
tinyMatrix R,
int nx,
int nu,
tinytype rho,
int verbose
);
void compute_sensitivity_matrices(
TinyCache* cache,
tinyMatrix Adyn,
tinyMatrix Bdyn,
tinyMatrix Q,
tinyMatrix R,
int nx,
int nu,
tinytype rho,
int verbose
);
int tiny_update_matrices_with_derivatives(TinyCache* cache, tinytype delta_rho);
int tiny_solve(TinySolver* solver);
int tiny_update_settings(
TinySettings* settings,
tinytype abs_pri_tol,
tinytype abs_dua_tol,
int max_iter,
int check_termination,
int en_state_bound,
int en_input_bound,
int en_state_soc,
int en_input_soc,
int en_state_linear,
int en_input_linear
);
int tiny_set_default_settings(TinySettings* settings);
int tiny_set_x0(TinySolver* solver, tinyVector x0);
int tiny_set_x_ref(TinySolver* solver, tinyMatrix x_ref);
int tiny_set_u_ref(TinySolver* solver, tinyMatrix u_ref);
/**
* Initialize sensitivity matrices for adaptive rho
*
* @param solver Pointer to solver
*/
void tiny_initialize_sensitivity_matrices(TinySolver* solver);
int tiny_setup_state_soc_constraints(
TinySolver* solver,
tinyVector Acx,
tinyVector qcx,
tinyVector cx,
int numStateCones
);
int tiny_setup_input_soc_constraints(
TinySolver* solver,
tinyVector Acu,
tinyVector qcu,
tinyVector cu,
int numInputCones
);
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,13 @@
#pragma once
// Default settings
#define TINY_DEFAULT_ABS_PRI_TOL (1e-03)
#define TINY_DEFAULT_ABS_DUA_TOL (1e-03)
#define TINY_DEFAULT_MAX_ITER (1000)
#define TINY_DEFAULT_CHECK_TERMINATION (1)
#define TINY_DEFAULT_EN_STATE_BOUND (1)
#define TINY_DEFAULT_EN_INPUT_BOUND (1)
#define TINY_DEFAULT_EN_STATE_SOC (0)
#define TINY_DEFAULT_EN_INPUT_SOC (0)
#define TINY_DEFAULT_EN_STATE_LINEAR (0)
#define TINY_DEFAULT_EN_INPUT_LINEAR (0)

View File

@@ -0,0 +1,197 @@
#pragma once
#include <Eigen/Eigen>
// #include <Eigen/Core>
// #include <Eigen/LU>
using namespace Eigen;
#ifdef __cplusplus
extern "C" {
#endif
typedef double tinytype; // should be double if you want to generate code
typedef Matrix<tinytype, Dynamic, Dynamic> tinyMatrix;
typedef Matrix<tinytype, Dynamic, 1> tinyVector;
// typedef Matrix<tinytype, NSTATES, 1> tiny_VectorNx;
// typedef Matrix<tinytype, NINPUTS, 1> tiny_VectorNu;
// typedef Matrix<tinytype, NSTATES, NSTATES> tiny_MatrixNxNx;
// typedef Matrix<tinytype, NSTATES, NINPUTS> tiny_MatrixNxNu;
// typedef Matrix<tinytype, NINPUTS, NSTATES> tiny_MatrixNuNx;
// typedef Matrix<tinytype, NINPUTS, NINPUTS> tiny_MatrixNuNu;
// typedef Matrix<tinytype, NSTATES, NHORIZON> tiny_MatrixNxNh; // Nu x Nh
// typedef Matrix<tinytype, NINPUTS, NHORIZON - 1> tiny_MatrixNuNhm1; // Nu x Nh-1
/**
* Solution
*/
typedef struct {
int iter;
int solved;
tinyMatrix x; // nx x N
tinyMatrix u; // nu x N-1
} TinySolution;
/**
* Matrices that must be recomputed with changes in time step, rho
*/
typedef struct {
tinytype rho;
tinyMatrix Kinf; // nu x nx
tinyMatrix Pinf; // nx x nx
tinyMatrix Quu_inv; // nu x nu
tinyMatrix AmBKt; // nx x nx
tinyVector APf; // nx x 1
tinyVector BPf; // nu x 1
tinyMatrix C1; // From adaptive rho
tinyMatrix C2; // From adaptive rho
// Sensitivity matrices for adaptive rho
tinyMatrix dKinf_drho;
tinyMatrix dPinf_drho;
tinyMatrix dC1_drho;
tinyMatrix dC2_drho;
} TinyCache;
/**
* User settings
*/
typedef struct {
tinytype abs_pri_tol;
tinytype abs_dua_tol;
int max_iter;
int check_termination;
int en_state_bound;
int en_input_bound;
int en_state_soc;
int en_input_soc;
int en_state_linear;
int en_input_linear;
// Add adaptive rho parameters
int adaptive_rho; // Enable/disable adaptive rho (1/0)
tinytype adaptive_rho_min; // Minimum value for rho
tinytype adaptive_rho_max; // Maximum value for rho
int adaptive_rho_enable_clipping; // Enable/disable clipping of rho (1/0)
} TinySettings;
/**
* Problem variables
*/
typedef struct {
int nx; // Number of states
int nu; // Number of control inputs
int N; // Number of knotpoints in the horizon
// State and input
tinyMatrix x; // nx x N
tinyMatrix u; // nu x N-1
// Linear control cost terms
tinyMatrix q; // nx x N
tinyMatrix r; // nu x N-1
// Linear Riccati backward pass terms
tinyMatrix p; // nx x N
tinyMatrix d; // nu x N-1
// Bound constraint variables
// Slack variables
tinyMatrix v; // nx x N
tinyMatrix vnew; // nx x N
tinyMatrix z; // nu x N-1
tinyMatrix znew; // nu x N-1
// Dual variables
tinyMatrix g; // nx x N
tinyMatrix y; // nu x N-1
// State and input bounds
tinyMatrix x_min; // nx x N
tinyMatrix x_max; // nx x N
tinyMatrix u_min; // nu x N-1
tinyMatrix u_max; // nu x N-1
// Cone constraint variables
// Variables to keep track of general cone information
int numStateCones; // Number of cone constraints on states at each time step
int numInputCones; // Number of cone constraints on inputs at each time step
tinyVector cx; // One coefficient for each state cone
tinyVector cu; // One coefficient for each input cone
VectorXi Acx; // Start indices for each state cone
VectorXi Acu; // Start indices for each input cone
VectorXi qcx; // Dimension for each state cone
VectorXi qcu; // Dimension for each input cone
// Slack variables
tinyMatrix vc; // nx x N
tinyMatrix vcnew; // nx x N
tinyMatrix zc; // nu x N-1
tinyMatrix zcnew; // nu x N-1
// Dual variables
tinyMatrix gc; // nx x N
tinyMatrix yc; // nu x N-1
// Linear constraint variables
// Variables to keep track of general linear constraint information
int numStateLinear; // Number of linear constraints on states at each time step
int numInputLinear; // Number of linear constraints on inputs at each time step
// Constraint matrices and vectors
tinyMatrix Alin_x; // Normal vectors for state linear constraints (numStateLinear x nx)
tinyVector blin_x; // Offset values for state linear constraints (numStateLinear x 1)
tinyMatrix Alin_u; // Normal vectors for input linear constraints (numInputLinear x nu)
tinyVector blin_u; // Offset values for input linear constraints (numInputLinear x 1)
// Slack variables for linear constraints
tinyMatrix vl; // nx x N
tinyMatrix vlnew; // nx x N
tinyMatrix zl; // nu x N-1
tinyMatrix zlnew; // nu x N-1
// Dual variables for linear constraints
tinyMatrix gl; // nx x N
tinyMatrix yl; // nu x N-1
// Q, R, A, B, f given by user
tinyVector Q; // nx x 1
tinyVector R; // nu x 1
tinyMatrix Adyn; // nx x nx (state transition matrix)
tinyMatrix Bdyn; // nx x nu (control matrix)
tinyVector fdyn; // nx x 1 (affine vector)
// Reference trajectory to track for one horizon
tinyMatrix Xref; // nx x N
tinyMatrix Uref; // nu x N-1
// Temporaries
tinyVector Qu; // nu x 1
// Variables for keeping track of solve status
tinytype primal_residual_state;
tinytype primal_residual_input;
tinytype dual_residual_state;
tinytype dual_residual_input;
int status;
int iter;
} TinyWorkspace;
/**
* Main TinyMPC solver structure that holds all information.
*/
typedef struct {
TinySolution* solution; // Solution
TinySettings* settings; // Problem settings
TinyCache* cache; // Problem cache
TinyWorkspace* work; // Solver workspace
} TinySolver;
// Add at the top with other definitions
#define BENCH_NX 12
#define BENCH_NU 4
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,111 @@
#pragma once
#include <algorithm>
#include <cassert>
#include <concepts>
#include <cstddef>
#include <vector>
namespace wust_vision {
namespace auto_aim {
template<typename T>
concept HasStaticLerp = requires(const T& a, const T& b, double t) {
{
T::lerp(a, b, t)
} -> std::same_as<T>;
};
template<HasStaticLerp PointT>
class Trajectory {
public:
void reserve(size_t n) {
cp_vec.reserve(n);
dt_vec.reserve(n > 0 ? n - 1 : 0);
prefix_time.reserve(n);
}
void clear() {
cp_vec.clear();
dt_vec.clear();
prefix_time.clear();
total_duration_ = 0.0;
}
void push_back(const PointT& p, double dt = 0.0) {
if (cp_vec.empty()) {
cp_vec.push_back(p);
prefix_time.push_back(0.0);
total_duration_ = 0.0;
return;
}
assert(dt >= 0.0);
cp_vec.push_back(p);
dt_vec.push_back(dt);
total_duration_ += dt;
prefix_time.push_back(total_duration_);
}
void set(const std::vector<PointT>& c, const std::vector<double>& t) {
assert(!c.empty());
assert(c.size() == t.size() + 1);
cp_vec = c;
dt_vec = t;
prefix_time.resize(cp_vec.size());
prefix_time[0] = 0.0;
for (size_t i = 0; i < dt_vec.size(); ++i)
prefix_time[i + 1] = prefix_time[i] + dt_vec[i];
total_duration_ = prefix_time.back();
}
double getPrefixTimeAtIdx(int i) const {
return prefix_time[i];
}
PointT getStateAtIdx(int i) const {
return cp_vec[i];
}
PointT getStateAtTime(double t) const {
if (cp_vec.empty())
return PointT {};
if (t <= 0.0)
return cp_vec.front();
if (t >= total_duration_)
return cp_vec.back();
auto it = std::lower_bound(prefix_time.begin(), prefix_time.end(), t);
size_t i1 = std::distance(prefix_time.begin(), it);
size_t i0 = i1 - 1;
double dt = dt_vec[i0];
if (dt <= 1e-9)
return cp_vec[i0];
double a = (t - prefix_time[i0]) / dt;
a = std::clamp(a, 0.0, 1.0);
return PointT::lerp(cp_vec[i0], cp_vec[i1], a);
}
double getTotalDuration() const {
return total_duration_;
}
size_t size() const {
return cp_vec.size();
}
std::vector<PointT> cp_vec;
std::vector<double> dt_vec;
std::vector<double> prefix_time;
double total_duration_ { 0.0 };
};
} // namespace auto_aim
} // namespace wust_vision

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,27 @@
#pragma once
#include <memory>
namespace wust_vl::common::utils {
class Parameter;
}
using wust_vlParamterPtr = std::shared_ptr<wust_vl::common::utils::Parameter>;
namespace wust_vision {
struct GimbalCmd;
}
namespace wust_vision::auto_aim {
enum class AutoAimFsm;
class Target;
class VeryAimer {
public:
using Ptr = std::shared_ptr<VeryAimer>;
VeryAimer(wust_vlParamterPtr auto_aim_config_parameter);
static Ptr create(wust_vlParamterPtr auto_aim_config_parameter) {
return std::make_shared<VeryAimer>(auto_aim_config_parameter);
};
~VeryAimer();
[[nodiscard]] GimbalCmd
veryAim(Target target, double bullet_speed, const AutoAimFsm& auto_aim_fsm);
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace wust_vision::auto_aim

View File

@@ -0,0 +1,70 @@
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tasks/auto_aim/type.hpp"
#include "tasks/type_common.hpp"
namespace wust_vision {
namespace auto_aim {
struct LightParams {
// width / height
double min_ratio;
double max_ratio;
// vertical angle
double max_angle;
// judge color
int color_diff_thresh;
double max_angle_diff;
int binary_thres;
void load(const YAML::Node& config) {
binary_thres = config["binary_thres"].as<int>();
min_ratio = config["min_ratio"].as<double>();
max_ratio = config["max_ratio"].as<double>();
max_angle = config["max_angle"].as<double>();
max_angle_diff = config["max_angle_diff"].as<double>();
color_diff_thresh = config["color_diff_thresh"].as<int>();
}
};
struct ArmorParams {
double min_light_ratio;
// light pairs distance
double min_small_center_distance;
double max_small_center_distance;
double min_large_center_distance;
double max_large_center_distance;
// horizontal angle
double max_angle;
void load(const YAML::Node& config) {
min_light_ratio = config["min_light_ratio"].as<double>();
min_small_center_distance = config["min_small_center_distance"].as<double>();
max_small_center_distance = config["max_small_center_distance"].as<double>();
min_large_center_distance = config["min_large_center_distance"].as<double>();
max_large_center_distance = config["max_large_center_distance"].as<double>();
max_angle = config["max_angle"].as<double>();
}
};
class ArmorDetectorBase {
public:
using Ptr = std::unique_ptr<ArmorDetectorBase>;
virtual ~ArmorDetectorBase() = default;
virtual void
pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) = 0;
using DetectorCallback =
std::function<void(const std::vector<ArmorObject>&, const CommonFrame&)>;
virtual void setCallback(DetectorCallback cb) = 0;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,473 @@
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tasks/auto_aim/armor_detect/armor_detector_common.hpp"
#include "tasks/auto_aim/armor_detect/armor_detector_base.hpp"
#include "tasks/auto_aim/armor_detect/number_classifier/factory.hpp"
#include "tasks/utils/utils.hpp"
#include "wust_vl/common/utils/timer.hpp"
namespace wust_vision {
namespace auto_aim {
struct ArmorDetectorCommon::Impl {
public:
Impl(const YAML::Node& config) {
params_.load(config);
number_classifier_ = NumberClassifierFactory::createNumberClassifier(
params_.classify_backend,
params_.classify_model_path,
params_.classify_label_path
);
}
bool extractNetImage(const cv::Mat& src, ArmorObject& armor) const noexcept {
constexpr int light_length = 12;
constexpr int warp_height = 28;
constexpr int small_armor_width = 32;
constexpr int large_armor_width = 54;
const cv::Size roi_size(20, 28);
if (src.empty() || src.cols < 10 || src.rows < 10) {
std::cerr << "[extractNetImage] input src is empty or too small!" << std::endl;
return false;
}
const auto ordered = armor.sortCorners(armor.pts);
const cv::Point2f& p0 = ordered[0];
const cv::Point2f& p1 = ordered[1];
const cv::Point2f& p2 = ordered[2];
const cv::Point2f& p3 = ordered[3];
const float l1_len = cv::norm(p1 - p0);
const float l2_len = cv::norm(p2 - p3);
const cv::Point2f c1 = (p0 + p1) * 0.5f;
const cv::Point2f c2 = (p2 + p3) * 0.5f;
const float avg_light_len = 0.5f * (l1_len + l2_len);
const float center_dist =
avg_light_len > 1e-3f ? cv::norm(c1 - c2) / avg_light_len : 0.f;
const bool is_large = center_dist > params_.armor_params.min_large_center_distance;
const cv::Rect bbox = cv::boundingRect(armor.pts);
if (bbox.width <= 0 || bbox.height <= 0)
return false;
if (bbox.width > src.cols || bbox.height > src.rows)
return false;
const int dw = static_cast<int>(bbox.width * (params_.expand_ratio_w - 1.f));
const int dh = static_cast<int>(bbox.height * (params_.expand_ratio_h - 1.f));
int new_x = bbox.x - (dw >> 1);
int new_y = bbox.y - (dh >> 1);
new_x = std::max(new_x, 0);
new_y = std::max(new_y, 0);
int new_w = std::min(bbox.width + dw, src.cols - new_x);
int new_h = std::min(bbox.height + dh, src.rows - new_y);
if (new_w <= 0 || new_h <= 0)
return false;
const cv::Rect expanded_rect(new_x, new_y, new_w, new_h);
cv::Mat litroi_color = src(expanded_rect);
if (litroi_color.empty())
return false;
cv::Mat litroi_gray;
try {
cv::cvtColor(litroi_color, litroi_gray, cv::COLOR_BGR2GRAY);
} catch (...) {
return false;
}
armor.whole_gray_img = litroi_gray;
if (params_.enable_cv) {
cv::Mat litroi_binary;
try {
cv::threshold(
litroi_gray,
litroi_binary,
params_.light_params.binary_thres,
255,
cv::THRESH_BINARY
);
armor.whole_binary_img = litroi_binary;
} catch (...) {
return false;
}
}
const cv::Point2f offset(static_cast<float>(new_x), static_cast<float>(new_y));
if (params_.enable_classify) {
cv::Point2f src_vertices[4] = { armor.pts[1] - offset,
armor.pts[0] - offset,
armor.pts[3] - offset,
armor.pts[2] - offset };
const int warp_width = is_large ? large_armor_width : small_armor_width;
const int top_light_y = (warp_height - light_length) / 2 - 1;
const int bottom_light_y = top_light_y + light_length;
if (warp_width <= 0 || warp_height <= 0)
return false;
cv::Point2f dst_vertices[4] = {
{ 0.f, static_cast<float>(bottom_light_y) },
{ 0.f, static_cast<float>(top_light_y) },
{ static_cast<float>(warp_width - 1), static_cast<float>(top_light_y) },
{ static_cast<float>(warp_width - 1), static_cast<float>(bottom_light_y) }
};
const cv::Mat warp_mat = cv::getPerspectiveTransform(src_vertices, dst_vertices);
cv::Mat number_image;
cv::warpPerspective(
litroi_gray,
number_image,
warp_mat,
cv::Size(warp_width, warp_height),
cv::INTER_LINEAR,
cv::BORDER_CONSTANT,
0
);
const int roi_x = (warp_width - roi_size.width) >> 1;
const cv::Rect num_roi(roi_x, 0, roi_size.width, roi_size.height);
if ((num_roi & cv::Rect(0, 0, warp_width, warp_height)) != num_roi)
return false;
cv::Mat num_crop = number_image(num_roi);
cv::threshold(
num_crop,
armor.number_img,
0,
255,
cv::THRESH_BINARY | cv::THRESH_OTSU
);
}
armor.whole_rgb_img = litroi_color;
armor.local_offset = offset;
return true;
}
bool refineLightsFromArmorPts(ArmorObject& armor) const noexcept {
armor.center = (armor.pts[0] + armor.pts[1] + armor.pts[2] + armor.pts[3]) * 0.25f;
const int n_lights = static_cast<int>(armor.lights.size());
if (n_lights < 2)
return false;
const auto ordered = armor.sortCorners(armor.pts);
const cv::Point2f ref_centers[2] = { (ordered[0] + ordered[1]) * 0.5f,
(ordered[2] + ordered[3]) * 0.5f };
int best0 = -1, best1 = -1;
float best0_d2 = std::numeric_limits<float>::max();
float best1_d2 = std::numeric_limits<float>::max();
for (int i = 0; i < n_lights; ++i) {
const cv::Point2f& c = armor.lights[i].center;
const cv::Point2f d0 = c - ref_centers[0];
const float dist0 = d0.dot(d0);
if (dist0 < best0_d2) {
best0_d2 = dist0;
best0 = i;
}
const cv::Point2f d1 = c - ref_centers[1];
const float dist1 = d1.dot(d1);
if (dist1 < best1_d2) {
best1_d2 = dist1;
best1 = i;
}
}
if (best0 == best1) {
best1 = -1;
best1_d2 = std::numeric_limits<float>::max();
for (int i = 0; i < n_lights; ++i) {
if (i == best0)
continue;
const cv::Point2f d = armor.lights[i].center - ref_centers[1];
const float dist = d.dot(d);
if (dist < best1_d2) {
best1_d2 = dist;
best1 = i;
}
}
}
if (best0 < 0 || best1 < 0)
return false;
const auto& l0 = armor.lights[best0];
const auto& l1 = armor.lights[best1];
if (l0.center.x < l1.center.x) {
armor.lights[0] = l0;
armor.lights[1] = l1;
} else {
armor.lights[0] = l1;
armor.lights[1] = l0;
}
return true;
}
std::vector<Light>
findLights(const cv::Mat& color_img, const cv::Mat& binary_img, ArmorObject& armor)
const noexcept {
std::vector<std::vector<cv::Point>> contours;
contours.reserve(64);
cv::findContours(binary_img, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE);
std::vector<Light> all_lights;
all_lights.reserve(contours.size());
for (const auto& contour: contours) {
const int n = static_cast<int>(contour.size());
if (n < 6)
continue;
Light light(contour);
if (!isLight(light))
continue;
int sum_r = 0;
int sum_b = 0;
for (const auto& pt: contour) {
const cv::Vec3b* row = color_img.ptr<cv::Vec3b>(pt.y);
const cv::Vec3b& pix = row[pt.x];
sum_r += pix[0];
sum_b += pix[2];
}
const int avg_diff = std::abs(sum_r - sum_b) / n;
if (avg_diff <= params_.light_params.color_diff_thresh)
continue;
light.color = (sum_r > sum_b) ? 0 : 1; // 0=红, 1=蓝
all_lights.emplace_back(std::move(light));
}
std::sort(all_lights.begin(), all_lights.end(), [](const Light& a, const Light& b) {
return a.center.x < b.center.x;
});
armor.lights = all_lights;
return all_lights;
}
bool isLight(const Light& light) const noexcept {
// width / length 比例
const float ratio = light.width / light.length;
if (ratio <= params_.light_params.min_ratio || ratio >= params_.light_params.max_ratio)
return false;
if (light.tilt_angle >= params_.light_params.max_angle)
return false;
return true;
}
bool isArmor(const Light& l1, const Light& l2) const noexcept {
const float len1 = l1.length;
const float len2 = l2.length;
if (len1 <= 1e-3f || len2 <= 1e-3f)
return false;
const float min_len = (len1 < len2) ? len1 : len2;
const float max_len = (len1 < len2) ? len2 : len1;
if (min_len / max_len <= params_.armor_params.min_light_ratio)
return false;
const cv::Point2f d = l1.center - l2.center;
const float dist2 = d.dot(d);
const float avg_len = 0.5f * (len1 + len2);
const float min_small = params_.armor_params.min_small_center_distance * avg_len;
const float max_small = params_.armor_params.max_small_center_distance * avg_len;
const float min_large = params_.armor_params.min_large_center_distance * avg_len;
const float max_large = params_.armor_params.max_large_center_distance * avg_len;
const float min_small2 = min_small * min_small;
const float max_small2 = max_small * max_small;
const float min_large2 = min_large * min_large;
const float max_large2 = max_large * max_large;
const bool small_ok = (dist2 >= min_small2 && dist2 < max_small2);
const bool large_ok = (dist2 >= min_large2 && dist2 < max_large2);
if (!(small_ok || large_ok))
return false;
static const float tan_max_angle =
std::tan(params_.armor_params.max_angle * CV_PI / 180.0f);
if (std::abs(d.y) >= std::abs(d.x) * tan_max_angle)
return false;
if (l1.color != l2.color)
return false;
return true;
}
std::vector<ArmorObject> detectNet(
const cv::Mat& src_img,
std::vector<ArmorObject>& objs_result,
Eigen::Matrix3f transform_matrix,
int detect_color,
const std::optional<ArmorNumber>& target_number
) const noexcept {
std::vector<ArmorObject> armors;
if (!src_img.data || src_img.empty()) {
std::cout << "img data nullptr or empty" << std::endl;
return armors;
}
if (objs_result.empty()) {
return armors;
}
for (auto& armor_in: objs_result) {
ArmorObject armor = armor_in;
if ((detect_color == 0 && armor.color == ArmorColor::BLUE)
|| (detect_color == 1 && armor.color == ArmorColor::RED))
{
continue;
}
if (params_.enable_classify || params_.enable_cv) {
bool ok = false;
ok = extractNetImage(src_img, armor);
if (!ok)
continue;
}
if (params_.enable_classify) {
number_classifier_->classifyNumber(armor);
if (armor.confidence < params_.classifier_threshold)
continue;
}
if (target_number.has_value()) {
if (!isSameTarget(target_number.value(), armor.number)) {
continue;
}
}
if (armor.color == ArmorColor::NONE || armor.color == ArmorColor::PURPLE) {
armor.is_ok = false;
armor.transform(transform_matrix);
armors.emplace_back(armor);
continue;
}
if (params_.enable_cv) {
findLights(armor.whole_rgb_img, armor.whole_binary_img, armor);
if (refineLightsFromArmorPts(armor)) {
if (isArmor(armor.lights[0], armor.lights[1])) {
armor.is_ok = true;
for (auto& light: armor.lights) {
light.addOffset(armor.local_offset);
}
}
}
if (armor.is_ok) {
armor.is_ok = armor.checkOkptsRight(params_.max_pts_error);
}
}
if (!armor.is_ok) {
auto ordered = armor.sortCorners(armor.pts);
Light l1, l2;
l1.length = cv::norm(ordered[1] - ordered[0]);
l1.center = (ordered[0] + ordered[1]) / 2.0;
l2.length = cv::norm(ordered[2] - ordered[3]);
l2.center = (ordered[2] + ordered[3]) / 2.0;
if (!isArmor(l1, l2)) {
continue;
}
}
armor.transform(transform_matrix);
armors.emplace_back(armor);
}
return armors;
}
std::unique_ptr<NumberClassifierBase> number_classifier_;
struct ArmorDetectCommonParams {
std::string classify_backend = "opencv";
std::string classify_model_path;
std::string classify_label_path;
double classifier_threshold = 0.5;
LightParams light_params;
ArmorParams armor_params;
float expand_ratio_w = 1.1f;
float expand_ratio_h = 1.1f;
double max_pts_error = 20.0;
bool enable_cv = false;
bool enable_classify = true;
void load(const YAML::Node& config) {
expand_ratio_w = config["cv"]["light"]["expand_ratio_w"].as<float>(1.1);
expand_ratio_h = config["cv"]["light"]["expand_ratio_h"].as<float>(1.1);
max_pts_error = config["cv"]["light"]["max_pts_error"].as<double>(20.0);
enable_cv = config["cv"]["enable"].as<bool>();
light_params.load(config["cv"]["light"]);
armor_params.load(config["cv"]["armor"]);
enable_classify = config["classify"]["enable"].as<bool>();
classify_model_path =
utils::expandEnv(config["classify"]["model_path"].as<std::string>());
classify_label_path =
utils::expandEnv(config["classify"]["label_path"].as<std::string>());
classify_backend = config["classify"]["backend"].as<std::string>();
classifier_threshold = config["classify"]["threshold"].as<double>();
}
} params_;
};
ArmorDetectorCommon::ArmorDetectorCommon(const YAML::Node& config) {
_impl = std::make_unique<Impl>(config);
}
ArmorDetectorCommon::~ArmorDetectorCommon() {
_impl.reset();
}
std::vector<ArmorObject> ArmorDetectorCommon::detectNet(
const cv::Mat& src_img,
std::vector<ArmorObject>& objs_result,
Eigen::Matrix3f transform_matrix,
int detect_color,
const std::optional<ArmorNumber>& target_number
) {
return _impl
->detectNet(src_img, objs_result, transform_matrix, detect_color, target_number);
}
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,41 @@
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tasks/auto_aim/type.hpp"
namespace wust_vision {
namespace auto_aim {
class ArmorDetectorCommon {
public:
using Ptr = std::unique_ptr<ArmorDetectorCommon>;
ArmorDetectorCommon(const YAML::Node& config);
static Ptr create(const YAML::Node& config) {
return std::make_unique<ArmorDetectorCommon>(config);
}
~ArmorDetectorCommon();
std::vector<ArmorObject> detectNet(
const cv::Mat& src_img,
std::vector<ArmorObject>& objs_result,
Eigen::Matrix3f transform_matrix,
int detect_color,
const std::optional<ArmorNumber>& target_number = std::nullopt
);
private:
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,134 @@
// Copyright 2025 XiaoJian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "armor_detector_base.hpp"
#include "tasks/utils/config.hpp"
#include <string>
#include <yaml-cpp/yaml.h>
#ifdef USE_OPENVINO
#include "tasks/auto_aim/armor_detect/openvino/armor_detector_openvino.hpp"
#endif
#ifdef USE_TRT
#include "tasks/auto_aim/armor_detect/tensorrt/armor_detector_tensorrt.hpp"
#endif
#ifdef USE_NCNN
#include "tasks/auto_aim/armor_detect/ncnn/armor_detector_ncnn.hpp"
#endif
#ifdef USE_ORT
#include "tasks/auto_aim/armor_detect/onnxruntime/armor_detector_onnxruntime.hpp"
#endif
#include "tasks/auto_aim/armor_detect/opencv/armor_detector_opencv.hpp"
namespace wust_vision {
namespace auto_aim {
class DetectorFactory {
public:
static ArmorDetectorBase::Ptr createArmorDetector(
const std::string& backend,
bool use_armor_detect_common,
std::string cv_config_path = OPENCV_CONFIG,
std::string ml_config_path = ML_CONFIG
) {
// 检查编译时是否支持
auto isBackendEnabled = [&backend]() -> bool {
#ifdef USE_OPENVINO
if (backend == "openvino")
return true;
#endif
#ifdef USE_TRT
if (backend == "tensorrt")
return true;
#endif
#ifdef USE_NCNN
if (backend == "ncnn")
return true;
#endif
#ifdef USE_ORT
if (backend == "onnxruntime")
return true;
#endif
if (backend == "opencv")
return true;
return false;
};
if (!isBackendEnabled()) {
std::cout << "Backend " << backend << " is not enabled at compile time."
<< std::endl;
throw std::runtime_error("Backend " + backend + " is not enabled at compile time.");
}
auto getConfigPath = [&](const std::string& backend) -> std::string {
if (backend == "opencv")
return cv_config_path;
else
return ml_config_path;
};
std::string config_path = getConfigPath(backend);
if (config_path.empty()) {
std::cout << "No config path for backend: " << backend << std::endl;
throw std::runtime_error("No config path for backend: " + backend);
}
YAML::Node armor_detect_config = YAML::LoadFile(config_path);
// 创建对应后端实例
#if defined(USE_OPENVINO)
if (backend == "openvino") {
return ArmorDetectorOpenVino::create(
armor_detect_config["armor_detector"],
use_armor_detect_common
);
}
#endif
#if defined(USE_TRT)
if (backend == "tensorrt") {
return ArmorDetectorTrt::create(
armor_detect_config["armor_detector"],
use_armor_detect_common
);
}
#endif
#if defined(USE_NCNN)
if (backend == "ncnn") {
return ArmorDetectorNCNN::create(
armor_detect_config["armor_detector"],
use_armor_detect_common
);
}
#endif
#if defined(USE_ORT)
if (backend == "onnxruntime") {
return ArmorDetectorOnnxRuntime::create(
armor_detect_config["armor_detector"],
use_armor_detect_common
);
}
#endif
if (backend == "opencv") {
return ArmorDetectorOpenCV::create(armor_detect_config["armor_detector"]);
}
std::cout << "Unsupported armor detector backend (or not compiled): " << backend
<< std::endl;
throw std::runtime_error(
"Unsupported armor detector backend (or not compiled): " + backend
);
}
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,198 @@
#include "armor_infer.hpp"
namespace wust_vision::auto_aim::armor_infer {
struct GridAndStride {
int grid0;
int grid1;
int stride;
};
[[nodiscard]] static inline std::vector<GridAndStride>
generate_grids_and_stride(int target_w, int target_h, const std::vector<int>& strides) noexcept {
std::vector<GridAndStride> grid_strides;
for (int stride: strides) {
const int num_w = target_w / stride;
const int num_h = target_h / stride;
grid_strides.reserve(grid_strides.size() + num_w * num_h);
for (int gy = 0; gy < num_h; ++gy) {
for (int gx = 0; gx < num_w; ++gx) {
grid_strides.push_back(GridAndStride { gx, gy, stride });
}
}
}
return grid_strides;
}
std::vector<ArmorObject> ArmorInfer::postProcessTUP_impl(const cv::Mat& out) const {
static std::optional<std::vector<GridAndStride>> _grid_strides;
if (!_grid_strides) {
_grid_strides = generate_grids_and_stride(inputW(), inputH(), { 8, 16, 32 });
}
const auto& grid_strides = _grid_strides.value();
std::vector<ArmorObject> out_objs;
const int num_anchors =
static_cast<int>(std::min<size_t>(grid_strides.size(), static_cast<size_t>(out.rows)));
for (int a = 0; a < num_anchors; ++a) {
const float confidence = out.at<float>(a, 8);
if (confidence < conf_threshold_)
continue;
const auto& gs = grid_strides[a];
const int gx = gs.grid0, gy = gs.grid1, stride = gs.stride;
// color & class
const int color_offset = 9;
const int num_colors = ModelTraits<Mode::TUP>::NUM_COLORS;
const int num_classes = ModelTraits<Mode::TUP>::NUM_CLASSES;
cv::Mat color_scores = out.row(a).colRange(color_offset, color_offset + num_colors);
cv::Mat class_scores =
out.row(a).colRange(color_offset + num_colors, color_offset + num_colors + num_classes);
double max_color, max_class;
cv::Point color_id, class_id;
cv::minMaxLoc(color_scores, nullptr, &max_color, nullptr, &color_id);
cv::minMaxLoc(class_scores, nullptr, &max_class, nullptr, &class_id);
const float x1 = (out.at<float>(a, 0) + gx) * stride;
const float y1 = (out.at<float>(a, 1) + gy) * stride;
const float x2 = (out.at<float>(a, 2) + gx) * stride;
const float y2 = (out.at<float>(a, 3) + gy) * stride;
const float x3 = (out.at<float>(a, 4) + gx) * stride;
const float y3 = (out.at<float>(a, 5) + gy) * stride;
const float x4 = (out.at<float>(a, 6) + gx) * stride;
const float y4 = (out.at<float>(a, 7) + gy) * stride;
ArmorObject obj;
obj.pts = { cv::Point2f(x1, y1),
cv::Point2f(x2, y2),
cv::Point2f(x3, y3),
cv::Point2f(x4, y4) };
obj.box = cv::boundingRect(obj.pts);
obj.color = static_cast<ArmorColor>(color_id.x);
obj.number = static_cast<ArmorNumber>(class_id.x);
obj.confidence = confidence;
out_objs.push_back(std::move(obj));
}
return topKAndNms(out_objs, top_k_, nms_threshold_);
}
std::vector<ArmorObject> ArmorInfer::postProcessRP_impl(const cv::Mat& out) const {
std::vector<ArmorObject> out_objs;
const int rows = out.rows;
const int color_offset = 9;
const int num_colors = ModelTraits<Mode::RP>::NUM_COLORS;
const int num_classes = ModelTraits<Mode::RP>::NUM_CLASSES;
for (int r = 0; r < rows; ++r) {
float conf_raw = out.at<float>(r, 8);
const float confidence = static_cast<float>(sigmoid(conf_raw));
if (confidence < conf_threshold_)
continue;
cv::Mat color_scores = out.row(r).colRange(color_offset, color_offset + num_colors);
cv::Mat class_scores =
out.row(r).colRange(color_offset + num_colors, color_offset + num_colors + num_classes);
double max_color_score, max_class_score;
cv::Point color_id, class_id;
cv::minMaxLoc(color_scores, nullptr, &max_color_score, nullptr, &color_id);
cv::minMaxLoc(class_scores, nullptr, &max_class_score, nullptr, &class_id);
ArmorObject obj;
obj.pts.resize(4);
for (int k = 0; k < 4; ++k) {
const float x = out.at<float>(r, 0 + k * 2);
const float y = out.at<float>(r, 1 + k * 2);
obj.pts[k] = cv::Point2f(x, y);
}
obj.box = cv::boundingRect(obj.pts);
obj.color = static_cast<ArmorColor>(color_id.x);
obj.number = static_cast<ArmorNumber>(class_id.x);
obj.confidence = confidence;
out_objs.push_back(std::move(obj));
}
return topKAndNms(out_objs, top_k_, nms_threshold_);
}
std::vector<ArmorObject> ArmorInfer::postProcessAT_impl(const cv::Mat& out) const {
std::vector<ArmorObject> out_objs;
constexpr int nkpt = ModelTraits<Mode::AT>::NUM_KPTS;
constexpr int nk = nkpt * 2; // keypoints flattened
auto max_det = out.rows;
auto det_dim = out.cols;
auto output_ptr = out.ptr<float>();
for (int i = 0; i < max_det; ++i) {
const float* row = output_ptr + i * det_dim;
float conf = row[4];
if (!std::isfinite(conf) || conf < conf_threshold_)
continue;
float x1 = row[0];
float y1 = row[1];
float x2 = row[2];
float y2 = row[3];
int cls = static_cast<int>(row[5]);
if (!std::isfinite(x1) || !std::isfinite(y1) || !std::isfinite(x2) || !std::isfinite(y2)
|| x2 <= x1 || y2 <= y1)
continue;
ArmorObject obj;
obj.box = cv::Rect2f(x1, y1, x2 - x1, y2 - y1);
obj.confidence = conf;
auto color_num = ModelTraits<Mode::AT>::CLASSES[cls];
obj.color = color_num.first;
obj.number = color_num.second;
obj.pts.reserve(nkpt);
for (int k = 0; k < nkpt; ++k) {
float kx = row[6 + 2 * k];
float ky = row[6 + 2 * k + 1];
obj.pts.emplace_back(kx, ky);
}
out_objs.emplace_back(std::move(obj));
}
return out_objs;
}
std::vector<ArmorObject> ArmorInfer::postProcessBOX_impl(const cv::Mat& out) const {
std::vector<ArmorObject> out_objs;
auto max_det = out.rows;
auto det_dim = out.cols;
auto output_ptr = out.ptr<float>();
for (int i = 0; i < max_det; ++i) {
const float* row = output_ptr + i * det_dim;
float conf = row[4];
if (!std::isfinite(conf) || conf < conf_threshold_)
continue;
float x1 = row[0];
float y1 = row[1];
float x2 = row[2];
float y2 = row[3];
int cls = static_cast<int>(row[5]);
if (!std::isfinite(x1) || !std::isfinite(y1) || !std::isfinite(x2) || !std::isfinite(y2)
|| x2 <= x1 || y2 <= y1)
continue;
ArmorObject obj;
obj.box = cv::Rect2f(x1, y1, x2 - x1, y2 - y1);
obj.confidence = conf;
auto color_num = ModelTraits<Mode::BOX>::CLASSES[cls];
obj.color = color_num.first;
obj.number = color_num.second;
std::vector<cv::Point2f> pts;
pts.resize(4);
pts[0] = cv::Point2f(obj.box.x + obj.box.width, obj.box.y + obj.box.height); // 右下
pts[1] = cv::Point2f(obj.box.x + obj.box.width, obj.box.y); // 右上
pts[2] = cv::Point2f(obj.box.x, obj.box.y); // 左上
pts[3] = cv::Point2f(obj.box.x, obj.box.y + obj.box.height); // 左下
obj.pts = std::move(pts);
out_objs.emplace_back(std::move(obj));
}
return out_objs;
}
} // namespace wust_vision::auto_aim::armor_infer

Some files were not shown because too many files have changed in this diff Show More