add wust typr mpc and mutipule x

This commit is contained in:
cyy_mac
2026-03-27 03:41:42 +08:00
parent 2c64655fae
commit 7dcb53bb77
192 changed files with 29571 additions and 9 deletions

View File

@@ -0,0 +1,43 @@
add_library(tinympcstatic STATIC
admm.cpp
tiny_api.cpp
codegen.cpp
rho_benchmark.cpp
)
set_property(TARGET tinympcstatic PROPERTY POSITION_INDEPENDENT_CODE ON)
# target_link_libraries(tinympcstatic PUBLIC Eigen)
target_include_directories(tinympcstatic PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/..)
target_include_directories(tinympcstatic PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../include/Eigen)
if(USING_CODEGEN) # Defined in top-level CMakeLists.txt
# Files that are needed for embedded code generation
list( APPEND EMBEDDED_FILES
"${CMAKE_CURRENT_SOURCE_DIR}/admm.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/admm.hpp"
"${CMAKE_CURRENT_SOURCE_DIR}/tiny_api.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/tiny_api.hpp"
"${CMAKE_CURRENT_SOURCE_DIR}/types.hpp"
"${CMAKE_CURRENT_SOURCE_DIR}/tiny_api_constants.hpp" )
foreach( f ${EMBEDDED_FILES} )
get_filename_component( fname ${f} NAME )
set( dest_file "${EMBEDDED_BUILD_TINYMPC_DIR}/${fname}" )
list( APPEND EMBEDDED_BUILD_TINYMPC_FILES "${dest_file}" )
add_custom_command(OUTPUT ${dest_file}
COMMAND ${CMAKE_COMMAND} -E copy "${f}" "${dest_file}"
DEPENDS ${f}
COMMENT "Copying ${fname}")
endforeach()
add_custom_target( copy_codegen_tinympc_files DEPENDS ${EMBEDDED_BUILD_TINYMPC_FILES} )
add_dependencies( copy_codegen_files copy_codegen_tinympc_files )
endif(USING_CODEGEN)

View File

@@ -0,0 +1,399 @@
#include <iostream>
#include "admm.hpp"
#include "rho_benchmark.hpp"
#define DEBUG_MODULE "TINYALG"
extern "C" {
/**
* Update linear terms from Riccati backward pass
*/
void backward_pass_grad(TinySolver* solver) {
for (int i = solver->work->N - 2; i >= 0; i--) {
(solver->work->d.col(i)).noalias() = solver->cache->Quu_inv
* (solver->work->Bdyn.transpose() * solver->work->p.col(i + 1) + solver->work->r.col(i)
+ solver->cache->BPf);
(solver->work->p.col(i)).noalias() = solver->work->q.col(i)
+ solver->cache->AmBKt.lazyProduct(solver->work->p.col(i + 1))
- (solver->cache->Kinf.transpose()).lazyProduct(solver->work->r.col(i))
+ solver->cache->APf;
}
}
/**
* Use LQR feedback policy to roll out trajectory
*/
void forward_pass(TinySolver* solver) {
for (int i = 0; i < solver->work->N - 1; i++) {
(solver->work->u.col(i)).noalias() =
-solver->cache->Kinf.lazyProduct(solver->work->x.col(i)) - solver->work->d.col(i);
(solver->work->x.col(i + 1)).noalias() =
solver->work->Adyn.lazyProduct(solver->work->x.col(i))
+ solver->work->Bdyn.lazyProduct(solver->work->u.col(i)) + solver->work->fdyn;
}
}
/**
* Project a vector s onto the second order cone defined by mu
* @param s, mu
* @return projection onto cone if s is outside cone. Return s if s is inside cone.
*/
tinyVector project_soc(tinyVector s, float mu) {
tinytype u0 = s(Eigen::placeholders::last) * mu;
tinyVector u1 = s.head(s.rows() - 1);
float a = u1.norm();
tinyVector cone_origin(s.rows());
cone_origin.setZero();
if (a <= -u0) { // below cone
return cone_origin;
} else if (a <= u0) { // in cone
return s;
} else if (a >= abs(u0)) { // outside cone
Matrix<tinytype, 3, 1> u2(u1.size() + 1);
u2 << u1, a / mu;
return 0.5 * (1 + u0 / a) * u2;
} else {
return cone_origin;
}
}
/**
* Project a vector z onto a hyperplane defined by a^T z = b
* Implements equation (21): ΠH(z) = z - (⟨z, a⟩ b)/||a||² * a
* @param z Vector to project
* @param a Normal vector of the hyperplane
* @param b Offset of the hyperplane
* @return Projection of z onto the hyperplane
*/
tinyVector project_hyperplane(const tinyVector& z, const tinyVector& a, tinytype b) {
tinytype dist = (a.dot(z) - b) / a.squaredNorm();
return z - dist * a;
}
/**
* Project slack (auxiliary) variables into their feasible domain, defined by
* projection functions related to each constraint
* TODO: pass in meta information with each constraint assigning it to a
* projection function
*/
void update_slack(TinySolver* solver) {
// Update bound constraint slack variables for state
solver->work->vnew = solver->work->x + solver->work->g;
// Update bound constraint slack variables for input
solver->work->znew = solver->work->u + solver->work->y;
// Box constraints on state
if (solver->settings->en_state_bound) {
solver->work->vnew =
solver->work->x_max.cwiseMin(solver->work->x_min.cwiseMax(solver->work->vnew));
}
// Box constraints on input
if (solver->settings->en_input_bound) {
solver->work->znew =
solver->work->u_max.cwiseMin(solver->work->u_min.cwiseMax(solver->work->znew));
}
// Update second order cone slack variables for state
if (solver->settings->en_state_soc && solver->work->numStateCones > 0) {
solver->work->vcnew = solver->work->x + solver->work->gc;
}
// Update second order cone slack variables for input
if (solver->settings->en_input_soc && solver->work->numInputCones > 0) {
solver->work->zcnew = solver->work->u + solver->work->yc;
}
// Cone constraints on state
if (solver->settings->en_state_soc) {
for (int i = 0; i < solver->work->N; i++) {
for (int k = 0; k < solver->work->numStateCones; k++) {
int start = solver->work->Acx(k);
int num_xs = solver->work->qcx(k);
tinytype mu = solver->work->cx(k);
tinyVector col = solver->work->vcnew.block(start, i, num_xs, 1);
solver->work->vcnew.block(start, i, num_xs, 1) = project_soc(col, mu);
}
}
}
// Cone constraints on input
if (solver->settings->en_input_soc) {
for (int i = 0; i < solver->work->N - 1; i++) {
for (int k = 0; k < solver->work->numInputCones; k++) {
int start = solver->work->Acu(k);
int num_us = solver->work->qcu(k);
tinytype mu = solver->work->cu(k);
tinyVector col = solver->work->zcnew.block(start, i, num_us, 1);
solver->work->zcnew.block(start, i, num_us, 1) = project_soc(col, mu);
}
}
}
// Update linear constraint slack variables for state
if (solver->settings->en_state_linear) {
solver->work->vlnew = solver->work->x + solver->work->gl;
}
// Update linear constraint slack variables for input
if (solver->settings->en_input_linear) {
solver->work->zlnew = solver->work->u + solver->work->yl;
}
// Linear constraints on state
if (solver->settings->en_state_linear) {
for (int i = 0; i < solver->work->N; i++) {
for (int k = 0; k < solver->work->numStateLinear; k++) {
tinyVector a = solver->work->Alin_x.row(k);
tinytype b = solver->work->blin_x(k);
tinytype constraint_value = a.dot(solver->work->vlnew.col(i));
if (constraint_value > b) { // Only project if constraint is violated
solver->work->vlnew.col(i) =
project_hyperplane(solver->work->vlnew.col(i), a, b);
}
}
}
}
// Linear constraints on input
if (solver->settings->en_input_linear) {
for (int i = 0; i < solver->work->N - 1; i++) {
for (int k = 0; k < solver->work->numInputLinear; k++) {
tinyVector a = solver->work->Alin_u.row(k);
tinytype b = solver->work->blin_u(k);
tinytype constraint_value = a.dot(solver->work->zlnew.col(i));
if (constraint_value > b) { // Only project if constraint is violated
solver->work->zlnew.col(i) =
project_hyperplane(solver->work->zlnew.col(i), a, b);
}
}
}
}
}
/**
* Update next iteration of dual variables by performing the augmented
* lagrangian multiplier update
*/
void update_dual(TinySolver* solver) {
// Update bound constraint dual variables for state
solver->work->g = solver->work->g + solver->work->x - solver->work->vnew;
// Update bound constraint dual variables for input
solver->work->y = solver->work->y + solver->work->u - solver->work->znew;
// Update second order cone dual variables for state
if (solver->settings->en_state_soc && solver->work->numStateCones > 0) {
solver->work->gc = solver->work->gc + solver->work->x - solver->work->vcnew;
}
// Update second order cone dual variables for input
if (solver->settings->en_input_soc && solver->work->numInputCones > 0) {
solver->work->yc = solver->work->yc + solver->work->u - solver->work->zcnew;
}
// Update linear constraint dual variables for state
if (solver->settings->en_state_linear) {
solver->work->gl = solver->work->gl + solver->work->x - solver->work->vlnew;
}
// Update linear constraint dual variables for input
if (solver->settings->en_input_linear) {
solver->work->yl = solver->work->yl + solver->work->u - solver->work->zlnew;
}
}
/**
* Update linear control cost terms in the Riccati feedback using the changing
* slack and dual variables from ADMM
*/
void update_linear_cost(TinySolver* solver) {
// Update state cost terms
solver->work->q = -(solver->work->Xref.array().colwise() * solver->work->Q.array());
(solver->work->q).noalias() -= solver->cache->rho * (solver->work->vnew - solver->work->g);
if (solver->settings->en_state_soc && solver->work->numStateCones > 0) {
(solver->work->q).noalias() -=
solver->cache->rho * (solver->work->vcnew - solver->work->gc);
}
if (solver->settings->en_state_linear) {
(solver->work->q).noalias() -=
solver->cache->rho * (solver->work->vlnew - solver->work->gl);
}
// Update input cost terms
solver->work->r = -(solver->work->Uref.array().colwise() * solver->work->R.array());
(solver->work->r).noalias() -= solver->cache->rho * (solver->work->znew - solver->work->y);
if (solver->settings->en_input_soc && solver->work->numInputCones > 0) {
(solver->work->r).noalias() -=
solver->cache->rho * (solver->work->zcnew - solver->work->yc);
}
if (solver->settings->en_input_linear) {
(solver->work->r).noalias() -=
solver->cache->rho * (solver->work->zlnew - solver->work->yl);
}
// Update terminal cost
solver->work->p.col(solver->work->N - 1) =
-(solver->work->Xref.col(solver->work->N - 1).transpose().lazyProduct(solver->cache->Pinf));
(solver->work->p.col(solver->work->N - 1)).noalias() -= solver->cache->rho
* (solver->work->vnew.col(solver->work->N - 1) - solver->work->g.col(solver->work->N - 1));
if (solver->settings->en_state_soc && solver->work->numStateCones > 0) {
solver->work->p.col(solver->work->N - 1) -= solver->cache->rho
* (solver->work->vcnew.col(solver->work->N - 1)
- solver->work->gc.col(solver->work->N - 1));
}
if (solver->settings->en_state_linear) {
solver->work->p.col(solver->work->N - 1) -= solver->cache->rho
* (solver->work->vlnew.col(solver->work->N - 1)
- solver->work->gl.col(solver->work->N - 1));
}
}
/**
* Check for termination condition by evaluating whether the largest absolute
* primal and dual residuals for states and inputs are below threhold.
*/
bool termination_condition(TinySolver* solver) {
if (solver->work->iter % solver->settings->check_termination == 0) {
solver->work->primal_residual_state =
(solver->work->x - solver->work->vnew).cwiseAbs().maxCoeff();
solver->work->dual_residual_state =
((solver->work->v - solver->work->vnew).cwiseAbs().maxCoeff()) * solver->cache->rho;
solver->work->primal_residual_input =
(solver->work->u - solver->work->znew).cwiseAbs().maxCoeff();
solver->work->dual_residual_input =
((solver->work->z - solver->work->znew).cwiseAbs().maxCoeff()) * solver->cache->rho;
if (solver->work->primal_residual_state < solver->settings->abs_pri_tol
&& solver->work->primal_residual_input < solver->settings->abs_pri_tol
&& solver->work->dual_residual_state < solver->settings->abs_dua_tol
&& solver->work->dual_residual_input < solver->settings->abs_dua_tol)
{
return true;
}
}
return false;
}
int solve(TinySolver* solver) {
// Initialize variables
solver->solution->solved = 0;
solver->solution->iter = 0;
solver->work->status = 11; // TINY_UNSOLVED
solver->work->iter = 0;
// Setup for adaptive rho
RhoAdapter adapter;
adapter.rho_min = solver->settings->adaptive_rho_min;
adapter.rho_max = solver->settings->adaptive_rho_max;
adapter.clip = solver->settings->adaptive_rho_enable_clipping;
RhoBenchmarkResult rho_result;
// Store previous values for residuals
tinyMatrix v_prev = solver->work->vnew;
tinyMatrix z_prev = solver->work->znew;
// Initialize SOC slack variables if needed
if (solver->settings->en_state_soc && solver->work->numStateCones > 0) {
solver->work->vcnew = solver->work->x;
}
if (solver->settings->en_input_soc && solver->work->numInputCones > 0) {
solver->work->zcnew = solver->work->u;
}
// Initialize linear constraint slack variables if needed
if (solver->settings->en_state_linear) {
solver->work->vlnew = solver->work->x;
}
if (solver->settings->en_input_linear) {
solver->work->zlnew = solver->work->u;
}
for (int i = 0; i < solver->settings->max_iter; i++) {
// Solve linear system with Riccati and roll out to get new trajectory
backward_pass_grad(solver);
forward_pass(solver);
// Project slack variables into feasible domain
update_slack(solver);
// Compute next iteration of dual variables
update_dual(solver);
// Update linear control cost terms using reference trajectory, duals, and slack variables
update_linear_cost(solver);
solver->work->iter += 1;
// Handle adaptive rho if enabled
if (solver->settings->adaptive_rho) {
// Calculate residuals for adaptive rho
tinytype pri_res_input = (solver->work->u - solver->work->znew).cwiseAbs().maxCoeff();
tinytype pri_res_state = (solver->work->x - solver->work->vnew).cwiseAbs().maxCoeff();
tinytype dua_res_input =
solver->cache->rho * (solver->work->znew - z_prev).cwiseAbs().maxCoeff();
tinytype dua_res_state =
solver->cache->rho * (solver->work->vnew - v_prev).cwiseAbs().maxCoeff();
// Update rho every 5 iterations
if (i > 0 && i % 5 == 0) {
benchmark_rho_adaptation(
&adapter,
solver->work->x,
solver->work->u,
solver->work->vnew,
solver->work->znew,
solver->work->g,
solver->work->y,
solver->cache,
solver->work,
solver->work->N,
&rho_result
);
// Update matrices using Taylor expansion
update_matrices_with_derivatives(solver->cache, rho_result.final_rho);
}
}
// Store previous values for next iteration
z_prev = solver->work->znew;
v_prev = solver->work->vnew;
// Check for whether cost is minimized by calculating residuals
if (termination_condition(solver)) {
solver->work->status = 1; // TINY_SOLVED
// Save solution
solver->solution->iter = solver->work->iter;
solver->solution->solved = 1;
solver->solution->x = solver->work->vnew;
solver->solution->u = solver->work->znew;
// std::cout << "Solver converged in " << solver->work->iter << " iterations" << std::endl;
return 0;
}
// Save previous slack variables
solver->work->v = solver->work->vnew;
solver->work->z = solver->work->znew;
}
solver->solution->iter = solver->work->iter;
solver->solution->solved = 0;
solver->solution->x = solver->work->vnew;
solver->solution->u = solver->work->znew;
return 1;
}
} /* extern "C" */

View File

@@ -0,0 +1,37 @@
#pragma once
#include "types.hpp"
#ifdef __cplusplus
extern "C" {
#endif
int solve(TinySolver* solver);
void update_primal(TinySolver* solver);
void backward_pass_grad(TinySolver* solver);
void forward_pass(TinySolver* solver);
void update_slack(TinySolver* solver);
void update_dual(TinySolver* solver);
void update_linear_cost(TinySolver* solver);
bool termination_condition(TinySolver* solver);
/**
* Project a vector s onto the second order cone defined by mu
* @param s, mu
* @return projection onto cone if s is outside cone. Return s if s is inside cone.
*/
tinyVector project_soc(tinyVector s, float mu);
/**
* Project a vector z onto a hyperplane defined by a^T z = b
* Implements equation (21): ΠH(z) = z - (⟨z, a⟩ b)/||a||² * a
* @param z Vector to project
* @param a Normal vector of the hyperplane
* @param b Offset of the hyperplane
* @return Projection of z onto the hyperplane
*/
tinyVector project_hyperplane(const tinyVector& z, const tinyVector& a, tinytype b);
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,466 @@
#include <ctype.h>
#include <dirent.h>
#include <stdio.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <time.h>
#include <unistd.h>
//#include <error.h>
#include "error.hpp"
#include <Eigen/Dense>
#include <iostream>
// #include "types.hpp"
#include "codegen.hpp"
#ifdef __MINGW32__
#include <direct.h>
inline int mkdir(const char* pathname, int flags) {
return _mkdir(pathname);
}
#endif
#ifdef __cplusplus
extern "C" {
#endif
/* Define the maximum allowed length of the path (directory + filename + extension) */
#define PATH_LENGTH 2048
using namespace Eigen;
static void print_matrix(FILE* f, MatrixXd mat, int num_elements) {
// Check if matrix is uninitialized or too small
if (mat.size() == 0 || mat.size() < num_elements) {
// Print zeros for all elements
for (int i = 0; i < num_elements; i++) {
fprintf(f, "(tinytype)0.0000000000000000");
if (i < num_elements - 1)
fprintf(f, ",");
}
return;
}
// Matrix is properly initialized and has enough elements
for (int i = 0; i < num_elements; i++) {
fprintf(f, "(tinytype)%.16f", mat.reshaped<RowMajor>()[i]);
if (i < num_elements - 1)
fprintf(f, ",");
}
}
static void create_directory(const char* dir, int verbose) {
// Attempt to create directory
if (mkdir(dir, S_IRWXU | S_IRWXG | S_IROTH)) {
if (errno == EEXIST) { // Skip if directory already exists
if (verbose)
std::cout << dir << " already exists, skipping." << std::endl;
} else {
ERROR_MSG(EXIT_FAILURE, "Failed to create directory %s", dir);
}
}
}
// TODO: Make this fail if tiny_setup has not already been called
int tiny_codegen(TinySolver* solver, const char* output_dir, int verbose) {
if (!solver) {
std::cout << "Error in tiny_codegen: solver is nullptr" << std::endl;
return 1;
}
int status = 0;
status |= codegen_create_directories(output_dir, verbose);
status |= codegen_data_header(output_dir, verbose);
status |= codegen_data_source(solver, output_dir, verbose);
status |= codegen_example(output_dir, verbose);
return status;
}
int tiny_codegen_with_sensitivity(
TinySolver* solver,
const char* output_dir,
tinyMatrix* dK,
tinyMatrix* dP,
tinyMatrix* dC1,
tinyMatrix* dC2,
int verbose
) {
if (!solver) {
std::cout << "Error in tiny_codegen_with_sensitivity: solver is nullptr" << std::endl;
return 1;
}
// Only store sensitivity matrices if adaptive rho is enabled
if (solver->settings->adaptive_rho) {
// Store the sensitivity matrices in the solver's cache
solver->cache->dKinf_drho = *dK;
solver->cache->dPinf_drho = *dP;
solver->cache->dC1_drho = *dC1;
solver->cache->dC2_drho = *dC2;
}
// Call the regular codegen function which will now include the sensitivity matrices if adaptive_rho is enabled
return tiny_codegen(solver, output_dir, verbose);
}
// Create code generation folder structure in whichever directory the executable calling tiny_codegen was called
int codegen_create_directories(const char* output_dir, int verbose) {
// Create output folder (root folder for code generation)
create_directory(output_dir, verbose);
// Create src folder
char src_dir[PATH_LENGTH];
sprintf(src_dir, "%s/src/", output_dir);
create_directory(src_dir, verbose);
// Create tinympc folder
char tinympc_dir[PATH_LENGTH];
sprintf(tinympc_dir, "%s/tinympc/", output_dir);
create_directory(tinympc_dir, verbose);
// // Create include folder
// char inc_dir[PATH_LENGTH];
// sprintf(inc_dir, "%s/include/", output_dir);
// create_directory(inc_dir, verbose);
return EXIT_SUCCESS;
}
// Create inc/tiny_data.hpp file
int codegen_data_header(const char* output_dir, int verbose) {
char data_hpp_fname[PATH_LENGTH];
FILE* data_hpp_f;
sprintf(data_hpp_fname, "%s/tinympc/tiny_data.hpp", output_dir);
// Open data header file
data_hpp_f = fopen(data_hpp_fname, "w+");
if (data_hpp_f == NULL)
ERROR_MSG(EXIT_FAILURE, "Failed to open file %s", data_hpp_fname);
// Preamble
time_t start_time;
time(&start_time);
fprintf(data_hpp_f, "/*\n");
fprintf(data_hpp_f, " * This file was autogenerated by TinyMPC on %s", ctime(&start_time));
fprintf(data_hpp_f, " */\n\n");
fprintf(data_hpp_f, "#pragma once\n\n");
fprintf(data_hpp_f, "#include \"types.hpp\"\n\n");
fprintf(data_hpp_f, "#ifdef __cplusplus\n");
fprintf(data_hpp_f, "extern \"C\" {\n");
fprintf(data_hpp_f, "#endif\n\n");
fprintf(data_hpp_f, "extern TinySolver tiny_solver;\n\n");
fprintf(data_hpp_f, "#ifdef __cplusplus\n");
fprintf(data_hpp_f, "}\n");
fprintf(data_hpp_f, "#endif\n");
// Close codegen data header file
fclose(data_hpp_f);
if (verbose) {
printf("Data header generated in %s\n", data_hpp_fname);
}
return 0;
}
// Create src/tiny_data.cpp file
int codegen_data_source(TinySolver* solver, const char* output_dir, int verbose) {
char data_cpp_fname[PATH_LENGTH];
FILE* data_cpp_f;
int nx = solver->work->nx;
int nu = solver->work->nu;
int N = solver->work->N;
sprintf(data_cpp_fname, "%s/src/tiny_data.cpp", output_dir);
// Open data source file
data_cpp_f = fopen(data_cpp_fname, "w+");
if (data_cpp_f == NULL)
ERROR_MSG(EXIT_FAILURE, "Failed to open file %s", data_cpp_fname);
// Preamble
time_t start_time;
time(&start_time);
fprintf(data_cpp_f, "/*\n");
fprintf(data_cpp_f, " * This file was autogenerated by TinyMPC on %s", ctime(&start_time));
fprintf(data_cpp_f, " */\n\n");
// Open extern C
fprintf(data_cpp_f, "#include \"tinympc/tiny_data.hpp\"\n\n");
fprintf(data_cpp_f, "#ifdef __cplusplus\n");
fprintf(data_cpp_f, "extern \"C\" {\n");
fprintf(data_cpp_f, "#endif\n\n");
// Solution
fprintf(data_cpp_f, "/* Solution */\n");
fprintf(data_cpp_f, "TinySolution solution = {\n");
fprintf(data_cpp_f, "\t%d,\t\t// iter\n", solver->solution->iter);
fprintf(data_cpp_f, "\t%d,\t\t// solved\n", solver->solution->solved);
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
fprintf(data_cpp_f, ").finished(),\t// x\n"); // x solution
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// x\n"); // u solution
fprintf(data_cpp_f, "};\n\n");
// Cache
fprintf(data_cpp_f, "/* Matrices that must be recomputed with changes in time step, rho */\n");
fprintf(data_cpp_f, "TinyCache cache = {\n");
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// rho (step size/penalty)\n", solver->cache->rho);
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, nx);
print_matrix(data_cpp_f, solver->cache->Kinf, nu * nx);
fprintf(data_cpp_f, ").finished(),\t// Kinf\n"); // Kinf
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
print_matrix(data_cpp_f, solver->cache->Pinf, nx * nx);
fprintf(data_cpp_f, ").finished(),\t// Pinf\n"); // Pinf
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, nu);
print_matrix(data_cpp_f, solver->cache->Quu_inv, nu * nu);
fprintf(data_cpp_f, ").finished(),\t// Quu_inv\n"); // Quu_inv
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
print_matrix(data_cpp_f, solver->cache->AmBKt, nx * nx);
fprintf(data_cpp_f, ").finished(),\t// AmBKt\n"); // AmBKt
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
print_matrix(data_cpp_f, solver->cache->C1, nx * nx);
fprintf(data_cpp_f, ").finished(),\t// C1\n"); // C1
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
print_matrix(data_cpp_f, solver->cache->C2, nx * nx);
fprintf(data_cpp_f, ").finished()"); // C2, no comma if no sensitivity matrices
// Only print sensitivity matrices if adaptive rho is enabled
if (solver->settings->adaptive_rho) {
fprintf(data_cpp_f, ",\t// C2\n"); // Add comma and comment for C2 if we have more matrices
// Add sensitivity matrices within the struct initialization
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, nx);
print_matrix(data_cpp_f, solver->cache->dKinf_drho, nu * nx);
fprintf(data_cpp_f, ").finished(),\t// dKinf_drho\n"); // dKinf_drho
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
print_matrix(data_cpp_f, solver->cache->dPinf_drho, nx * nx);
fprintf(data_cpp_f, ").finished(),\t// dPinf_drho\n"); // dPinf_drho
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
print_matrix(data_cpp_f, solver->cache->dC1_drho, nx * nx);
fprintf(data_cpp_f, ").finished(),\t// dC1_drho\n"); // dC1_drho
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
print_matrix(data_cpp_f, solver->cache->dC2_drho, nx * nx);
fprintf(data_cpp_f, ").finished()\t// dC2_drho\n"); // dC2_drho
} else {
fprintf(data_cpp_f, "\t// C2\n"); // Just add comment for C2
}
fprintf(data_cpp_f, "};\n\n");
// Settings
fprintf(data_cpp_f, "/* User settings */\n");
fprintf(data_cpp_f, "TinySettings settings = {\n");
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// primal tolerance\n", solver->settings->abs_pri_tol);
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// dual tolerance\n", solver->settings->abs_dua_tol);
fprintf(data_cpp_f, "\t%d,\t\t// max iterations\n", solver->settings->max_iter);
fprintf(
data_cpp_f,
"\t%d,\t\t// iterations per termination check\n",
solver->settings->check_termination
);
fprintf(data_cpp_f, "\t%d,\t\t// enable state constraints\n", solver->settings->en_state_bound);
fprintf(data_cpp_f, "\t%d\t\t// enable input constraints\n", solver->settings->en_input_bound);
fprintf(data_cpp_f, "};\n\n");
// Workspace
fprintf(data_cpp_f, "/* Problem variables */\n");
fprintf(data_cpp_f, "TinyWorkspace work = {\n");
fprintf(data_cpp_f, "\t%d,\t// Number of states\n", nx);
fprintf(data_cpp_f, "\t%d,\t// Number of control inputs\n", nu);
fprintf(data_cpp_f, "\t%d,\t// Number of knotpoints in the horizon\n", N);
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
fprintf(data_cpp_f, ").finished(),\t// x\n"); // x
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// u\n"); // u
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
fprintf(data_cpp_f, ").finished(),\t// q\n"); // q
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// r\n"); // r
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
fprintf(data_cpp_f, ").finished(),\t// p\n"); // p
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// d\n"); // d
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
fprintf(data_cpp_f, ").finished(),\t// v\n"); // v
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
fprintf(data_cpp_f, ").finished(),\t// vnew\n"); // vnew
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// z\n"); // z
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// znew\n"); // znew
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
fprintf(data_cpp_f, ").finished(),\t// g\n"); // g
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// y\n"); // y
fprintf(data_cpp_f, "\t(tinyVector(%d) << ", nx);
print_matrix(data_cpp_f, solver->work->Q, nx);
fprintf(data_cpp_f, ").finished(),\t// Q\n"); // Q
fprintf(data_cpp_f, "\t(tinyVector(%d) << ", nu);
print_matrix(data_cpp_f, solver->work->R, nu);
fprintf(data_cpp_f, ").finished(),\t// R\n"); // R
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
print_matrix(data_cpp_f, solver->work->Adyn, nx * nx);
fprintf(data_cpp_f, ").finished(),\t// Adyn\n"); // Adyn
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nu);
print_matrix(data_cpp_f, solver->work->Bdyn, nx * nu);
fprintf(data_cpp_f, ").finished(),\t// Bdyn\n"); // Bdyn
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, solver->work->x_min, nx * N);
fprintf(data_cpp_f, ").finished(),\t// x_min\n"); // x_min
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, solver->work->x_max, nx * N);
fprintf(data_cpp_f, ").finished(),\t// x_max\n"); // x_max
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, solver->work->u_min, nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// u_min\n"); // u_min
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, solver->work->u_max, nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// u_max\n"); // u_max
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
fprintf(data_cpp_f, ").finished(),\t// Xref\n"); // Xref
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// Uref\n"); // Uref
fprintf(data_cpp_f, "\t(tinyVector(%d) << ", nu);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, 1), nu);
fprintf(data_cpp_f, ").finished(),\t// Qu\n"); // Qu
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// state primal residual\n", 0.0);
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// input primal residual\n", 0.0);
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// state dual residual\n", 0.0);
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// input dual residual\n", 0.0);
fprintf(data_cpp_f, "\t%d,\t// solve status\n", 0);
fprintf(data_cpp_f, "\t%d,\t// solve iteration\n", 0);
fprintf(data_cpp_f, "};\n\n");
// Write solver struct definition to workspace file
fprintf(data_cpp_f, "TinySolver tiny_solver = {&solution, &settings, &cache, &work};\n\n");
// Close extern C
fprintf(data_cpp_f, "#ifdef __cplusplus\n");
fprintf(data_cpp_f, "}\n");
fprintf(data_cpp_f, "#endif\n\n");
// Close codegen data file
fclose(data_cpp_f);
if (verbose) {
printf("Data generated in %s\n", data_cpp_fname);
}
return 0;
}
int codegen_example(const char* output_dir, int verbose) {
char example_cpp_fname[PATH_LENGTH];
FILE* example_cpp_f;
sprintf(example_cpp_fname, "%s/src/tiny_main.cpp", output_dir);
// Open example file
example_cpp_f = fopen(example_cpp_fname, "w+");
if (example_cpp_f == NULL)
ERROR_MSG(EXIT_FAILURE, "Failed to open file %s", example_cpp_fname);
// Preamble
time_t start_time;
time(&start_time);
fprintf(example_cpp_f, "/*\n");
fprintf(example_cpp_f, " * This file was autogenerated by TinyMPC on %s", ctime(&start_time));
fprintf(example_cpp_f, " */\n\n");
fprintf(example_cpp_f, "#include <iostream>\n\n");
fprintf(example_cpp_f, "#include <tinympc/tiny_api.hpp>\n");
fprintf(example_cpp_f, "#include <tinympc/tiny_data.hpp>\n\n");
fprintf(example_cpp_f, "using namespace Eigen;\n");
fprintf(example_cpp_f, "IOFormat TinyFmt(4, 0, \", \", \"\\n\", \"[\", \"]\");\n\n");
fprintf(example_cpp_f, "#ifdef __cplusplus\n");
fprintf(example_cpp_f, "extern \"C\" {\n");
fprintf(example_cpp_f, "#endif\n\n");
fprintf(example_cpp_f, "int main()\n");
fprintf(example_cpp_f, "{\n");
fprintf(example_cpp_f, "\tint exitflag = 1;\n");
fprintf(example_cpp_f, "\t// Double check some data\n");
fprintf(example_cpp_f, "\tstd::cout << \"rho: \" << tiny_solver.cache->rho << std::endl;\n");
fprintf(
example_cpp_f,
"\tstd::cout << \"\\nmax iters: \" << tiny_solver.settings->max_iter << std::endl;\n"
);
fprintf(
example_cpp_f,
"\tstd::cout << \"\\nState transition matrix:\\n\" << tiny_solver.work->Adyn.format(TinyFmt) << std::endl;\n"
);
fprintf(
example_cpp_f,
"\tstd::cout << \"\\nInput/control matrix:\\n\" << tiny_solver.work->Bdyn.format(TinyFmt) << std::endl;\n\n"
);
fprintf(
example_cpp_f,
"\t// Visit https://tinympc.org/ to see how to set the initial condition and update the reference trajectory.\n\n"
);
fprintf(example_cpp_f, "\tstd::cout << \"\\nSolving...\\n\" << std::endl;\n\n");
fprintf(example_cpp_f, "\texitflag = tiny_solve(&tiny_solver);\n\n");
fprintf(example_cpp_f, "\tif (exitflag == 0) printf(\"Hooray! Solved with no error!\\n\");\n");
fprintf(example_cpp_f, "\telse printf(\"Oops! Something went wrong!\\n\");\n");
fprintf(example_cpp_f, "\treturn 0;\n");
fprintf(example_cpp_f, "}\n\n");
fprintf(example_cpp_f, "#ifdef __cplusplus\n");
fprintf(example_cpp_f, "} /* extern \"C\" */\n");
fprintf(example_cpp_f, "#endif\n");
// Close codegen example main file
fclose(example_cpp_f);
if (verbose) {
printf("Example tinympc main generated in %s\n", example_cpp_fname);
}
return 0;
}
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,28 @@
#pragma once
#include "types.hpp"
#ifdef __cplusplus
extern "C" {
#endif
int tiny_codegen(TinySolver* solver, const char* output_dir, int verbose);
int tiny_codegen_with_sensitivity(
TinySolver* solver,
const char* output_dir,
tinyMatrix* dK,
tinyMatrix* dP,
tinyMatrix* dC1,
tinyMatrix* dC2,
int verbose
);
int codegen_create_directories(const char* output_dir, int verbose);
int codegen_data_header(const char* output_dir, int verbose);
int codegen_data_source(TinySolver* solver, const char* output_dir, int verbose);
int codegen_example(const char* output_dir, int verbose);
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,29 @@
#pragma once
#ifdef __cplusplus
extern "C" {
#endif
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
// #if defined(__linux__) || defined(__unix__)// Check if Linux
// #include <error.h>
// #define ERROR_MSG(exit_code, format, ...) error(exit_code, errno, format, ##__VA_ARGS__)
// #elif defined(__APPLE__) || defined(__MACH__) // Check if macOS
#define ERROR_MSG(exit_code, format, ...) \
{ \
fprintf(stderr, format ": %s\n", ##__VA_ARGS__, strerror(errno)); \
exit(exit_code); \
}
// #else
// #error "Unsupported operating system"
// #endif
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,252 @@
#include "rho_benchmark.hpp"
#include <algorithm>
#include <cmath>
#include <iostream>
#ifdef ARDUINO
#include <Arduino.h>
#else
// For non-Arduino platforms
uint32_t micros() {
return 0; // Replace with appropriate timing function
}
#endif
void initialize_format_matrices(RhoAdapter* adapter, int nx, int nu, int N) {
// Calculate dimensions
int x_decision_size = nx * N + nu * (N - 1);
int constraint_rows = (nx + nu) * (N - 1);
// Pre-allocate matrices
adapter->A_matrix = tinyMatrix::Zero(constraint_rows, x_decision_size);
adapter->z_vector = tinyMatrix::Zero(constraint_rows, 1);
adapter->y_vector = tinyMatrix::Zero(constraint_rows, 1);
adapter->x_decision = tinyMatrix::Zero(x_decision_size, 1);
// Pre-compute P matrix structure
adapter->P_matrix = tinyMatrix::Zero(x_decision_size, x_decision_size);
adapter->q_vector = tinyMatrix::Zero(x_decision_size, 1);
// Pre-allocate residual computation matrices
adapter->Ax_vector = tinyMatrix::Zero(constraint_rows, 1);
adapter->r_prim_vector = tinyMatrix::Zero(constraint_rows, 1);
adapter->r_dual_vector = tinyMatrix::Zero(x_decision_size, 1);
adapter->Px_vector = tinyMatrix::Zero(x_decision_size, 1);
adapter->ATy_vector = tinyMatrix::Zero(x_decision_size, 1);
// Store dimensions
adapter->format_nx = nx;
adapter->format_nu = nu;
adapter->format_N = N;
adapter->matrices_initialized = true;
}
void format_matrices(
RhoAdapter* adapter,
const tinyMatrix& x_prev,
const tinyMatrix& u_prev,
const tinyMatrix& v_prev,
const tinyMatrix& z_prev,
const tinyMatrix& g_prev,
const tinyMatrix& y_prev,
TinyCache* cache,
TinyWorkspace* work,
int N
) {
if (!adapter->matrices_initialized) {
initialize_format_matrices(adapter, x_prev.rows(), u_prev.rows(), N);
}
int nx = adapter->format_nx;
int nu = adapter->format_nu;
// Fill x_decision
int x_idx = 0;
for (int i = 0; i < N; i++) {
adapter->x_decision.block(x_idx, 0, nx, 1) = x_prev.col(i);
x_idx += nx;
if (i < N - 1) {
adapter->x_decision.block(x_idx, 0, nu, 1) = u_prev.col(i);
x_idx += nu;
}
}
// Clear A matrix for reuse
adapter->A_matrix.setZero();
// Fill A matrix with dynamics and input constraints
for (int i = 0; i < N - 1; i++) {
// Input constraints
int row_start = i * nu;
int col_start = i * (nx + nu) + nx;
adapter->A_matrix.block(row_start, col_start, nu, nu) = tinyMatrix::Identity(nu, nu);
// Dynamics constraints
row_start = (N - 1) * nu + i * nx;
col_start = i * (nx + nu);
adapter->A_matrix.block(row_start, col_start, nx, nx) = work->Adyn;
adapter->A_matrix.block(row_start, col_start + nx, nx, nu) = work->Bdyn;
int next_state_idx = col_start + nx + nu;
if (next_state_idx < adapter->A_matrix.cols()) {
adapter->A_matrix.block(row_start, next_state_idx, nx, nx) =
-tinyMatrix::Identity(nx, nx);
}
}
// Fill z and y vectors
for (int i = 0; i < N - 1; i++) {
adapter->z_vector.block(i * nu, 0, nu, 1) = z_prev.col(i);
adapter->z_vector.block((N - 1) * nu + i * nx, 0, nx, 1) = v_prev.col(i + 1);
adapter->y_vector.block(i * nu, 0, nu, 1) = y_prev.col(i);
adapter->y_vector.block((N - 1) * nu + i * nx, 0, nx, 1) = g_prev.col(i + 1);
}
// Build P matrix (cost matrix)
adapter->P_matrix.setZero();
// Fill diagonal blocks
x_idx = 0;
for (int i = 0; i < N; i++) {
// State cost
if (i == N - 1) {
adapter->P_matrix.block(x_idx, x_idx, nx, nx) = cache->Pinf;
} else {
adapter->P_matrix.block(x_idx, x_idx, nx, nx) = work->Q.asDiagonal();
}
x_idx += nx;
// Input cost
if (i < N - 1) {
adapter->P_matrix.block(x_idx, x_idx, nu, nu) = work->R.asDiagonal();
x_idx += nu;
}
}
// Create q vector (linear cost vector)
x_idx = 0;
for (int i = 0; i < N; i++) {
// For simplicity, we'll use zero reference for now
// In a real implementation, you'd use your reference trajectory
tinyMatrix x_ref = tinyMatrix::Zero(nx, 1);
tinyMatrix delta_x = x_prev.col(i) - x_ref;
adapter->q_vector.block(x_idx, 0, nx, 1) = work->Q.asDiagonal() * delta_x;
x_idx += nx;
if (i < N - 1) {
// For simplicity, we'll use zero reference for now
tinyMatrix u_ref = tinyMatrix::Zero(nu, 1);
tinyMatrix delta_u = u_prev.col(i) - u_ref;
adapter->q_vector.block(x_idx, 0, nu, 1) = work->R.asDiagonal() * delta_u;
x_idx += nu;
}
}
}
void compute_residuals(
RhoAdapter* adapter,
tinytype* pri_res,
tinytype* dual_res,
tinytype* pri_norm,
tinytype* dual_norm
) {
// Compute Ax
adapter->Ax_vector = adapter->A_matrix * adapter->x_decision;
// Compute primal residual
adapter->r_prim_vector = adapter->Ax_vector - adapter->z_vector;
*pri_res = adapter->r_prim_vector.cwiseAbs().maxCoeff();
*pri_norm =
std::max(adapter->Ax_vector.cwiseAbs().maxCoeff(), adapter->z_vector.cwiseAbs().maxCoeff());
// Compute dual residual components
adapter->Px_vector = adapter->P_matrix * adapter->x_decision;
adapter->ATy_vector = adapter->A_matrix.transpose() * adapter->y_vector;
// Compute full dual residual
adapter->r_dual_vector = adapter->Px_vector + adapter->q_vector + adapter->ATy_vector;
*dual_res = adapter->r_dual_vector.cwiseAbs().maxCoeff();
// Compute normalization
*dual_norm = std::max(
std::max(
adapter->Px_vector.cwiseAbs().maxCoeff(),
adapter->ATy_vector.cwiseAbs().maxCoeff()
),
adapter->q_vector.cwiseAbs().maxCoeff()
);
}
tinytype predict_rho(
RhoAdapter* adapter,
tinytype pri_res,
tinytype dual_res,
tinytype pri_norm,
tinytype dual_norm,
tinytype current_rho
) {
const tinytype eps = 1e-10;
tinytype normalized_pri = pri_res / (pri_norm + eps);
tinytype normalized_dual = dual_res / (dual_norm + eps);
tinytype ratio = normalized_pri / (normalized_dual + eps);
tinytype new_rho = current_rho * std::sqrt(ratio);
if (adapter->clip) {
new_rho = std::min(std::max(new_rho, adapter->rho_min), adapter->rho_max);
}
return new_rho;
}
void update_matrices_with_derivatives(TinyCache* cache, tinytype new_rho) {
tinytype delta_rho = new_rho - cache->rho;
cache->Kinf = cache->Kinf + delta_rho * cache->dKinf_drho;
cache->Pinf = cache->Pinf + delta_rho * cache->dPinf_drho;
cache->C1 = cache->C1 + delta_rho * cache->dC1_drho;
cache->C2 = cache->C2 + delta_rho * cache->dC2_drho;
cache->rho = new_rho;
}
void benchmark_rho_adaptation(
RhoAdapter* adapter,
const tinyMatrix& x_prev,
const tinyMatrix& u_prev,
const tinyMatrix& v_prev,
const tinyMatrix& z_prev,
const tinyMatrix& g_prev,
const tinyMatrix& y_prev,
TinyCache* cache,
TinyWorkspace* work,
int N,
RhoBenchmarkResult* result
) {
uint32_t start_time = micros();
// Format matrices
format_matrices(adapter, x_prev, u_prev, v_prev, z_prev, g_prev, y_prev, cache, work, N);
// Compute residuals
tinytype pri_res, dual_res, pri_norm, dual_norm;
compute_residuals(adapter, &pri_res, &dual_res, &pri_norm, &dual_norm);
// Predict new rho
tinytype new_rho = predict_rho(adapter, pri_res, dual_res, pri_norm, dual_norm, cache->rho);
// Update matrices
update_matrices_with_derivatives(cache, new_rho);
// Store results
result->time_us = micros() - start_time;
result->initial_rho = cache->rho;
result->final_rho = new_rho;
result->pri_res = pri_res;
result->dual_res = dual_res;
result->pri_norm = pri_norm;
result->dual_norm = dual_norm;
}

View File

@@ -0,0 +1,94 @@
#pragma once
#include "types.hpp"
#include <cstdint>
struct RhoAdapter {
tinytype rho_min;
tinytype rho_max;
bool clip;
bool matrices_initialized;
// Pre-allocated matrices for formatting
tinyMatrix A_matrix;
tinyMatrix z_vector;
tinyMatrix y_vector;
tinyMatrix x_decision;
tinyMatrix P_matrix;
tinyMatrix q_vector;
// Pre-allocated matrices for residual computation
tinyMatrix Ax_vector;
tinyMatrix r_prim_vector;
tinyMatrix r_dual_vector;
tinyMatrix Px_vector;
tinyMatrix ATy_vector;
// Dimensions
int format_nx;
int format_nu;
int format_N;
};
struct RhoBenchmarkResult {
uint32_t time_us;
tinytype initial_rho;
tinytype final_rho;
tinytype pri_res;
tinytype dual_res;
tinytype pri_norm;
tinytype dual_norm;
};
// Initialize matrices for formatting
void initialize_format_matrices(RhoAdapter* adapter, int nx, int nu, int N);
// Format matrices for residual computation
void format_matrices(
RhoAdapter* adapter,
const tinyMatrix& x_prev,
const tinyMatrix& u_prev,
const tinyMatrix& v_prev,
const tinyMatrix& z_prev,
const tinyMatrix& g_prev,
const tinyMatrix& y_prev,
TinyCache* cache,
TinyWorkspace* work,
int N
);
// Compute residuals
void compute_residuals(
RhoAdapter* adapter,
tinytype* pri_res,
tinytype* dual_res,
tinytype* pri_norm,
tinytype* dual_norm
);
// Predict new rho value
tinytype predict_rho(
RhoAdapter* adapter,
tinytype pri_res,
tinytype dual_res,
tinytype pri_norm,
tinytype dual_norm,
tinytype current_rho
);
// Update matrices using derivatives
void update_matrices_with_derivatives(TinyCache* cache, tinytype new_rho);
// Main benchmark function
void benchmark_rho_adaptation(
RhoAdapter* adapter,
const tinyMatrix& x_prev,
const tinyMatrix& u_prev,
const tinyMatrix& v_prev,
const tinyMatrix& z_prev,
const tinyMatrix& g_prev,
const tinyMatrix& y_prev,
TinyCache* cache,
TinyWorkspace* work,
int N,
RhoBenchmarkResult* result
);

View File

@@ -0,0 +1,876 @@
#include "tiny_api.hpp"
#include "tiny_api_constants.hpp"
#include <iostream>
#ifdef __cplusplus
extern "C" {
#endif
using namespace Eigen;
IOFormat TinyApiFmt(4, 0, ", ", "\n", "[", "]");
static int
check_dimension(std::string matrix_name, std::string rows_or_columns, int actual, int expected) {
if (actual != expected) {
std::cout << matrix_name << " has " << actual << " " << rows_or_columns << ". Expected "
<< expected << "." << std::endl;
return 1;
}
return 0;
}
int tiny_setup(
TinySolver** solverp,
tinyMatrix Adyn,
tinyMatrix Bdyn,
tinyMatrix fdyn,
tinyMatrix Q,
tinyMatrix R,
tinytype rho,
int nx,
int nu,
int N,
int verbose
) {
TinySolution* solution = new TinySolution();
TinyCache* cache = new TinyCache();
TinySettings* settings = new TinySettings();
TinyWorkspace* work = new TinyWorkspace();
TinySolver* solver = new TinySolver();
solver->solution = solution;
solver->cache = cache;
solver->settings = settings;
solver->work = work;
*solverp = solver;
// Initialize solution
solution->iter = 0;
solution->solved = 0;
solution->x = tinyMatrix::Zero(nx, N);
solution->u = tinyMatrix::Zero(nu, N - 1);
// Initialize settings
tiny_set_default_settings(settings);
// Initialize workspace
work->nx = nx;
work->nu = nu;
work->N = N;
// Make sure arguments are the correct shapes
int status = 0;
status |= check_dimension("State transition matrix (A)", "rows", Adyn.rows(), nx);
status |= check_dimension("State transition matrix (A)", "columns", Adyn.cols(), nx);
status |= check_dimension("Input matrix (B)", "rows", Bdyn.rows(), nx);
status |= check_dimension("Input matrix (B)", "columns", Bdyn.cols(), nu);
status |= check_dimension("Affine vector (f)", "rows", fdyn.rows(), nx);
status |= check_dimension("Affine vector (f)", "columns", fdyn.cols(), 1);
status |= check_dimension("State stage cost (Q)", "rows", Q.rows(), nx);
status |= check_dimension("State stage cost (Q)", "columns", Q.cols(), nx);
status |= check_dimension("State input cost (R)", "rows", R.rows(), nu);
status |= check_dimension("State input cost (R)", "columns", R.cols(), nu);
if (status) {
return status;
}
work->x = tinyMatrix::Zero(nx, N);
work->u = tinyMatrix::Zero(nu, N - 1);
work->q = tinyMatrix::Zero(nx, N);
work->r = tinyMatrix::Zero(nu, N - 1);
work->p = tinyMatrix::Zero(nx, N);
work->d = tinyMatrix::Zero(nu, N - 1);
// Bound constraint slack variables
work->v = tinyMatrix::Zero(nx, N);
work->vnew = tinyMatrix::Zero(nx, N);
work->z = tinyMatrix::Zero(nu, N - 1);
work->znew = tinyMatrix::Zero(nu, N - 1);
// Bound constraint dual variables
work->g = tinyMatrix::Zero(nx, N);
work->y = tinyMatrix::Zero(nu, N - 1);
// Cone constraint slack variables
work->vc = tinyMatrix::Zero(nx, N);
work->vcnew = tinyMatrix::Zero(nx, N);
work->zc = tinyMatrix::Zero(nu, N - 1);
work->zcnew = tinyMatrix::Zero(nu, N - 1);
// Cone constraint dual variables
work->gc = tinyMatrix::Zero(nx, N);
work->yc = tinyMatrix::Zero(nu, N - 1);
// Linear constraint slack variables
work->vl = tinyMatrix::Zero(nx, N);
work->vlnew = tinyMatrix::Zero(nx, N);
work->zl = tinyMatrix::Zero(nu, N - 1);
work->zlnew = tinyMatrix::Zero(nu, N - 1);
// Linear constraint dual variables
work->gl = tinyMatrix::Zero(nx, N);
work->yl = tinyMatrix::Zero(nu, N - 1);
work->Q = (Q + rho * tinyMatrix::Identity(nx, nx)).diagonal();
work->R = (R + rho * tinyMatrix::Identity(nu, nu)).diagonal();
work->Adyn = Adyn; // State transition matrix
work->Bdyn = Bdyn; // Input matrix
work->fdyn = fdyn; // Affine offset vector
work->Xref = tinyMatrix::Zero(nx, N);
work->Uref = tinyMatrix::Zero(nu, N - 1);
work->Qu = tinyVector::Zero(nu);
work->primal_residual_state = 0;
work->primal_residual_input = 0;
work->dual_residual_state = 0;
work->dual_residual_input = 0;
work->status = 0;
work->iter = 0;
// Initialize cache
status = tiny_precompute_and_set_cache(
cache,
Adyn,
Bdyn,
fdyn,
work->Q.asDiagonal(),
work->R.asDiagonal(),
nx,
nu,
rho,
verbose
);
if (status) {
return status;
}
// Initialize sensitivity matrices for adaptive rho
if (solver->settings->adaptive_rho) {
tiny_initialize_sensitivity_matrices(solver);
}
return 0;
}
int tiny_set_bound_constraints(
TinySolver* solver,
tinyMatrix x_min,
tinyMatrix x_max,
tinyMatrix u_min,
tinyMatrix u_max
) {
if (!solver) {
std::cout << "Error in tiny_set_bound_constraints: solver is nullptr" << std::endl;
return 1;
}
// Make sure all bound constraint matrix sizes are self-consistent
int status = 0;
status |= check_dimension("Lower state bounds (x_min)", "rows", x_min.rows(), solver->work->nx);
status |= check_dimension("Lower state bounds (x_min)", "cols", x_min.cols(), solver->work->N);
status |= check_dimension("Lower state bounds (x_max)", "rows", x_max.rows(), solver->work->nx);
status |= check_dimension("Lower state bounds (x_max)", "cols", x_max.cols(), solver->work->N);
status |= check_dimension("Lower input bounds (u_min)", "rows", u_min.rows(), solver->work->nu);
status |=
check_dimension("Lower input bounds (u_min)", "cols", u_min.cols(), solver->work->N - 1);
status |= check_dimension("Lower input bounds (u_max)", "rows", u_max.rows(), solver->work->nu);
status |=
check_dimension("Lower input bounds (u_max)", "cols", u_max.cols(), solver->work->N - 1);
solver->work->x_min = x_min;
solver->work->x_max = x_max;
solver->work->u_min = u_min;
solver->work->u_max = u_max;
return 0;
}
int tiny_set_cone_constraints(
TinySolver* solver,
VectorXi Acx,
VectorXi qcx,
tinyVector cx,
VectorXi Acu,
VectorXi qcu,
tinyVector cu
) {
if (!solver) {
std::cout << "Error in tiny_set_cone_constraints: solver is nullptr" << std::endl;
return 1;
}
// Make sure all cone constraint vector sizes are self-consistent
int num_state_cones = Acx.rows();
int num_input_cones = Acu.rows();
int status = 0;
status |= check_dimension("Cone state size (qcx)", "rows", qcx.rows(), num_state_cones);
status |= check_dimension("Cone mu value for state (cx)", "rows", cx.rows(), num_state_cones);
status |= check_dimension("Cone input size (qcu)", "rows", qcu.rows(), num_input_cones);
status |= check_dimension("Cone mu value for input (cu)", "rows", cu.rows(), num_input_cones);
if (status) {
return status;
}
solver->work->numStateCones = num_state_cones;
solver->work->numInputCones = num_input_cones;
solver->work->Acx = Acx;
solver->work->qcx = qcx;
solver->work->cx = cx;
solver->work->Acu = Acu;
solver->work->qcu = qcu;
solver->work->cu = cu;
return 0;
}
int tiny_set_linear_constraints(
TinySolver* solver,
tinyMatrix Alin_x,
tinyVector blin_x,
tinyMatrix Alin_u,
tinyVector blin_u
) {
if (!solver) {
std::cout << "Error in tiny_set_linear_constraints: solver is nullptr" << std::endl;
return 1;
}
// Make sure all linear constraint matrix sizes are self-consistent
int num_state_linear = Alin_x.rows();
int num_input_linear = Alin_u.rows();
int status = 0;
// Check state constraint dimensions
if (num_state_linear > 0) {
status |= check_dimension(
"State linear constraint matrix (Alin_x)",
"rows",
Alin_x.rows(),
num_state_linear
);
status |= check_dimension(
"State linear constraint matrix (Alin_x)",
"columns",
Alin_x.cols(),
solver->work->nx
);
status |= check_dimension(
"State linear constraint vector (blin_x)",
"rows",
blin_x.rows(),
num_state_linear
);
status |=
check_dimension("State linear constraint vector (blin_x)", "columns", blin_x.cols(), 1);
}
// Check input constraint dimensions
if (num_input_linear > 0) {
status |= check_dimension(
"Input linear constraint matrix (Alin_u)",
"rows",
Alin_u.rows(),
num_input_linear
);
status |= check_dimension(
"Input linear constraint matrix (Alin_u)",
"columns",
Alin_u.cols(),
solver->work->nu
);
status |= check_dimension(
"Input linear constraint vector (blin_u)",
"rows",
blin_u.rows(),
num_input_linear
);
status |=
check_dimension("Input linear constraint vector (blin_u)", "columns", blin_u.cols(), 1);
}
if (status) {
return status;
}
solver->work->numStateLinear = num_state_linear;
solver->work->numInputLinear = num_input_linear;
solver->work->Alin_x = Alin_x;
solver->work->blin_x = blin_x;
solver->work->Alin_u = Alin_u;
solver->work->blin_u = blin_u;
return 0;
}
int tiny_precompute_and_set_cache(
TinyCache* cache,
tinyMatrix Adyn,
tinyMatrix Bdyn,
tinyMatrix fdyn,
tinyMatrix Q,
tinyMatrix R,
int nx,
int nu,
tinytype rho,
int verbose
) {
if (!cache) {
std::cout << "Error in tiny_precompute_and_set_cache: cache is nullptr" << std::endl;
return 1;
}
// Update by adding rho * identity matrix to Q, R
tinyMatrix Q1 = Q + rho * tinyMatrix::Identity(nx, nx);
tinyMatrix R1 = R + rho * tinyMatrix::Identity(nu, nu);
// Printing
if (verbose) {
std::cout << "A = " << Adyn.format(TinyApiFmt) << std::endl;
std::cout << "B = " << Bdyn.format(TinyApiFmt) << std::endl;
std::cout << "Q = " << Q1.format(TinyApiFmt) << std::endl;
std::cout << "R = " << R1.format(TinyApiFmt) << std::endl;
std::cout << "rho = " << rho << std::endl;
}
// Riccati recursion to get Kinf, Pinf
tinyMatrix Ktp1 = tinyMatrix::Zero(nu, nx);
tinyMatrix Ptp1 = rho * tinyMatrix::Ones(nx, 1).array().matrix().asDiagonal();
tinyMatrix Kinf = tinyMatrix::Zero(nu, nx);
tinyMatrix Pinf = tinyMatrix::Zero(nx, nx);
for (int i = 0; i < 1000; i++) {
Kinf = (R1 + Bdyn.transpose() * Ptp1 * Bdyn).inverse() * Bdyn.transpose() * Ptp1 * Adyn;
Pinf = Q1 + Adyn.transpose() * Ptp1 * (Adyn - Bdyn * Kinf);
// if Kinf converges, break
if ((Kinf - Ktp1).cwiseAbs().maxCoeff() < 1e-5) {
if (verbose) {
std::cout << "Kinf converged after " << i + 1 << " iterations" << std::endl;
}
break;
}
Ktp1 = Kinf;
Ptp1 = Pinf;
}
// Compute cached matrices
tinyMatrix Quu_inv = (R1 + Bdyn.transpose() * Pinf * Bdyn).inverse();
tinyMatrix AmBKt = (Adyn - Bdyn * Kinf).transpose();
// Precomputation for affine term
tinyVector APf = AmBKt * Pinf * fdyn;
tinyVector BPf = Bdyn.transpose() * Pinf * fdyn;
if (verbose) {
std::cout << "Kinf = " << Kinf.format(TinyApiFmt) << std::endl;
std::cout << "Pinf = " << Pinf.format(TinyApiFmt) << std::endl;
std::cout << "Quu_inv = " << Quu_inv.format(TinyApiFmt) << std::endl;
std::cout << "AmBKt = " << AmBKt.format(TinyApiFmt) << std::endl;
std::cout << "APf = " << APf.format(TinyApiFmt) << std::endl;
std::cout << "BPf = " << BPf.format(TinyApiFmt) << std::endl;
std::cout << "\nPrecomputation finished!\n" << std::endl;
}
cache->rho = rho;
cache->Kinf = Kinf;
cache->Pinf = Pinf;
cache->Quu_inv = Quu_inv;
cache->AmBKt = AmBKt;
cache->C1 = Quu_inv;
cache->C2 = AmBKt;
cache->APf = APf;
cache->BPf = BPf;
return 0; // return success
}
int tiny_solve(TinySolver* solver) {
return solve(solver);
}
int tiny_update_settings(
TinySettings* settings,
tinytype abs_pri_tol,
tinytype abs_dua_tol,
int max_iter,
int check_termination,
int en_state_bound,
int en_input_bound,
int en_state_soc,
int en_input_soc,
int en_state_linear,
int en_input_linear
) {
if (!settings) {
std::cout << "Error in tiny_update_settings: settings is nullptr" << std::endl;
return 1;
}
settings->abs_pri_tol = abs_pri_tol;
settings->abs_dua_tol = abs_dua_tol;
settings->max_iter = max_iter;
settings->check_termination = check_termination;
settings->en_state_bound = en_state_bound;
settings->en_input_bound = en_input_bound;
settings->en_state_soc = en_state_soc;
settings->en_input_soc = en_input_soc;
settings->en_state_linear = en_state_linear;
settings->en_input_linear = en_input_linear;
return 0;
}
int tiny_set_default_settings(TinySettings* settings) {
if (!settings) {
std::cout << "Error in tiny_set_default_settings: settings is nullptr" << std::endl;
return 1;
}
settings->abs_pri_tol = TINY_DEFAULT_ABS_PRI_TOL;
settings->abs_dua_tol = TINY_DEFAULT_ABS_DUA_TOL;
settings->max_iter = TINY_DEFAULT_MAX_ITER;
settings->check_termination = TINY_DEFAULT_CHECK_TERMINATION;
// Turn off constraints until they are set by tiny_set_bound_constraints or tiny_set_cone_constraints
settings->en_state_bound = TINY_DEFAULT_EN_STATE_BOUND;
settings->en_input_bound = TINY_DEFAULT_EN_INPUT_BOUND;
settings->en_state_soc = TINY_DEFAULT_EN_STATE_SOC;
settings->en_input_soc = TINY_DEFAULT_EN_INPUT_SOC;
settings->en_state_linear = TINY_DEFAULT_EN_STATE_LINEAR;
settings->en_input_linear = TINY_DEFAULT_EN_INPUT_LINEAR;
// Initialize adaptive rho settings
// NOTE : Adaptive rho currently supports only quadrotor system
settings->adaptive_rho = 0; // Disabled by default
settings->adaptive_rho_min = 1.0;
settings->adaptive_rho_max = 100.0;
settings->adaptive_rho_enable_clipping = 1;
return 0;
}
int tiny_set_x0(TinySolver* solver, tinyVector x0) {
if (!solver) {
std::cout << "Error in tiny_set_x0: solver is nullptr" << std::endl;
return 1;
}
if (x0.rows() != solver->work->nx) {
perror("Error in tiny_set_x0: x0 is not the correct length");
}
solver->work->x.col(0) = x0;
return 0;
}
int tiny_set_x_ref(TinySolver* solver, tinyMatrix x_ref) {
if (!solver) {
std::cout << "Error in tiny_set_x_ref: solver is nullptr" << std::endl;
return 1;
}
int status = 0;
status |= check_dimension(
"State reference trajectory (x_ref)",
"rows",
x_ref.rows(),
solver->work->nx
);
status |= check_dimension(
"State reference trajectory (x_ref)",
"columns",
x_ref.cols(),
solver->work->N
);
solver->work->Xref = x_ref;
return 0;
}
int tiny_set_u_ref(TinySolver* solver, tinyMatrix u_ref) {
if (!solver) {
std::cout << "Error in tiny_set_u_ref: solver is nullptr" << std::endl;
return 1;
}
int status = 0;
status |= check_dimension(
"Control/input reference trajectory (u_ref)",
"rows",
u_ref.rows(),
solver->work->nu
);
status |= check_dimension(
"Control/input reference trajectory (u_ref)",
"columns",
u_ref.cols(),
solver->work->N - 1
);
solver->work->Uref = u_ref;
return 0;
}
void tiny_initialize_sensitivity_matrices(TinySolver* solver) {
int nu = solver->work->nu;
int nx = solver->work->nx;
// Initialize matrices with zeros
solver->cache->dKinf_drho = tinyMatrix::Zero(nu, nx);
solver->cache->dPinf_drho = tinyMatrix::Zero(nx, nx);
solver->cache->dC1_drho = tinyMatrix::Zero(nu, nu);
solver->cache->dC2_drho = tinyMatrix::Zero(nx, nx);
const float dKinf_drho[4][12] = { { 0.0001,
-0.0001,
-0.0025,
0.0003,
0.0007,
0.0050,
0.0001,
-0.0001,
-0.0008,
0.0000,
0.0001,
0.0008 },
{ -0.0001,
-0.0000,
-0.0025,
-0.0001,
-0.0006,
-0.0050,
-0.0001,
0.0000,
-0.0008,
-0.0000,
-0.0001,
-0.0008 },
{ 0.0000,
0.0000,
-0.0025,
0.0001,
0.0004,
0.0050,
0.0000,
0.0000,
-0.0008,
0.0000,
0.0000,
0.0008 },
{ -0.0000,
0.0001,
-0.0025,
-0.0003,
-0.0004,
-0.0050,
-0.0000,
0.0001,
-0.0008,
-0.0000,
-0.0000,
-0.0008 } };
const float dPinf_drho[12][12] = { { 0.0494,
-0.0045,
-0.0000,
0.0110,
0.1300,
-0.0283,
0.0280,
-0.0026,
-0.0000,
0.0004,
0.0070,
-0.0094 },
{ -0.0045,
0.0491,
0.0000,
-0.1320,
-0.0111,
0.0114,
-0.0026,
0.0279,
0.0000,
-0.0076,
-0.0004,
0.0038 },
{ -0.0000,
0.0000,
2.4450,
0.0000,
-0.0000,
-0.0000,
-0.0000,
0.0000,
1.2593,
0.0000,
0.0000,
0.0000 },
{ 0.0110,
-0.1320,
0.0000,
0.3913,
0.0592,
0.3108,
0.0080,
-0.0776,
0.0000,
0.0254,
0.0068,
0.0750 },
{ 0.1300,
-0.0111,
-0.0000,
0.0592,
0.4420,
0.7771,
0.0797,
-0.0081,
-0.0000,
0.0068,
0.0350,
0.1875 },
{ -0.0283,
0.0114,
-0.0000,
0.3108,
0.7771,
10.0441,
0.0272,
-0.0109,
0.0000,
0.0655,
0.1639,
2.6362 },
{ 0.0280,
-0.0026,
-0.0000,
0.0080,
0.0797,
0.0272,
0.0163,
-0.0016,
-0.0000,
0.0005,
0.0047,
0.0032 },
{ -0.0026,
0.0279,
0.0000,
-0.0776,
-0.0081,
-0.0109,
-0.0016,
0.0161,
0.0000,
-0.0046,
-0.0005,
-0.0013 },
{ -0.0000,
0.0000,
1.2593,
0.0000,
-0.0000,
0.0000,
-0.0000,
0.0000,
0.9232,
0.0000,
0.0000,
0.0000 },
{ 0.0004,
-0.0076,
0.0000,
0.0254,
0.0068,
0.0655,
0.0005,
-0.0046,
0.0000,
0.0022,
0.0017,
0.0244 },
{ 0.0070,
-0.0004,
0.0000,
0.0068,
0.0350,
0.1639,
0.0047,
-0.0005,
0.0000,
0.0017,
0.0054,
0.0610 },
{ -0.0094,
0.0038,
0.0000,
0.0750,
0.1875,
2.6362,
0.0032,
-0.0013,
0.0000,
0.0244,
0.0610,
0.9869 } };
const float dC1_drho[4][4] = { { -0.0000, 0.0000, -0.0000, 0.0000 },
{ 0.0000, -0.0000, 0.0000, -0.0000 },
{ -0.0000, 0.0000, -0.0000, 0.0000 },
{ 0.0000, -0.0000, 0.0000, -0.0000 } };
const float dC2_drho[12][12] = { { 0.0000,
-0.0000,
0.0000,
0.0000,
0.0000,
-0.0000,
0.0000,
-0.0000,
0.0000,
0.0000,
0.0000,
-0.0000 },
{ -0.0000,
0.0000,
0.0000,
-0.0000,
-0.0000,
0.0000,
-0.0000,
0.0000,
0.0000,
-0.0000,
-0.0000,
0.0000 },
{ -0.0000,
0.0000,
0.0001,
0.0000,
-0.0000,
-0.0000,
-0.0000,
0.0000,
0.0000,
0.0000,
-0.0000,
-0.0000 },
{ 0.0000,
-0.0000,
-0.0000,
0.0001,
0.0000,
-0.0000,
0.0000,
-0.0000,
-0.0000,
0.0000,
0.0000,
-0.0000 },
{ 0.0000,
-0.0000,
-0.0000,
0.0000,
0.0001,
-0.0000,
0.0000,
-0.0000,
-0.0000,
0.0000,
0.0000,
-0.0000 },
{ -0.0000,
0.0000,
-0.0000,
-0.0000,
0.0000,
0.0001,
-0.0000,
0.0000,
-0.0000,
0.0000,
0.0000,
0.0000 },
{ 0.0000,
-0.0000,
0.0000,
0.0000,
0.0000,
-0.0000,
0.0000,
-0.0000,
0.0000,
0.0000,
0.0000,
-0.0000 },
{ -0.0000,
0.0000,
0.0000,
-0.0000,
-0.0000,
0.0000,
-0.0000,
0.0000,
0.0000,
-0.0000,
-0.0000,
0.0000 },
{ -0.0000,
0.0000,
0.0021,
0.0000,
-0.0000,
-0.0000,
-0.0000,
0.0000,
0.0006,
0.0000,
-0.0000,
-0.0000 },
{ 0.0002,
-0.0027,
-0.0000,
0.0068,
0.0005,
-0.0005,
0.0001,
-0.0015,
-0.0000,
0.0004,
0.0000,
-0.0001 },
{ 0.0027,
-0.0002,
0.0000,
0.0005,
0.0066,
-0.0011,
0.0015,
-0.0001,
0.0000,
0.0000,
0.0004,
-0.0002 },
{ -0.0001,
0.0001,
0.0000,
-0.0000,
0.0000,
0.0041,
-0.0000,
0.0000,
0.0000,
0.0000,
0.0000,
0.0006 } };
// Map arrays to Eigen matrices
solver->cache->dKinf_drho = Map<const Matrix<float, 4, 12>>(dKinf_drho[0]).cast<tinytype>();
solver->cache->dPinf_drho = Map<const Matrix<float, 12, 12>>(dPinf_drho[0]).cast<tinytype>();
solver->cache->dC1_drho = Map<const Matrix<float, 4, 4>>(dC1_drho[0]).cast<tinytype>();
solver->cache->dC2_drho = Map<const Matrix<float, 12, 12>>(dC2_drho[0]).cast<tinytype>();
}
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,118 @@
#pragma once
#include "admm.hpp"
#include <iostream>
#ifdef __cplusplus
extern "C" {
#endif
int tiny_setup(
TinySolver** solverp,
tinyMatrix Adyn,
tinyMatrix Bdyn,
tinyMatrix fdyn,
tinyMatrix Q,
tinyMatrix R,
tinytype rho,
int nx,
int nu,
int N,
int verbose
);
int tiny_set_bound_constraints(
TinySolver* solver,
tinyMatrix x_min,
tinyMatrix x_max,
tinyMatrix u_min,
tinyMatrix u_max
);
int tiny_set_cone_constraints(
TinySolver* solver,
VectorXi Acu,
VectorXi qcu,
tinyVector cu,
VectorXi Acx,
VectorXi qcx,
tinyVector cx
);
int tiny_set_linear_constraints(
TinySolver* solver,
tinyMatrix Alin_x,
tinyVector blin_x,
tinyMatrix Alin_u,
tinyVector blin_u
);
int tiny_precompute_and_set_cache(
TinyCache* cache,
tinyMatrix Adyn,
tinyMatrix Bdyn,
tinyMatrix fdyn,
tinyMatrix Q,
tinyMatrix R,
int nx,
int nu,
tinytype rho,
int verbose
);
void compute_sensitivity_matrices(
TinyCache* cache,
tinyMatrix Adyn,
tinyMatrix Bdyn,
tinyMatrix Q,
tinyMatrix R,
int nx,
int nu,
tinytype rho,
int verbose
);
int tiny_update_matrices_with_derivatives(TinyCache* cache, tinytype delta_rho);
int tiny_solve(TinySolver* solver);
int tiny_update_settings(
TinySettings* settings,
tinytype abs_pri_tol,
tinytype abs_dua_tol,
int max_iter,
int check_termination,
int en_state_bound,
int en_input_bound,
int en_state_soc,
int en_input_soc,
int en_state_linear,
int en_input_linear
);
int tiny_set_default_settings(TinySettings* settings);
int tiny_set_x0(TinySolver* solver, tinyVector x0);
int tiny_set_x_ref(TinySolver* solver, tinyMatrix x_ref);
int tiny_set_u_ref(TinySolver* solver, tinyMatrix u_ref);
/**
* Initialize sensitivity matrices for adaptive rho
*
* @param solver Pointer to solver
*/
void tiny_initialize_sensitivity_matrices(TinySolver* solver);
int tiny_setup_state_soc_constraints(
TinySolver* solver,
tinyVector Acx,
tinyVector qcx,
tinyVector cx,
int numStateCones
);
int tiny_setup_input_soc_constraints(
TinySolver* solver,
tinyVector Acu,
tinyVector qcu,
tinyVector cu,
int numInputCones
);
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,13 @@
#pragma once
// Default settings
#define TINY_DEFAULT_ABS_PRI_TOL (1e-03)
#define TINY_DEFAULT_ABS_DUA_TOL (1e-03)
#define TINY_DEFAULT_MAX_ITER (1000)
#define TINY_DEFAULT_CHECK_TERMINATION (1)
#define TINY_DEFAULT_EN_STATE_BOUND (1)
#define TINY_DEFAULT_EN_INPUT_BOUND (1)
#define TINY_DEFAULT_EN_STATE_SOC (0)
#define TINY_DEFAULT_EN_INPUT_SOC (0)
#define TINY_DEFAULT_EN_STATE_LINEAR (0)
#define TINY_DEFAULT_EN_INPUT_LINEAR (0)

View File

@@ -0,0 +1,197 @@
#pragma once
#include <Eigen/Eigen>
// #include <Eigen/Core>
// #include <Eigen/LU>
using namespace Eigen;
#ifdef __cplusplus
extern "C" {
#endif
typedef double tinytype; // should be double if you want to generate code
typedef Matrix<tinytype, Dynamic, Dynamic> tinyMatrix;
typedef Matrix<tinytype, Dynamic, 1> tinyVector;
// typedef Matrix<tinytype, NSTATES, 1> tiny_VectorNx;
// typedef Matrix<tinytype, NINPUTS, 1> tiny_VectorNu;
// typedef Matrix<tinytype, NSTATES, NSTATES> tiny_MatrixNxNx;
// typedef Matrix<tinytype, NSTATES, NINPUTS> tiny_MatrixNxNu;
// typedef Matrix<tinytype, NINPUTS, NSTATES> tiny_MatrixNuNx;
// typedef Matrix<tinytype, NINPUTS, NINPUTS> tiny_MatrixNuNu;
// typedef Matrix<tinytype, NSTATES, NHORIZON> tiny_MatrixNxNh; // Nu x Nh
// typedef Matrix<tinytype, NINPUTS, NHORIZON - 1> tiny_MatrixNuNhm1; // Nu x Nh-1
/**
* Solution
*/
typedef struct {
int iter;
int solved;
tinyMatrix x; // nx x N
tinyMatrix u; // nu x N-1
} TinySolution;
/**
* Matrices that must be recomputed with changes in time step, rho
*/
typedef struct {
tinytype rho;
tinyMatrix Kinf; // nu x nx
tinyMatrix Pinf; // nx x nx
tinyMatrix Quu_inv; // nu x nu
tinyMatrix AmBKt; // nx x nx
tinyVector APf; // nx x 1
tinyVector BPf; // nu x 1
tinyMatrix C1; // From adaptive rho
tinyMatrix C2; // From adaptive rho
// Sensitivity matrices for adaptive rho
tinyMatrix dKinf_drho;
tinyMatrix dPinf_drho;
tinyMatrix dC1_drho;
tinyMatrix dC2_drho;
} TinyCache;
/**
* User settings
*/
typedef struct {
tinytype abs_pri_tol;
tinytype abs_dua_tol;
int max_iter;
int check_termination;
int en_state_bound;
int en_input_bound;
int en_state_soc;
int en_input_soc;
int en_state_linear;
int en_input_linear;
// Add adaptive rho parameters
int adaptive_rho; // Enable/disable adaptive rho (1/0)
tinytype adaptive_rho_min; // Minimum value for rho
tinytype adaptive_rho_max; // Maximum value for rho
int adaptive_rho_enable_clipping; // Enable/disable clipping of rho (1/0)
} TinySettings;
/**
* Problem variables
*/
typedef struct {
int nx; // Number of states
int nu; // Number of control inputs
int N; // Number of knotpoints in the horizon
// State and input
tinyMatrix x; // nx x N
tinyMatrix u; // nu x N-1
// Linear control cost terms
tinyMatrix q; // nx x N
tinyMatrix r; // nu x N-1
// Linear Riccati backward pass terms
tinyMatrix p; // nx x N
tinyMatrix d; // nu x N-1
// Bound constraint variables
// Slack variables
tinyMatrix v; // nx x N
tinyMatrix vnew; // nx x N
tinyMatrix z; // nu x N-1
tinyMatrix znew; // nu x N-1
// Dual variables
tinyMatrix g; // nx x N
tinyMatrix y; // nu x N-1
// State and input bounds
tinyMatrix x_min; // nx x N
tinyMatrix x_max; // nx x N
tinyMatrix u_min; // nu x N-1
tinyMatrix u_max; // nu x N-1
// Cone constraint variables
// Variables to keep track of general cone information
int numStateCones; // Number of cone constraints on states at each time step
int numInputCones; // Number of cone constraints on inputs at each time step
tinyVector cx; // One coefficient for each state cone
tinyVector cu; // One coefficient for each input cone
VectorXi Acx; // Start indices for each state cone
VectorXi Acu; // Start indices for each input cone
VectorXi qcx; // Dimension for each state cone
VectorXi qcu; // Dimension for each input cone
// Slack variables
tinyMatrix vc; // nx x N
tinyMatrix vcnew; // nx x N
tinyMatrix zc; // nu x N-1
tinyMatrix zcnew; // nu x N-1
// Dual variables
tinyMatrix gc; // nx x N
tinyMatrix yc; // nu x N-1
// Linear constraint variables
// Variables to keep track of general linear constraint information
int numStateLinear; // Number of linear constraints on states at each time step
int numInputLinear; // Number of linear constraints on inputs at each time step
// Constraint matrices and vectors
tinyMatrix Alin_x; // Normal vectors for state linear constraints (numStateLinear x nx)
tinyVector blin_x; // Offset values for state linear constraints (numStateLinear x 1)
tinyMatrix Alin_u; // Normal vectors for input linear constraints (numInputLinear x nu)
tinyVector blin_u; // Offset values for input linear constraints (numInputLinear x 1)
// Slack variables for linear constraints
tinyMatrix vl; // nx x N
tinyMatrix vlnew; // nx x N
tinyMatrix zl; // nu x N-1
tinyMatrix zlnew; // nu x N-1
// Dual variables for linear constraints
tinyMatrix gl; // nx x N
tinyMatrix yl; // nu x N-1
// Q, R, A, B, f given by user
tinyVector Q; // nx x 1
tinyVector R; // nu x 1
tinyMatrix Adyn; // nx x nx (state transition matrix)
tinyMatrix Bdyn; // nx x nu (control matrix)
tinyVector fdyn; // nx x 1 (affine vector)
// Reference trajectory to track for one horizon
tinyMatrix Xref; // nx x N
tinyMatrix Uref; // nu x N-1
// Temporaries
tinyVector Qu; // nu x 1
// Variables for keeping track of solve status
tinytype primal_residual_state;
tinytype primal_residual_input;
tinytype dual_residual_state;
tinytype dual_residual_input;
int status;
int iter;
} TinyWorkspace;
/**
* Main TinyMPC solver structure that holds all information.
*/
typedef struct {
TinySolution* solution; // Solution
TinySettings* settings; // Problem settings
TinyCache* cache; // Problem cache
TinyWorkspace* work; // Solver workspace
} TinySolver;
// Add at the top with other definitions
#define BENCH_NX 12
#define BENCH_NU 4
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,111 @@
#pragma once
#include <algorithm>
#include <cassert>
#include <concepts>
#include <cstddef>
#include <vector>
namespace wust_vision {
namespace auto_aim {
template<typename T>
concept HasStaticLerp = requires(const T& a, const T& b, double t) {
{
T::lerp(a, b, t)
} -> std::same_as<T>;
};
template<HasStaticLerp PointT>
class Trajectory {
public:
void reserve(size_t n) {
cp_vec.reserve(n);
dt_vec.reserve(n > 0 ? n - 1 : 0);
prefix_time.reserve(n);
}
void clear() {
cp_vec.clear();
dt_vec.clear();
prefix_time.clear();
total_duration_ = 0.0;
}
void push_back(const PointT& p, double dt = 0.0) {
if (cp_vec.empty()) {
cp_vec.push_back(p);
prefix_time.push_back(0.0);
total_duration_ = 0.0;
return;
}
assert(dt >= 0.0);
cp_vec.push_back(p);
dt_vec.push_back(dt);
total_duration_ += dt;
prefix_time.push_back(total_duration_);
}
void set(const std::vector<PointT>& c, const std::vector<double>& t) {
assert(!c.empty());
assert(c.size() == t.size() + 1);
cp_vec = c;
dt_vec = t;
prefix_time.resize(cp_vec.size());
prefix_time[0] = 0.0;
for (size_t i = 0; i < dt_vec.size(); ++i)
prefix_time[i + 1] = prefix_time[i] + dt_vec[i];
total_duration_ = prefix_time.back();
}
double getPrefixTimeAtIdx(int i) const {
return prefix_time[i];
}
PointT getStateAtIdx(int i) const {
return cp_vec[i];
}
PointT getStateAtTime(double t) const {
if (cp_vec.empty())
return PointT {};
if (t <= 0.0)
return cp_vec.front();
if (t >= total_duration_)
return cp_vec.back();
auto it = std::lower_bound(prefix_time.begin(), prefix_time.end(), t);
size_t i1 = std::distance(prefix_time.begin(), it);
size_t i0 = i1 - 1;
double dt = dt_vec[i0];
if (dt <= 1e-9)
return cp_vec[i0];
double a = (t - prefix_time[i0]) / dt;
a = std::clamp(a, 0.0, 1.0);
return PointT::lerp(cp_vec[i0], cp_vec[i1], a);
}
double getTotalDuration() const {
return total_duration_;
}
size_t size() const {
return cp_vec.size();
}
std::vector<PointT> cp_vec;
std::vector<double> dt_vec;
std::vector<double> prefix_time;
double total_duration_ { 0.0 };
};
} // namespace auto_aim
} // namespace wust_vision

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,27 @@
#pragma once
#include <memory>
namespace wust_vl::common::utils {
class Parameter;
}
using wust_vlParamterPtr = std::shared_ptr<wust_vl::common::utils::Parameter>;
namespace wust_vision {
struct GimbalCmd;
}
namespace wust_vision::auto_aim {
enum class AutoAimFsm;
class Target;
class VeryAimer {
public:
using Ptr = std::shared_ptr<VeryAimer>;
VeryAimer(wust_vlParamterPtr auto_aim_config_parameter);
static Ptr create(wust_vlParamterPtr auto_aim_config_parameter) {
return std::make_shared<VeryAimer>(auto_aim_config_parameter);
};
~VeryAimer();
[[nodiscard]] GimbalCmd
veryAim(Target target, double bullet_speed, const AutoAimFsm& auto_aim_fsm);
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace wust_vision::auto_aim

View File

@@ -0,0 +1,70 @@
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tasks/auto_aim/type.hpp"
#include "tasks/type_common.hpp"
namespace wust_vision {
namespace auto_aim {
struct LightParams {
// width / height
double min_ratio;
double max_ratio;
// vertical angle
double max_angle;
// judge color
int color_diff_thresh;
double max_angle_diff;
int binary_thres;
void load(const YAML::Node& config) {
binary_thres = config["binary_thres"].as<int>();
min_ratio = config["min_ratio"].as<double>();
max_ratio = config["max_ratio"].as<double>();
max_angle = config["max_angle"].as<double>();
max_angle_diff = config["max_angle_diff"].as<double>();
color_diff_thresh = config["color_diff_thresh"].as<int>();
}
};
struct ArmorParams {
double min_light_ratio;
// light pairs distance
double min_small_center_distance;
double max_small_center_distance;
double min_large_center_distance;
double max_large_center_distance;
// horizontal angle
double max_angle;
void load(const YAML::Node& config) {
min_light_ratio = config["min_light_ratio"].as<double>();
min_small_center_distance = config["min_small_center_distance"].as<double>();
max_small_center_distance = config["max_small_center_distance"].as<double>();
min_large_center_distance = config["min_large_center_distance"].as<double>();
max_large_center_distance = config["max_large_center_distance"].as<double>();
max_angle = config["max_angle"].as<double>();
}
};
class ArmorDetectorBase {
public:
using Ptr = std::unique_ptr<ArmorDetectorBase>;
virtual ~ArmorDetectorBase() = default;
virtual void
pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) = 0;
using DetectorCallback =
std::function<void(const std::vector<ArmorObject>&, const CommonFrame&)>;
virtual void setCallback(DetectorCallback cb) = 0;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,473 @@
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tasks/auto_aim/armor_detect/armor_detector_common.hpp"
#include "tasks/auto_aim/armor_detect/armor_detector_base.hpp"
#include "tasks/auto_aim/armor_detect/number_classifier/factory.hpp"
#include "tasks/utils/utils.hpp"
#include "wust_vl/common/utils/timer.hpp"
namespace wust_vision {
namespace auto_aim {
struct ArmorDetectorCommon::Impl {
public:
Impl(const YAML::Node& config) {
params_.load(config);
number_classifier_ = NumberClassifierFactory::createNumberClassifier(
params_.classify_backend,
params_.classify_model_path,
params_.classify_label_path
);
}
bool extractNetImage(const cv::Mat& src, ArmorObject& armor) const noexcept {
constexpr int light_length = 12;
constexpr int warp_height = 28;
constexpr int small_armor_width = 32;
constexpr int large_armor_width = 54;
const cv::Size roi_size(20, 28);
if (src.empty() || src.cols < 10 || src.rows < 10) {
std::cerr << "[extractNetImage] input src is empty or too small!" << std::endl;
return false;
}
const auto ordered = armor.sortCorners(armor.pts);
const cv::Point2f& p0 = ordered[0];
const cv::Point2f& p1 = ordered[1];
const cv::Point2f& p2 = ordered[2];
const cv::Point2f& p3 = ordered[3];
const float l1_len = cv::norm(p1 - p0);
const float l2_len = cv::norm(p2 - p3);
const cv::Point2f c1 = (p0 + p1) * 0.5f;
const cv::Point2f c2 = (p2 + p3) * 0.5f;
const float avg_light_len = 0.5f * (l1_len + l2_len);
const float center_dist =
avg_light_len > 1e-3f ? cv::norm(c1 - c2) / avg_light_len : 0.f;
const bool is_large = center_dist > params_.armor_params.min_large_center_distance;
const cv::Rect bbox = cv::boundingRect(armor.pts);
if (bbox.width <= 0 || bbox.height <= 0)
return false;
if (bbox.width > src.cols || bbox.height > src.rows)
return false;
const int dw = static_cast<int>(bbox.width * (params_.expand_ratio_w - 1.f));
const int dh = static_cast<int>(bbox.height * (params_.expand_ratio_h - 1.f));
int new_x = bbox.x - (dw >> 1);
int new_y = bbox.y - (dh >> 1);
new_x = std::max(new_x, 0);
new_y = std::max(new_y, 0);
int new_w = std::min(bbox.width + dw, src.cols - new_x);
int new_h = std::min(bbox.height + dh, src.rows - new_y);
if (new_w <= 0 || new_h <= 0)
return false;
const cv::Rect expanded_rect(new_x, new_y, new_w, new_h);
cv::Mat litroi_color = src(expanded_rect);
if (litroi_color.empty())
return false;
cv::Mat litroi_gray;
try {
cv::cvtColor(litroi_color, litroi_gray, cv::COLOR_BGR2GRAY);
} catch (...) {
return false;
}
armor.whole_gray_img = litroi_gray;
if (params_.enable_cv) {
cv::Mat litroi_binary;
try {
cv::threshold(
litroi_gray,
litroi_binary,
params_.light_params.binary_thres,
255,
cv::THRESH_BINARY
);
armor.whole_binary_img = litroi_binary;
} catch (...) {
return false;
}
}
const cv::Point2f offset(static_cast<float>(new_x), static_cast<float>(new_y));
if (params_.enable_classify) {
cv::Point2f src_vertices[4] = { armor.pts[1] - offset,
armor.pts[0] - offset,
armor.pts[3] - offset,
armor.pts[2] - offset };
const int warp_width = is_large ? large_armor_width : small_armor_width;
const int top_light_y = (warp_height - light_length) / 2 - 1;
const int bottom_light_y = top_light_y + light_length;
if (warp_width <= 0 || warp_height <= 0)
return false;
cv::Point2f dst_vertices[4] = {
{ 0.f, static_cast<float>(bottom_light_y) },
{ 0.f, static_cast<float>(top_light_y) },
{ static_cast<float>(warp_width - 1), static_cast<float>(top_light_y) },
{ static_cast<float>(warp_width - 1), static_cast<float>(bottom_light_y) }
};
const cv::Mat warp_mat = cv::getPerspectiveTransform(src_vertices, dst_vertices);
cv::Mat number_image;
cv::warpPerspective(
litroi_gray,
number_image,
warp_mat,
cv::Size(warp_width, warp_height),
cv::INTER_LINEAR,
cv::BORDER_CONSTANT,
0
);
const int roi_x = (warp_width - roi_size.width) >> 1;
const cv::Rect num_roi(roi_x, 0, roi_size.width, roi_size.height);
if ((num_roi & cv::Rect(0, 0, warp_width, warp_height)) != num_roi)
return false;
cv::Mat num_crop = number_image(num_roi);
cv::threshold(
num_crop,
armor.number_img,
0,
255,
cv::THRESH_BINARY | cv::THRESH_OTSU
);
}
armor.whole_rgb_img = litroi_color;
armor.local_offset = offset;
return true;
}
bool refineLightsFromArmorPts(ArmorObject& armor) const noexcept {
armor.center = (armor.pts[0] + armor.pts[1] + armor.pts[2] + armor.pts[3]) * 0.25f;
const int n_lights = static_cast<int>(armor.lights.size());
if (n_lights < 2)
return false;
const auto ordered = armor.sortCorners(armor.pts);
const cv::Point2f ref_centers[2] = { (ordered[0] + ordered[1]) * 0.5f,
(ordered[2] + ordered[3]) * 0.5f };
int best0 = -1, best1 = -1;
float best0_d2 = std::numeric_limits<float>::max();
float best1_d2 = std::numeric_limits<float>::max();
for (int i = 0; i < n_lights; ++i) {
const cv::Point2f& c = armor.lights[i].center;
const cv::Point2f d0 = c - ref_centers[0];
const float dist0 = d0.dot(d0);
if (dist0 < best0_d2) {
best0_d2 = dist0;
best0 = i;
}
const cv::Point2f d1 = c - ref_centers[1];
const float dist1 = d1.dot(d1);
if (dist1 < best1_d2) {
best1_d2 = dist1;
best1 = i;
}
}
if (best0 == best1) {
best1 = -1;
best1_d2 = std::numeric_limits<float>::max();
for (int i = 0; i < n_lights; ++i) {
if (i == best0)
continue;
const cv::Point2f d = armor.lights[i].center - ref_centers[1];
const float dist = d.dot(d);
if (dist < best1_d2) {
best1_d2 = dist;
best1 = i;
}
}
}
if (best0 < 0 || best1 < 0)
return false;
const auto& l0 = armor.lights[best0];
const auto& l1 = armor.lights[best1];
if (l0.center.x < l1.center.x) {
armor.lights[0] = l0;
armor.lights[1] = l1;
} else {
armor.lights[0] = l1;
armor.lights[1] = l0;
}
return true;
}
std::vector<Light>
findLights(const cv::Mat& color_img, const cv::Mat& binary_img, ArmorObject& armor)
const noexcept {
std::vector<std::vector<cv::Point>> contours;
contours.reserve(64);
cv::findContours(binary_img, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE);
std::vector<Light> all_lights;
all_lights.reserve(contours.size());
for (const auto& contour: contours) {
const int n = static_cast<int>(contour.size());
if (n < 6)
continue;
Light light(contour);
if (!isLight(light))
continue;
int sum_r = 0;
int sum_b = 0;
for (const auto& pt: contour) {
const cv::Vec3b* row = color_img.ptr<cv::Vec3b>(pt.y);
const cv::Vec3b& pix = row[pt.x];
sum_r += pix[0];
sum_b += pix[2];
}
const int avg_diff = std::abs(sum_r - sum_b) / n;
if (avg_diff <= params_.light_params.color_diff_thresh)
continue;
light.color = (sum_r > sum_b) ? 0 : 1; // 0=红, 1=蓝
all_lights.emplace_back(std::move(light));
}
std::sort(all_lights.begin(), all_lights.end(), [](const Light& a, const Light& b) {
return a.center.x < b.center.x;
});
armor.lights = all_lights;
return all_lights;
}
bool isLight(const Light& light) const noexcept {
// width / length 比例
const float ratio = light.width / light.length;
if (ratio <= params_.light_params.min_ratio || ratio >= params_.light_params.max_ratio)
return false;
if (light.tilt_angle >= params_.light_params.max_angle)
return false;
return true;
}
bool isArmor(const Light& l1, const Light& l2) const noexcept {
const float len1 = l1.length;
const float len2 = l2.length;
if (len1 <= 1e-3f || len2 <= 1e-3f)
return false;
const float min_len = (len1 < len2) ? len1 : len2;
const float max_len = (len1 < len2) ? len2 : len1;
if (min_len / max_len <= params_.armor_params.min_light_ratio)
return false;
const cv::Point2f d = l1.center - l2.center;
const float dist2 = d.dot(d);
const float avg_len = 0.5f * (len1 + len2);
const float min_small = params_.armor_params.min_small_center_distance * avg_len;
const float max_small = params_.armor_params.max_small_center_distance * avg_len;
const float min_large = params_.armor_params.min_large_center_distance * avg_len;
const float max_large = params_.armor_params.max_large_center_distance * avg_len;
const float min_small2 = min_small * min_small;
const float max_small2 = max_small * max_small;
const float min_large2 = min_large * min_large;
const float max_large2 = max_large * max_large;
const bool small_ok = (dist2 >= min_small2 && dist2 < max_small2);
const bool large_ok = (dist2 >= min_large2 && dist2 < max_large2);
if (!(small_ok || large_ok))
return false;
static const float tan_max_angle =
std::tan(params_.armor_params.max_angle * CV_PI / 180.0f);
if (std::abs(d.y) >= std::abs(d.x) * tan_max_angle)
return false;
if (l1.color != l2.color)
return false;
return true;
}
std::vector<ArmorObject> detectNet(
const cv::Mat& src_img,
std::vector<ArmorObject>& objs_result,
Eigen::Matrix3f transform_matrix,
int detect_color,
const std::optional<ArmorNumber>& target_number
) const noexcept {
std::vector<ArmorObject> armors;
if (!src_img.data || src_img.empty()) {
std::cout << "img data nullptr or empty" << std::endl;
return armors;
}
if (objs_result.empty()) {
return armors;
}
for (auto& armor_in: objs_result) {
ArmorObject armor = armor_in;
if ((detect_color == 0 && armor.color == ArmorColor::BLUE)
|| (detect_color == 1 && armor.color == ArmorColor::RED))
{
continue;
}
if (params_.enable_classify || params_.enable_cv) {
bool ok = false;
ok = extractNetImage(src_img, armor);
if (!ok)
continue;
}
if (params_.enable_classify) {
number_classifier_->classifyNumber(armor);
if (armor.confidence < params_.classifier_threshold)
continue;
}
if (target_number.has_value()) {
if (!isSameTarget(target_number.value(), armor.number)) {
continue;
}
}
if (armor.color == ArmorColor::NONE || armor.color == ArmorColor::PURPLE) {
armor.is_ok = false;
armor.transform(transform_matrix);
armors.emplace_back(armor);
continue;
}
if (params_.enable_cv) {
findLights(armor.whole_rgb_img, armor.whole_binary_img, armor);
if (refineLightsFromArmorPts(armor)) {
if (isArmor(armor.lights[0], armor.lights[1])) {
armor.is_ok = true;
for (auto& light: armor.lights) {
light.addOffset(armor.local_offset);
}
}
}
if (armor.is_ok) {
armor.is_ok = armor.checkOkptsRight(params_.max_pts_error);
}
}
if (!armor.is_ok) {
auto ordered = armor.sortCorners(armor.pts);
Light l1, l2;
l1.length = cv::norm(ordered[1] - ordered[0]);
l1.center = (ordered[0] + ordered[1]) / 2.0;
l2.length = cv::norm(ordered[2] - ordered[3]);
l2.center = (ordered[2] + ordered[3]) / 2.0;
if (!isArmor(l1, l2)) {
continue;
}
}
armor.transform(transform_matrix);
armors.emplace_back(armor);
}
return armors;
}
std::unique_ptr<NumberClassifierBase> number_classifier_;
struct ArmorDetectCommonParams {
std::string classify_backend = "opencv";
std::string classify_model_path;
std::string classify_label_path;
double classifier_threshold = 0.5;
LightParams light_params;
ArmorParams armor_params;
float expand_ratio_w = 1.1f;
float expand_ratio_h = 1.1f;
double max_pts_error = 20.0;
bool enable_cv = false;
bool enable_classify = true;
void load(const YAML::Node& config) {
expand_ratio_w = config["cv"]["light"]["expand_ratio_w"].as<float>(1.1);
expand_ratio_h = config["cv"]["light"]["expand_ratio_h"].as<float>(1.1);
max_pts_error = config["cv"]["light"]["max_pts_error"].as<double>(20.0);
enable_cv = config["cv"]["enable"].as<bool>();
light_params.load(config["cv"]["light"]);
armor_params.load(config["cv"]["armor"]);
enable_classify = config["classify"]["enable"].as<bool>();
classify_model_path =
utils::expandEnv(config["classify"]["model_path"].as<std::string>());
classify_label_path =
utils::expandEnv(config["classify"]["label_path"].as<std::string>());
classify_backend = config["classify"]["backend"].as<std::string>();
classifier_threshold = config["classify"]["threshold"].as<double>();
}
} params_;
};
ArmorDetectorCommon::ArmorDetectorCommon(const YAML::Node& config) {
_impl = std::make_unique<Impl>(config);
}
ArmorDetectorCommon::~ArmorDetectorCommon() {
_impl.reset();
}
std::vector<ArmorObject> ArmorDetectorCommon::detectNet(
const cv::Mat& src_img,
std::vector<ArmorObject>& objs_result,
Eigen::Matrix3f transform_matrix,
int detect_color,
const std::optional<ArmorNumber>& target_number
) {
return _impl
->detectNet(src_img, objs_result, transform_matrix, detect_color, target_number);
}
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,41 @@
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tasks/auto_aim/type.hpp"
namespace wust_vision {
namespace auto_aim {
class ArmorDetectorCommon {
public:
using Ptr = std::unique_ptr<ArmorDetectorCommon>;
ArmorDetectorCommon(const YAML::Node& config);
static Ptr create(const YAML::Node& config) {
return std::make_unique<ArmorDetectorCommon>(config);
}
~ArmorDetectorCommon();
std::vector<ArmorObject> detectNet(
const cv::Mat& src_img,
std::vector<ArmorObject>& objs_result,
Eigen::Matrix3f transform_matrix,
int detect_color,
const std::optional<ArmorNumber>& target_number = std::nullopt
);
private:
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,134 @@
// Copyright 2025 XiaoJian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "armor_detector_base.hpp"
#include "tasks/utils/config.hpp"
#include <string>
#include <yaml-cpp/yaml.h>
#ifdef USE_OPENVINO
#include "tasks/auto_aim/armor_detect/openvino/armor_detector_openvino.hpp"
#endif
#ifdef USE_TRT
#include "tasks/auto_aim/armor_detect/tensorrt/armor_detector_tensorrt.hpp"
#endif
#ifdef USE_NCNN
#include "tasks/auto_aim/armor_detect/ncnn/armor_detector_ncnn.hpp"
#endif
#ifdef USE_ORT
#include "tasks/auto_aim/armor_detect/onnxruntime/armor_detector_onnxruntime.hpp"
#endif
#include "tasks/auto_aim/armor_detect/opencv/armor_detector_opencv.hpp"
namespace wust_vision {
namespace auto_aim {
class DetectorFactory {
public:
static ArmorDetectorBase::Ptr createArmorDetector(
const std::string& backend,
bool use_armor_detect_common,
std::string cv_config_path = OPENCV_CONFIG,
std::string ml_config_path = ML_CONFIG
) {
// 检查编译时是否支持
auto isBackendEnabled = [&backend]() -> bool {
#ifdef USE_OPENVINO
if (backend == "openvino")
return true;
#endif
#ifdef USE_TRT
if (backend == "tensorrt")
return true;
#endif
#ifdef USE_NCNN
if (backend == "ncnn")
return true;
#endif
#ifdef USE_ORT
if (backend == "onnxruntime")
return true;
#endif
if (backend == "opencv")
return true;
return false;
};
if (!isBackendEnabled()) {
std::cout << "Backend " << backend << " is not enabled at compile time."
<< std::endl;
throw std::runtime_error("Backend " + backend + " is not enabled at compile time.");
}
auto getConfigPath = [&](const std::string& backend) -> std::string {
if (backend == "opencv")
return cv_config_path;
else
return ml_config_path;
};
std::string config_path = getConfigPath(backend);
if (config_path.empty()) {
std::cout << "No config path for backend: " << backend << std::endl;
throw std::runtime_error("No config path for backend: " + backend);
}
YAML::Node armor_detect_config = YAML::LoadFile(config_path);
// 创建对应后端实例
#if defined(USE_OPENVINO)
if (backend == "openvino") {
return ArmorDetectorOpenVino::create(
armor_detect_config["armor_detector"],
use_armor_detect_common
);
}
#endif
#if defined(USE_TRT)
if (backend == "tensorrt") {
return ArmorDetectorTrt::create(
armor_detect_config["armor_detector"],
use_armor_detect_common
);
}
#endif
#if defined(USE_NCNN)
if (backend == "ncnn") {
return ArmorDetectorNCNN::create(
armor_detect_config["armor_detector"],
use_armor_detect_common
);
}
#endif
#if defined(USE_ORT)
if (backend == "onnxruntime") {
return ArmorDetectorOnnxRuntime::create(
armor_detect_config["armor_detector"],
use_armor_detect_common
);
}
#endif
if (backend == "opencv") {
return ArmorDetectorOpenCV::create(armor_detect_config["armor_detector"]);
}
std::cout << "Unsupported armor detector backend (or not compiled): " << backend
<< std::endl;
throw std::runtime_error(
"Unsupported armor detector backend (or not compiled): " + backend
);
}
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,198 @@
#include "armor_infer.hpp"
namespace wust_vision::auto_aim::armor_infer {
struct GridAndStride {
int grid0;
int grid1;
int stride;
};
[[nodiscard]] static inline std::vector<GridAndStride>
generate_grids_and_stride(int target_w, int target_h, const std::vector<int>& strides) noexcept {
std::vector<GridAndStride> grid_strides;
for (int stride: strides) {
const int num_w = target_w / stride;
const int num_h = target_h / stride;
grid_strides.reserve(grid_strides.size() + num_w * num_h);
for (int gy = 0; gy < num_h; ++gy) {
for (int gx = 0; gx < num_w; ++gx) {
grid_strides.push_back(GridAndStride { gx, gy, stride });
}
}
}
return grid_strides;
}
std::vector<ArmorObject> ArmorInfer::postProcessTUP_impl(const cv::Mat& out) const {
static std::optional<std::vector<GridAndStride>> _grid_strides;
if (!_grid_strides) {
_grid_strides = generate_grids_and_stride(inputW(), inputH(), { 8, 16, 32 });
}
const auto& grid_strides = _grid_strides.value();
std::vector<ArmorObject> out_objs;
const int num_anchors =
static_cast<int>(std::min<size_t>(grid_strides.size(), static_cast<size_t>(out.rows)));
for (int a = 0; a < num_anchors; ++a) {
const float confidence = out.at<float>(a, 8);
if (confidence < conf_threshold_)
continue;
const auto& gs = grid_strides[a];
const int gx = gs.grid0, gy = gs.grid1, stride = gs.stride;
// color & class
const int color_offset = 9;
const int num_colors = ModelTraits<Mode::TUP>::NUM_COLORS;
const int num_classes = ModelTraits<Mode::TUP>::NUM_CLASSES;
cv::Mat color_scores = out.row(a).colRange(color_offset, color_offset + num_colors);
cv::Mat class_scores =
out.row(a).colRange(color_offset + num_colors, color_offset + num_colors + num_classes);
double max_color, max_class;
cv::Point color_id, class_id;
cv::minMaxLoc(color_scores, nullptr, &max_color, nullptr, &color_id);
cv::minMaxLoc(class_scores, nullptr, &max_class, nullptr, &class_id);
const float x1 = (out.at<float>(a, 0) + gx) * stride;
const float y1 = (out.at<float>(a, 1) + gy) * stride;
const float x2 = (out.at<float>(a, 2) + gx) * stride;
const float y2 = (out.at<float>(a, 3) + gy) * stride;
const float x3 = (out.at<float>(a, 4) + gx) * stride;
const float y3 = (out.at<float>(a, 5) + gy) * stride;
const float x4 = (out.at<float>(a, 6) + gx) * stride;
const float y4 = (out.at<float>(a, 7) + gy) * stride;
ArmorObject obj;
obj.pts = { cv::Point2f(x1, y1),
cv::Point2f(x2, y2),
cv::Point2f(x3, y3),
cv::Point2f(x4, y4) };
obj.box = cv::boundingRect(obj.pts);
obj.color = static_cast<ArmorColor>(color_id.x);
obj.number = static_cast<ArmorNumber>(class_id.x);
obj.confidence = confidence;
out_objs.push_back(std::move(obj));
}
return topKAndNms(out_objs, top_k_, nms_threshold_);
}
std::vector<ArmorObject> ArmorInfer::postProcessRP_impl(const cv::Mat& out) const {
std::vector<ArmorObject> out_objs;
const int rows = out.rows;
const int color_offset = 9;
const int num_colors = ModelTraits<Mode::RP>::NUM_COLORS;
const int num_classes = ModelTraits<Mode::RP>::NUM_CLASSES;
for (int r = 0; r < rows; ++r) {
float conf_raw = out.at<float>(r, 8);
const float confidence = static_cast<float>(sigmoid(conf_raw));
if (confidence < conf_threshold_)
continue;
cv::Mat color_scores = out.row(r).colRange(color_offset, color_offset + num_colors);
cv::Mat class_scores =
out.row(r).colRange(color_offset + num_colors, color_offset + num_colors + num_classes);
double max_color_score, max_class_score;
cv::Point color_id, class_id;
cv::minMaxLoc(color_scores, nullptr, &max_color_score, nullptr, &color_id);
cv::minMaxLoc(class_scores, nullptr, &max_class_score, nullptr, &class_id);
ArmorObject obj;
obj.pts.resize(4);
for (int k = 0; k < 4; ++k) {
const float x = out.at<float>(r, 0 + k * 2);
const float y = out.at<float>(r, 1 + k * 2);
obj.pts[k] = cv::Point2f(x, y);
}
obj.box = cv::boundingRect(obj.pts);
obj.color = static_cast<ArmorColor>(color_id.x);
obj.number = static_cast<ArmorNumber>(class_id.x);
obj.confidence = confidence;
out_objs.push_back(std::move(obj));
}
return topKAndNms(out_objs, top_k_, nms_threshold_);
}
std::vector<ArmorObject> ArmorInfer::postProcessAT_impl(const cv::Mat& out) const {
std::vector<ArmorObject> out_objs;
constexpr int nkpt = ModelTraits<Mode::AT>::NUM_KPTS;
constexpr int nk = nkpt * 2; // keypoints flattened
auto max_det = out.rows;
auto det_dim = out.cols;
auto output_ptr = out.ptr<float>();
for (int i = 0; i < max_det; ++i) {
const float* row = output_ptr + i * det_dim;
float conf = row[4];
if (!std::isfinite(conf) || conf < conf_threshold_)
continue;
float x1 = row[0];
float y1 = row[1];
float x2 = row[2];
float y2 = row[3];
int cls = static_cast<int>(row[5]);
if (!std::isfinite(x1) || !std::isfinite(y1) || !std::isfinite(x2) || !std::isfinite(y2)
|| x2 <= x1 || y2 <= y1)
continue;
ArmorObject obj;
obj.box = cv::Rect2f(x1, y1, x2 - x1, y2 - y1);
obj.confidence = conf;
auto color_num = ModelTraits<Mode::AT>::CLASSES[cls];
obj.color = color_num.first;
obj.number = color_num.second;
obj.pts.reserve(nkpt);
for (int k = 0; k < nkpt; ++k) {
float kx = row[6 + 2 * k];
float ky = row[6 + 2 * k + 1];
obj.pts.emplace_back(kx, ky);
}
out_objs.emplace_back(std::move(obj));
}
return out_objs;
}
std::vector<ArmorObject> ArmorInfer::postProcessBOX_impl(const cv::Mat& out) const {
std::vector<ArmorObject> out_objs;
auto max_det = out.rows;
auto det_dim = out.cols;
auto output_ptr = out.ptr<float>();
for (int i = 0; i < max_det; ++i) {
const float* row = output_ptr + i * det_dim;
float conf = row[4];
if (!std::isfinite(conf) || conf < conf_threshold_)
continue;
float x1 = row[0];
float y1 = row[1];
float x2 = row[2];
float y2 = row[3];
int cls = static_cast<int>(row[5]);
if (!std::isfinite(x1) || !std::isfinite(y1) || !std::isfinite(x2) || !std::isfinite(y2)
|| x2 <= x1 || y2 <= y1)
continue;
ArmorObject obj;
obj.box = cv::Rect2f(x1, y1, x2 - x1, y2 - y1);
obj.confidence = conf;
auto color_num = ModelTraits<Mode::BOX>::CLASSES[cls];
obj.color = color_num.first;
obj.number = color_num.second;
std::vector<cv::Point2f> pts;
pts.resize(4);
pts[0] = cv::Point2f(obj.box.x + obj.box.width, obj.box.y + obj.box.height); // 右下
pts[1] = cv::Point2f(obj.box.x + obj.box.width, obj.box.y); // 右上
pts[2] = cv::Point2f(obj.box.x, obj.box.y); // 左上
pts[3] = cv::Point2f(obj.box.x, obj.box.y + obj.box.height); // 左下
obj.pts = std::move(pts);
out_objs.emplace_back(std::move(obj));
}
return out_objs;
}
} // namespace wust_vision::auto_aim::armor_infer

View File

@@ -0,0 +1,324 @@
#pragma once
#include "tasks/auto_aim/type.hpp"
namespace wust_vision::auto_aim::armor_infer {
static constexpr float MERGE_CONF_ERROR = 0.15f;
static constexpr float MERGE_MIN_IOU = 0.9f;
enum class Mode { TUP, RP, AT, BOX416, BOX320, BOX };
inline Mode modeFromString(const std::string& m) {
if (m == "tup" || m == "TUP")
return Mode::TUP;
if (m == "rp" || m == "RP")
return Mode::RP;
if (m == "at" || m == "AT")
return Mode::AT;
if (m == "box416" || m == "BOX416")
return Mode::BOX416;
if (m == "box320" || m == "BOX320")
return Mode::BOX320;
return Mode::TUP;
}
// ------------------------- model traits -------------------------
template<Mode M>
struct ModelTraits; // declare
// TUP
template<>
struct ModelTraits<Mode::TUP> {
static constexpr int INPUT_W = 416;
static constexpr int INPUT_H = 416;
static constexpr int NUM_CLASSES = 8;
static constexpr int NUM_COLORS = 4;
static constexpr bool USE_NORM = false;
static constexpr bool INPUT_RGB = true;
};
// RP
template<>
struct ModelTraits<Mode::RP> {
static constexpr int INPUT_W = 640;
static constexpr int INPUT_H = 640;
static constexpr int NUM_CLASSES = 9;
static constexpr int NUM_COLORS = 4;
static constexpr bool USE_NORM = true;
static constexpr bool INPUT_RGB = false;
};
template<>
struct ModelTraits<Mode::AT> {
static constexpr int INPUT_W = 640;
static constexpr int INPUT_H = 640;
static constexpr int NUM_KPTS = 4;
static constexpr bool USE_NORM = true;
static constexpr bool INPUT_RGB = false;
static constexpr std::array<std::pair<ArmorColor, ArmorNumber>, 64> CLASSES = { {
{ ArmorColor::BLUE, ArmorNumber::SENTRY }, { ArmorColor::BLUE, ArmorNumber::NO1 },
{ ArmorColor::BLUE, ArmorNumber::NO2 }, { ArmorColor::BLUE, ArmorNumber::NO3 },
{ ArmorColor::BLUE, ArmorNumber::NO4 }, { ArmorColor::BLUE, ArmorNumber::NO5 },
{ ArmorColor::BLUE, ArmorNumber::OUTPOST }, { ArmorColor::BLUE, ArmorNumber::BASE },
{ ArmorColor::BLUE, ArmorNumber::SENTRY }, { ArmorColor::BLUE, ArmorNumber::NO1 },
{ ArmorColor::BLUE, ArmorNumber::NO2 }, { ArmorColor::BLUE, ArmorNumber::NO3 },
{ ArmorColor::BLUE, ArmorNumber::NO4 }, { ArmorColor::BLUE, ArmorNumber::NO5 },
{ ArmorColor::BLUE, ArmorNumber::OUTPOST }, { ArmorColor::BLUE, ArmorNumber::BASE },
{ ArmorColor::RED, ArmorNumber::SENTRY }, { ArmorColor::RED, ArmorNumber::NO1 },
{ ArmorColor::RED, ArmorNumber::NO2 }, { ArmorColor::RED, ArmorNumber::NO3 },
{ ArmorColor::RED, ArmorNumber::NO4 }, { ArmorColor::RED, ArmorNumber::NO5 },
{ ArmorColor::RED, ArmorNumber::OUTPOST }, { ArmorColor::RED, ArmorNumber::BASE },
{ ArmorColor::RED, ArmorNumber::SENTRY }, { ArmorColor::RED, ArmorNumber::NO1 },
{ ArmorColor::RED, ArmorNumber::NO2 }, { ArmorColor::RED, ArmorNumber::NO3 },
{ ArmorColor::RED, ArmorNumber::NO4 }, { ArmorColor::RED, ArmorNumber::NO5 },
{ ArmorColor::RED, ArmorNumber::OUTPOST }, { ArmorColor::RED, ArmorNumber::BASE },
{ ArmorColor::NONE, ArmorNumber::SENTRY }, { ArmorColor::NONE, ArmorNumber::NO1 },
{ ArmorColor::NONE, ArmorNumber::NO2 }, { ArmorColor::NONE, ArmorNumber::NO3 },
{ ArmorColor::NONE, ArmorNumber::NO4 }, { ArmorColor::NONE, ArmorNumber::NO5 },
{ ArmorColor::NONE, ArmorNumber::OUTPOST }, { ArmorColor::NONE, ArmorNumber::BASE },
{ ArmorColor::NONE, ArmorNumber::SENTRY }, { ArmorColor::NONE, ArmorNumber::NO1 },
{ ArmorColor::NONE, ArmorNumber::NO2 }, { ArmorColor::NONE, ArmorNumber::NO3 },
{ ArmorColor::NONE, ArmorNumber::NO4 }, { ArmorColor::NONE, ArmorNumber::NO5 },
{ ArmorColor::NONE, ArmorNumber::OUTPOST }, { ArmorColor::NONE, ArmorNumber::BASE },
{ ArmorColor::PURPLE, ArmorNumber::SENTRY }, { ArmorColor::PURPLE, ArmorNumber::NO1 },
{ ArmorColor::PURPLE, ArmorNumber::NO2 }, { ArmorColor::PURPLE, ArmorNumber::NO3 },
{ ArmorColor::PURPLE, ArmorNumber::NO4 }, { ArmorColor::PURPLE, ArmorNumber::NO5 },
{ ArmorColor::PURPLE, ArmorNumber::OUTPOST }, { ArmorColor::PURPLE, ArmorNumber::BASE },
{ ArmorColor::PURPLE, ArmorNumber::SENTRY }, { ArmorColor::PURPLE, ArmorNumber::NO1 },
{ ArmorColor::PURPLE, ArmorNumber::NO2 }, { ArmorColor::PURPLE, ArmorNumber::NO3 },
{ ArmorColor::PURPLE, ArmorNumber::NO4 }, { ArmorColor::PURPLE, ArmorNumber::NO5 },
{ ArmorColor::PURPLE, ArmorNumber::OUTPOST }, { ArmorColor::PURPLE, ArmorNumber::BASE },
} };
};
template<>
struct ModelTraits<Mode::BOX416> {
static constexpr int INPUT_W = 416;
static constexpr int INPUT_H = 416;
};
template<>
struct ModelTraits<Mode::BOX320> {
static constexpr int INPUT_W = 320;
static constexpr int INPUT_H = 320;
};
template<>
struct ModelTraits<Mode::BOX> {
static constexpr int INPUT_W = 416;
static constexpr int INPUT_H = 416;
static constexpr bool INPUT_RGB = false;
static constexpr bool USE_NORM = true;
static constexpr std::array<std::pair<ArmorColor, ArmorNumber>, 12> CLASSES = { {
{ ArmorColor::BLUE, ArmorNumber::NO1 },
{ ArmorColor::BLUE, ArmorNumber::NO2 },
{ ArmorColor::BLUE, ArmorNumber::NO3 },
{ ArmorColor::BLUE, ArmorNumber::NO4 },
{ ArmorColor::BLUE, ArmorNumber::NO5 },
{ ArmorColor::BLUE, ArmorNumber::SENTRY },
{ ArmorColor::RED, ArmorNumber::NO1 },
{ ArmorColor::RED, ArmorNumber::NO2 },
{ ArmorColor::RED, ArmorNumber::NO3 },
{ ArmorColor::RED, ArmorNumber::NO4 },
{ ArmorColor::RED, ArmorNumber::NO5 },
{ ArmorColor::RED, ArmorNumber::SENTRY },
} };
};
[[nodiscard]] inline double sigmoid(double x) noexcept {
return x >= 0 ? 1.0 / (1.0 + std::exp(-x)) : std::exp(x) / (1.0 + std::exp(x));
}
[[nodiscard]] inline float rectIoU(const cv::Rect2f& a, const cv::Rect2f& b) noexcept {
const cv::Rect2f inter = a & b;
const float inter_area = inter.area();
const float union_area = a.area() + b.area() - inter_area;
if (union_area <= 0.f || std::isnan(union_area))
return 0.f;
return inter_area / union_area;
}
// Merge / NMS helpers that mimic original intent but clearer
inline void nms_merge_sorted_bboxes(
std::vector<ArmorObject>& objs,
std::vector<int>& out_indices,
float nms_threshold
) {
out_indices.clear();
const size_t n = objs.size();
std::vector<float> areas(n);
for (size_t i = 0; i < n; ++i)
areas[i] = objs[i].box.area();
for (size_t i = 0; i < n; ++i) {
ArmorObject& a = objs[i];
bool keep = true;
for (int idx: out_indices) {
ArmorObject& b = objs[idx];
const float iou = rectIoU(a.box, b.box);
if (std::isnan(iou) || iou > nms_threshold) {
keep = false;
if (a.number == b.number && a.color == b.color && iou > MERGE_MIN_IOU
&& std::abs(a.confidence - b.confidence) < MERGE_CONF_ERROR)
{
// accumulate points for later averaging
for (const auto& pt: a.pts)
b.pts.push_back(pt);
}
break;
}
}
if (keep)
out_indices.push_back(static_cast<int>(i));
}
}
inline std::vector<ArmorObject>
topKAndNms(std::vector<ArmorObject>& objs, int top_k, float nms_threshold) {
std::sort(objs.begin(), objs.end(), [](const ArmorObject& a, const ArmorObject& b) {
return a.confidence > b.confidence;
});
if (static_cast<int>(objs.size()) > top_k)
objs.resize(static_cast<size_t>(top_k));
std::vector<int> indices;
nms_merge_sorted_bboxes(objs, indices, nms_threshold);
std::vector<ArmorObject> result;
result.reserve(indices.size());
for (size_t i = 0; i < indices.size(); ++i) {
result.push_back(std::move(objs[indices[i]]));
// average merged extra points if any
auto& ro = result.back();
if (ro.pts.size() >= 8) {
const size_t npts = ro.pts.size();
std::array<cv::Point2f, 4> accum { { { 0, 0 }, { 0, 0 }, { 0, 0 }, { 0, 0 } } };
for (size_t j = 0; j < npts; ++j)
accum[j % 4] += ro.pts[j];
ro.pts.resize(4);
for (int k = 0; k < 4; ++k) {
float denom = static_cast<float>(npts) / 4.0f;
ro.pts[k].x = accum[k].x / denom;
ro.pts[k].y = accum[k].y / denom;
}
}
}
return result;
}
// However providing full modern unified class below that delegates using templated helpers.
class ArmorInfer {
public:
ArmorInfer(
Mode mode = Mode::TUP,
float conf_threshold = 0.25f,
float nms_threshold = 0.45f,
int top_k = 100
) noexcept:
mode_(mode),
conf_threshold_(conf_threshold),
nms_threshold_(nms_threshold),
top_k_(top_k) {
setMode(mode_);
}
void setMode(Mode m) noexcept {
mode_ = m;
switch (mode_) {
case Mode::TUP: {
input_w_ = ModelTraits<Mode::TUP>::INPUT_W;
input_h_ = ModelTraits<Mode::TUP>::INPUT_H;
use_norm_ = ModelTraits<Mode::TUP>::USE_NORM;
input_rgb_ = ModelTraits<Mode::TUP>::INPUT_RGB;
break;
}
case Mode::RP: {
input_w_ = ModelTraits<Mode::RP>::INPUT_W;
input_h_ = ModelTraits<Mode::RP>::INPUT_H;
use_norm_ = ModelTraits<Mode::RP>::USE_NORM;
input_rgb_ = ModelTraits<Mode::RP>::INPUT_RGB;
break;
}
case Mode::AT: {
input_w_ = ModelTraits<Mode::AT>::INPUT_W;
input_h_ = ModelTraits<Mode::AT>::INPUT_H;
use_norm_ = ModelTraits<Mode::AT>::USE_NORM;
input_rgb_ = ModelTraits<Mode::AT>::INPUT_RGB;
break;
}
case Mode::BOX416: {
input_w_ = ModelTraits<Mode::BOX416>::INPUT_W;
input_h_ = ModelTraits<Mode::BOX416>::INPUT_H;
use_norm_ = ModelTraits<Mode::BOX>::USE_NORM;
input_rgb_ = ModelTraits<Mode::BOX>::INPUT_RGB;
break;
}
case Mode::BOX320: {
input_w_ = ModelTraits<Mode::BOX320>::INPUT_W;
input_h_ = ModelTraits<Mode::BOX320>::INPUT_H;
use_norm_ = ModelTraits<Mode::BOX>::USE_NORM;
input_rgb_ = ModelTraits<Mode::BOX>::INPUT_RGB;
break;
} break;
}
}
void setConfThreshold(float t) noexcept {
conf_threshold_ = t;
}
void setNmsThreshold(float t) noexcept {
nms_threshold_ = t;
}
void setTopK(int k) noexcept {
top_k_ = k;
}
int inputW() const noexcept {
return input_w_;
}
int inputH() const noexcept {
return input_h_;
}
bool useNorm() const noexcept {
return use_norm_;
}
bool inputRGB() const noexcept {
return input_rgb_;
}
// main dispatching entry (keeps original signature)
[[nodiscard]] std::vector<ArmorObject> postProcess(const cv::Mat& output_buffer) const {
switch (mode_) {
case Mode::TUP:
return postProcessTUP_impl(output_buffer);
case Mode::RP:
return postProcessRP_impl(output_buffer);
case Mode::AT:
return postProcessAT_impl(output_buffer);
case Mode::BOX416:
return postProcessBOX_impl(output_buffer);
case Mode::BOX320:
return postProcessBOX_impl(output_buffer);
}
return {};
}
private:
std::vector<ArmorObject> postProcessTUP_impl(const cv::Mat& out) const;
std::vector<ArmorObject> postProcessRP_impl(const cv::Mat& out) const;
std::vector<ArmorObject> postProcessAT_impl(const cv::Mat& out) const;
std::vector<ArmorObject> postProcessBOX_impl(const cv::Mat& out) const;
private:
Mode mode_;
int input_w_ { 0 };
int input_h_ { 0 };
float conf_threshold_ { 0.25f };
float nms_threshold_ { 0.45f };
int top_k_ { 100 };
bool use_norm_ { false };
bool input_rgb_ { false };
};
} // namespace wust_vision::auto_aim::armor_infer

View File

@@ -0,0 +1,222 @@
#ifdef USE_NCNN
#include "tasks/auto_aim/armor_detect/ncnn/armor_detector_ncnn.hpp"
#include "tasks/auto_aim/armor_detect/armor_detector_common.hpp"
#include "tasks/auto_aim/armor_detect/armor_infer.hpp"
#include "wust_vl/ml_net/ncnn/ncnn_net.hpp"
namespace wust_vision {
namespace auto_aim {
struct ArmorDetectorNCNN::Impl {
public:
Impl(const YAML::Node& config, bool use_armor_detect_common) {
if (use_armor_detect_common) {
armor_detect_common_ = ArmorDetectorCommon::create(config);
}
std::string model_type = config["ncnn"]["model_type"].as<std::string>();
auto model = armor_infer::modeFromString(model_type);
float conf_threshold = config["ncnn"]["conf_threshold"].as<float>();
int top_k = config["ncnn"]["top_k"].as<int>();
float nms_threshold = config["ncnn"]["nms_threshold"].as<float>();
armor_infer_ = std::make_unique<armor_infer::ArmorInfer>(
model,
conf_threshold,
nms_threshold,
top_k
);
std::string model_path_param =
utils::expandEnv(config["ncnn"]["model_path_param"].as<std::string>());
std::string model_path_bin =
utils::expandEnv(config["ncnn"]["model_path_bin"].as<std::string>());
bool use_gpu = config["ncnn"]["use_gpu"].as<bool>();
int cpu_threads = config["ncnn"]["cpu_threads"].as<int>();
bool use_lightmode = config["ncnn"]["use_lightmode"].as<bool>();
auto input_name = config["ncnn"]["input_name"].as<std::string>();
auto output_name = config["ncnn"]["output_name"].as<std::string>();
int device_id = config["ncnn"]["device_id"].as<int>();
wust_vl::ml_net::NCNNNet::Params params;
params.model_path_param = model_path_param;
params.model_path_bin = model_path_bin;
params.input_name = input_name;
params.output_name = output_name;
params.use_vulkan = use_gpu;
params.device_id = device_id;
params.use_light_mode = use_lightmode;
params.cpu_threads = cpu_threads;
ncnn_net_ = std::make_unique<wust_vl::ml_net::NCNNNet>();
ncnn_net_->init(params);
}
static Ptr create(const YAML::Node& config, bool use_armor_detect_common) {
return std::make_unique<ArmorDetectorNCNN>(config, use_armor_detect_common);
}
~Impl() {
armor_detect_common_.reset();
ncnn_net_.reset();
}
cv::Mat ncnnMatToCvMat(const ncnn::Mat& m) {
cv::Mat img(m.h, m.w, CV_8UC3);
m.to_pixels(img.data, ncnn::Mat::PIXEL_RGB2BGR);
return img;
}
void setCallback(DetectorCallback callback) {
infer_callback_ = callback;
}
static ncnn::Mat letterbox_to_ncnn(
const cv::Mat& img,
Eigen::Matrix3f& transform_matrix,
int out_w,
int out_h,
float norm,
bool swap_rb = true
) {
const int img_w = img.cols;
const int img_h = img.rows;
float scale = std::min(out_w * 1.0f / img_w, out_h * 1.0f / img_h);
int resize_w = static_cast<int>(round(img_w * scale));
int resize_h = static_cast<int>(round(img_h * scale));
int pad_w = out_w - resize_w;
int pad_h = out_h - resize_h;
int pad_left = static_cast<int>(round(pad_w / 2.0f - 0.1f));
int pad_top = static_cast<int>(round(pad_h / 2.0f - 0.1f));
transform_matrix << 1.0f / scale, 0, -pad_left / scale, 0, 1.0f / scale,
-pad_top / scale, 0, 0, 1;
ncnn::Mat out;
if (swap_rb) {
out = ncnn::Mat::from_pixels_resize(
img.data,
ncnn::Mat::PIXEL_BGR2RGB,
img_w,
img_h,
resize_w,
resize_h
);
} else {
out = ncnn::Mat::from_pixels_resize(
img.data,
ncnn::Mat::PIXEL_RGB,
img_w,
img_h,
resize_w,
resize_h
);
}
int pad_right = out_w - resize_w - pad_left;
int pad_bottom = out_h - resize_h - pad_top;
ncnn::Mat padded;
ncnn::copy_make_border(
out,
padded,
pad_top,
pad_bottom,
pad_left,
pad_right,
ncnn::BORDER_CONSTANT,
114.f
);
std::array<float, 3> mean_vals;
std::array<float, 3> norm_vals;
mean_vals = { 0.f, 0.f, 0.f };
norm_vals = { norm, norm, norm };
padded.substract_mean_normalize(mean_vals.data(), norm_vals.data());
return padded;
}
void
processCallback(const CommonFrame& frame, const std::optional<ArmorNumber>& target_number) {
// Eigen::Matrix3f transform_matrix;
// cv::Mat resized_img = letterbox(frame.src_img, transform_matrix);
// ncnn::Mat in =
// ncnn::Mat::from_pixels(resized_img.data, ncnn::Mat::PIXEL_BGR2RGB, INPUT_W, INPUT_H);
Eigen::Matrix3f transform_matrix;
auto roi = frame.img_frame.src_img(frame.expanded);
const bool swap_rb = armor_infer_->inputRGB()
!= (frame.img_frame.pixel_format == wust_vl::video::PixelFormat::RGB);
const float scale = armor_infer_->useNorm() ? 1.0f / 255.0f : 1.0f;
ncnn::Mat in = letterbox_to_ncnn(
roi.clone(),
transform_matrix,
armor_infer_->inputW(),
armor_infer_->inputH(),
scale,
swap_rb
);
cv::Mat resized_img = ncnnMatToCvMat(in);
auto out = ncnn_net_->infer(in);
cv::Mat output_buffer(out.h, out.w, CV_32F, out.data);
// Parse YOLO output
auto objs_result = armor_infer_->postProcess(output_buffer);
std::vector<ArmorObject> armors;
if (armor_detect_common_) {
armors = armor_detect_common_->detectNet(
resized_img,
objs_result,
transform_matrix,
frame.detect_color,
target_number
);
// Call callback function
if (this->infer_callback_) {
this->infer_callback_(armors, frame);
return;
}
} else {
for (auto obj: objs_result) {
auto detect_color = frame.detect_color;
if (detect_color == 0 && obj.color == ArmorColor::BLUE) {
continue;
} else if (detect_color == 1 && obj.color == ArmorColor::RED) {
continue;
}
obj.transform(transform_matrix);
armors.push_back(obj);
}
if (this->infer_callback_) {
this->infer_callback_(armors, frame);
return;
}
}
return;
}
void pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) {
frame.id = current_id_++;
processCallback(frame, target_number);
}
private:
DetectorCallback infer_callback_;
std::unique_ptr<ArmorDetectorCommon> armor_detect_common_;
std::unique_ptr<armor_infer::ArmorInfer> armor_infer_;
int current_id_ = 0;
std::unique_ptr<wust_vl::ml_net::NCNNNet> ncnn_net_;
};
ArmorDetectorNCNN::ArmorDetectorNCNN(const YAML::Node& config, bool use_armor_detect_common) {
_impl = std::make_unique<Impl>(config, use_armor_detect_common);
}
ArmorDetectorNCNN::~ArmorDetectorNCNN() {
_impl.reset();
}
void ArmorDetectorNCNN::setCallback(DetectorCallback callback) {
_impl->setCallback(callback);
}
void ArmorDetectorNCNN::pushInput(
CommonFrame& frame,
const std::optional<ArmorNumber>& target_number
) {
_impl->pushInput(frame, target_number);
}
} // namespace auto_aim
} // namespace wust_vision
#endif

View File

@@ -0,0 +1,36 @@
// Copyright 2025 XiaoJian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tasks/auto_aim/armor_detect/armor_detector_base.hpp"
namespace wust_vision {
namespace auto_aim {
class ArmorDetectorNCNN: public ArmorDetectorBase {
public:
using Ptr = std::unique_ptr<ArmorDetectorNCNN>;
explicit ArmorDetectorNCNN(const YAML::Node& config, bool use_armor_detect_common);
static Ptr create(const YAML::Node& config, bool use_armor_detect_common) {
return std::make_unique<ArmorDetectorNCNN>(config, use_armor_detect_common);
}
~ArmorDetectorNCNN();
void setCallback(DetectorCallback callback) override;
void
pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) override;
private:
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,12 @@
#pragma once
#include "tasks/auto_aim/type.hpp"
namespace wust_vision {
namespace auto_aim {
class NumberClassifierBase {
public:
virtual ~NumberClassifierBase() = default;
virtual bool classifyNumber(ArmorObject& armor) = 0;
virtual void initNumberClassifier() = 0;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,34 @@
#pragma once
#include "number_classifier.hpp"
#ifdef USE_TRT
#include "number_classifier_trt.hpp"
#endif
namespace wust_vision {
namespace auto_aim {
class NumberClassifierFactory {
public:
static std::unique_ptr<NumberClassifierBase> createNumberClassifier(
const std::string& backend,
const std::string& classify_model_path,
const std::string& classify_label_path
) {
#if defined(USE_TRT)
if (backend == "tensorrt") {
return std::make_unique<NumberClassifierTRT>(
classify_model_path,
classify_label_path
);
}
#endif
if (backend == "opencv") {
return std::make_unique<NumberClassifier>(classify_model_path, classify_label_path);
}
throw std::runtime_error(
"Unsupported number classifier backend (or not compiled): " + backend
);
}
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,116 @@
// Copyright Chen Jun 2023. Licensed under the MIT License.
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "number_classifier.hpp"
#include <wust_vl/common/utils/logger.hpp>
namespace wust_vision {
namespace auto_aim {
NumberClassifier::NumberClassifier(
const std::string& classify_model_path,
const std::string& classify_label_path
):
classify_model_path_(classify_model_path),
classify_label_path_(classify_label_path) {
initNumberClassifier();
}
void NumberClassifier::initNumberClassifier() {
const std::string model_path = classify_model_path_;
std::unique_ptr<cv::dnn::Net> number_net_ =
std::make_unique<cv::dnn::Net>(cv::dnn::readNetFromONNX(model_path));
if (number_net_->empty()) {
WUST_ERROR("number_classifier")
<< "Failed to load number classifier model from " << model_path;
std::exit(EXIT_FAILURE);
} else {
WUST_INFO("number_classifier")
<< "Successfully loaded number classifier model from " << model_path;
}
const std::string label_path = classify_label_path_;
std::ifstream label_file(label_path);
std::string line;
class_names_.clear();
while (std::getline(label_file, line)) {
class_names_.push_back(line);
}
if (class_names_.empty()) {
WUST_ERROR("number_classifier") << "Failed to load labels from " << label_path;
std::exit(EXIT_FAILURE);
} else {
WUST_INFO("number_classifier")
<< "Successfully loaded " << class_names_.size() << " labels from " << label_path;
}
number_net_.reset();
}
bool NumberClassifier::classifyNumber(ArmorObject& armor) {
static thread_local std::unique_ptr<cv::dnn::Net> thread_net;
if (armor.number_img.empty()) {
return false;
}
if (!thread_net) {
thread_net =
std::make_unique<cv::dnn::Net>(cv::dnn::readNetFromONNX(classify_model_path_));
WUST_DEBUG("number_classifier") << "Loaded number classifier model for this thread";
if (thread_net->empty()) {
WUST_ERROR("number_classifier")
<< "Failed to load thread-local number classifier model.";
return false;
}
}
const cv::Mat image = armor.number_img;
cv::Mat blob;
cv::dnn::blobFromImage(image, blob, 1.0 / 255.0);
thread_net->setInput(blob);
cv::Mat outputs = thread_net->forward();
double max_val;
cv::minMaxLoc(outputs, nullptr, &max_val);
cv::Mat prob;
cv::exp(outputs - max_val, prob);
prob /= cv::sum(prob)[0];
double confidence;
cv::Point class_id;
cv::minMaxLoc(prob, nullptr, &confidence, nullptr, &class_id);
const int label_id = class_id.x;
const double raw_conf = armor.confidence;
armor.confidence = confidence;
static const std::map<int, ArmorNumber> label_to_armor_number = {
{ 0, ArmorNumber::NO1 }, { 1, ArmorNumber::NO2 }, { 2, ArmorNumber::NO3 },
{ 3, ArmorNumber::NO4 }, { 4, ArmorNumber::NO5 }, { 5, ArmorNumber::OUTPOST },
{ 6, ArmorNumber::SENTRY }, { 7, ArmorNumber::BASE }
};
if (label_id < 8 && label_to_armor_number.find(label_id) != label_to_armor_number.end()) {
armor.number = label_to_armor_number.at(label_id);
return true;
} else {
armor.confidence = raw_conf;
return false;
}
}
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,35 @@
// Copyright Chen Jun 2023. Licensed under the MIT License.
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base.hpp"
#include "tasks/auto_aim/type.hpp"
namespace wust_vision {
namespace auto_aim {
class NumberClassifier: public NumberClassifierBase {
public:
NumberClassifier(
const std::string& classify_model_path,
const std::string& classify_label_path
);
void initNumberClassifier() override;
bool classifyNumber(ArmorObject& armor) override;
private:
std::vector<std::string> class_names_;
std::string classify_model_path_;
std::string classify_label_path_;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,118 @@
// Copyright Chen Jun 2023. Licensed under the MIT License.
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef USE_TRT
#include "number_classifier_trt.hpp"
#include <wust_vl/common/utils/logger.hpp>
namespace wust_vision {
namespace auto_aim {
NumberClassifierTRT::NumberClassifierTRT(
const std::string& classify_model_path,
const std::string& classify_label_path
):
classify_model_path_(classify_model_path),
classify_label_path_(classify_label_path) {
initNumberClassifier();
}
void NumberClassifierTRT::initNumberClassifier() {
const std::string model_path = classify_model_path_;
trt_net_ = std::make_unique<wust_vl::ml_net::TensorRTNet>();
wust_vl::ml_net::TensorRTNet::Params trt_params;
trt_params.model_path = model_path;
trt_params.input_dims = nvinfer1::Dims4 { 1, 1, 20, 28 };
trt_net_->init(trt_params);
const auto input_output_dims = trt_net_->getInputOutputDims();
input_dims_ = std::get<0>(input_output_dims);
output_dims_ = std::get<1>(input_output_dims);
const std::string label_path = classify_label_path_;
std::ifstream label_file(label_path);
std::string line;
class_names_.clear();
while (std::getline(label_file, line)) {
class_names_.push_back(line);
}
if (class_names_.empty()) {
WUST_ERROR("number_classifier_trt") << "Failed to load labels from " << label_path;
std::exit(EXIT_FAILURE);
} else {
WUST_INFO("number_classifier_trt")
<< "Successfully loaded " << class_names_.size() << " labels from " << label_path;
}
}
bool NumberClassifierTRT::classifyNumber(ArmorObject& armor) {
static thread_local std::unique_ptr<nvinfer1::IExecutionContext> ctx;
if (armor.number_img.empty()) {
return false;
}
if (!ctx) {
auto c = trt_net_->getAContext();
ctx = std::unique_ptr<nvinfer1::IExecutionContext>(c);
WUST_DEBUG("number_classifier_trt") << "Loaded number classifier model for this thread";
if (!ctx) {
WUST_ERROR("number_classifier_trt")
<< "Failed to load thread-local number classifier model.";
return false;
}
}
const cv::Mat image = armor.number_img;
cv::Mat blob;
cv::dnn::blobFromImage(image, blob, 1.0 / 255.0, cv::Size(28, 20));
trt_net_->input2Device(blob.ptr<float>());
void* input_tensor_ptr = trt_net_->getInputTensorPtr();
trt_net_->infer(input_tensor_ptr, ctx.get());
const float* out = static_cast<float*>(trt_net_->output2Host());
cv::Mat outputs(1, 9, CV_32F);
std::memcpy(outputs.data, out, 9 * sizeof(float));
double max_val;
cv::minMaxLoc(outputs, nullptr, &max_val);
cv::Mat prob;
cv::exp(outputs - max_val, prob);
prob /= cv::sum(prob)[0];
double confidence;
cv::Point class_id;
cv::minMaxLoc(prob, nullptr, &confidence, nullptr, &class_id);
const int label_id = class_id.x;
const double raw_conf = armor.confidence;
armor.confidence = confidence;
static const std::map<int, ArmorNumber> label_to_armor_number = {
{ 0, ArmorNumber::NO1 }, { 1, ArmorNumber::NO2 }, { 2, ArmorNumber::NO3 },
{ 3, ArmorNumber::NO4 }, { 4, ArmorNumber::NO5 }, { 5, ArmorNumber::OUTPOST },
{ 6, ArmorNumber::SENTRY }, { 7, ArmorNumber::BASE }
};
if (label_id < 8 && label_to_armor_number.find(label_id) != label_to_armor_number.end()) {
armor.number = label_to_armor_number.at(label_id);
return true;
} else {
armor.confidence = raw_conf;
return false;
}
}
} // namespace auto_aim
} // namespace wust_vision
#endif

View File

@@ -0,0 +1,39 @@
// Copyright Chen Jun 2023. Licensed under the MIT License.
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base.hpp"
#include "tasks/auto_aim/type.hpp"
#include "wust_vl/ml_net/tensorrt/tensorrt_net.hpp"
namespace wust_vision {
namespace auto_aim {
class NumberClassifierTRT: public NumberClassifierBase {
public:
NumberClassifierTRT(
const std::string& classify_model_path,
const std::string& classify_label_path
);
void initNumberClassifier() override;
bool classifyNumber(ArmorObject& armor) override;
private:
std::vector<std::string> class_names_;
std::string classify_model_path_;
std::string classify_label_path_;
std::unique_ptr<wust_vl::ml_net::TensorRTNet> trt_net_;
nvinfer1::Dims input_dims_;
nvinfer1::Dims output_dims_;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,141 @@
#ifdef USE_ORT
#include "armor_detector_onnxruntime.hpp"
#include "tasks/auto_aim/armor_detect/armor_detector_common.hpp"
#include "tasks/auto_aim/armor_detect/armor_infer.hpp"
#include "tasks/utils/utils.hpp"
#include "wust_vl/ml_net/onnxruntime/onnxruntime_net.hpp"
namespace wust_vision {
namespace auto_aim {
struct ArmorDetectorOnnxRuntime::Impl {
public:
Impl(const YAML::Node& config, bool use_armor_detect_common) {
if (use_armor_detect_common) {
armor_detect_common_ = std::make_unique<ArmorDetectorCommon>(config);
}
std::string model_type = config["onnxruntime"]["model_type"].as<std::string>();
auto model = armor_infer::modeFromString(model_type);
float conf_threshold = config["onnxruntime"]["conf_threshold"].as<float>();
int top_k = config["onnxruntime"]["top_k"].as<int>();
float nms_threshold = config["onnxruntime"]["nms_threshold"].as<float>();
armor_infer_ = std::make_unique<armor_infer::ArmorInfer>(
model,
conf_threshold,
nms_threshold,
top_k
);
std::string provider = config["onnxruntime"]["provider"].as<std::string>("CPU");
provider_ = wust_vl::ml_net::string2OrtProvider(provider);
onnxruntime_net_ = std::make_unique<wust_vl::ml_net::OnnxRuntimeNet>();
wust_vl::ml_net::OnnxRuntimeNet::Params params;
std::string model_path =
utils::expandEnv(config["onnxruntime"]["model_path"].as<std::string>());
params.model_path = model_path;
params.provider = provider_;
onnxruntime_net_->init(params);
}
~Impl() {
onnxruntime_net_.reset();
armor_detect_common_.reset();
}
void setCallback(DetectorCallback callback) {
infer_callback_ = callback;
}
void
processCallback(const CommonFrame& frame, const std::optional<ArmorNumber>& target_number) {
Eigen::Matrix3f transform_matrix;
auto roi = frame.img_frame.src_img(frame.expanded);
cv::Mat resized_img = utils::letterbox(
roi,
transform_matrix,
armor_infer_->inputW(),
armor_infer_->inputH()
);
const bool swap_rb = armor_infer_->inputRGB()
!= (frame.img_frame.pixel_format == wust_vl::video::PixelFormat::RGB);
const float scale = armor_infer_->useNorm() ? 1.0f / 255.0f : 1.0f;
cv::Mat blob = cv::dnn::blobFromImage(
resized_img,
scale,
cv::Size(armor_infer_->inputW(), armor_infer_->inputH()),
cv::Scalar(0, 0, 0),
swap_rb
);
auto output_data = onnxruntime_net_->infer(blob.ptr<float>(), blob.total());
auto output_shape = onnxruntime_net_->getOutputShape();
int rows = static_cast<int>(output_shape[1]);
int cols = static_cast<int>(output_shape[2]);
cv::Mat output_buffer(rows, cols, CV_32F, output_data);
// Parsed variable
std::vector<ArmorObject> objs_result;
objs_result = armor_infer_->postProcess(output_buffer);
std::vector<ArmorObject> armors;
if (armor_detect_common_) {
armors = armor_detect_common_->detectNet(
resized_img,
objs_result,
transform_matrix,
frame.detect_color,
target_number
);
// Call callback function
if (this->infer_callback_) {
this->infer_callback_(armors, frame);
return;
}
} else {
for (auto obj: objs_result) {
auto detect_color = frame.detect_color;
if (detect_color == 0 && obj.color == ArmorColor::BLUE) {
continue;
} else if (detect_color == 1 && obj.color == ArmorColor::RED) {
continue;
}
obj.transform(transform_matrix);
armors.push_back(obj);
}
if (this->infer_callback_) {
this->infer_callback_(armors, frame);
return;
}
}
return;
}
void pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) {
frame.id = current_id_++;
processCallback(frame, target_number);
}
wust_vl::ml_net::OrtProvider provider_ = wust_vl::ml_net::OrtProvider::CPU;
DetectorCallback infer_callback_;
std::unique_ptr<ArmorDetectorCommon> armor_detect_common_;
std::unique_ptr<armor_infer::ArmorInfer> armor_infer_;
int current_id_ = 0;
std::unique_ptr<wust_vl::ml_net::OnnxRuntimeNet> onnxruntime_net_;
};
ArmorDetectorOnnxRuntime::ArmorDetectorOnnxRuntime(
const YAML::Node& config,
bool use_armor_detect_common
) {
_impl = std::make_unique<Impl>(config, use_armor_detect_common);
}
ArmorDetectorOnnxRuntime::~ArmorDetectorOnnxRuntime() {
_impl.reset();
}
void ArmorDetectorOnnxRuntime::setCallback(DetectorCallback callback) {
_impl->setCallback(callback);
}
void ArmorDetectorOnnxRuntime::pushInput(
CommonFrame& frame,
const std::optional<ArmorNumber>& target_number
) {
_impl->pushInput(frame, target_number);
}
} // namespace auto_aim
} // namespace wust_vision
#endif

View File

@@ -0,0 +1,36 @@
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tasks/auto_aim/armor_detect/armor_detector_base.hpp"
namespace wust_vision {
namespace auto_aim {
class ArmorDetectorOnnxRuntime: public ArmorDetectorBase {
public:
using Ptr = std::unique_ptr<ArmorDetectorOnnxRuntime>;
explicit ArmorDetectorOnnxRuntime(const YAML::Node& config, bool use_armor_detect_common);
static Ptr create(const YAML::Node& config, bool use_armor_detect_common) {
return std::make_unique<ArmorDetectorOnnxRuntime>(config, use_armor_detect_common);
}
~ArmorDetectorOnnxRuntime();
void
pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) override;
void setCallback(DetectorCallback callback) override;
private:
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,382 @@
// Copyright Chen Jun 2023. Licensed under the MIT License.
//
// Additional modifications and features by Chengfu Zou, Labor. Licensed under
// Apache License 2.0.
//
// Copyright (C) FYT Vision Group. All rights reserved.
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tasks/auto_aim/armor_detect/opencv/armor_detector_opencv.hpp"
#include "tasks/auto_aim/armor_detect/number_classifier/number_classifier.hpp"
#include "tasks/utils/utils.hpp"
namespace wust_vision {
namespace auto_aim {
struct ArmorDetectorOpenCV::Impl {
public:
Impl(const YAML::Node& config) {
auto classify_model_path =
utils::expandEnv(config["classify"]["model_path"].as<std::string>());
auto classify_label_path =
utils::expandEnv(config["classify"]["label_path"].as<std::string>());
double classify_threshold = config["classify"]["threshold"].as<double>();
number_classifier_ =
std::make_unique<NumberClassifier>(classify_model_path, classify_label_path);
light_params_.load(config["light"]);
armor_params_.load(config["armor"]);
}
std::vector<ArmorObject> detect(
const cv::Mat& input,
int detect_color,
const std::optional<ArmorNumber>& target_number
) noexcept {
if (input.empty())
return {};
std::vector<Light> lights_;
cv::Mat binary_img, gray_img;
std::tie(binary_img, gray_img) = preprocessImage(input);
lights_ = findLights(input, binary_img);
std::vector<ArmorObject> armors = matchLights(lights_, detect_color);
std::vector<ArmorObject> valid_armors;
for (auto& armor: armors) {
try {
armor.number_img = extractNumber(gray_img, armor);
if (armor.number_img.empty())
continue;
if (!number_classifier_->classifyNumber(armor))
continue;
if (target_number.has_value()) {
if (!isSameTarget(target_number.value(), armor.number)) {
continue;
}
}
if (armor.confidence < classifier_threshold_)
continue;
if (armor.number != ArmorNumber::NO1 && armor.number != ArmorNumber::BASE
&& armor.type == ArmorType::LARGE)
{
continue;
}
valid_armors.push_back(armor);
} catch (const std::exception& e) {
std::cerr << "[detect] Exception: " << e.what() << std::endl;
}
}
return valid_armors;
}
std::tuple<cv::Mat, cv::Mat> preprocessImage(const cv::Mat& img) noexcept {
cv::Mat gray_img;
cv::Mat binary_img;
if (img.empty()) {
return { binary_img, gray_img }; // 空图直接返回空
}
if (img.channels() == 3) {
cv::cvtColor(img, gray_img, cv::COLOR_RGB2GRAY);
} else if (img.channels() == 1) {
cv::cvtColor(img, gray_img, cv::COLOR_BayerRG2GRAY);
} else {
return { binary_img, gray_img };
}
cv::threshold(gray_img, binary_img, light_params_.binary_thres, 255, cv::THRESH_BINARY);
return { binary_img, gray_img };
}
std::vector<Light> findLights(const cv::Mat& img, const cv::Mat& binary_img) noexcept {
std::vector<std::vector<cv::Point>> contours;
contours.reserve(64);
cv::findContours(binary_img, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE);
cv::Mat color_img;
if (img.channels() == 3) {
color_img = img;
} else if (img.channels() == 1) {
cv::cvtColor(img, color_img, cv::COLOR_BayerRG2BGR);
} else {
return {};
}
std::vector<Light> lights;
lights.reserve(contours.size());
for (const auto& contour: contours) {
const int n = static_cast<int>(contour.size());
if (n < 6)
continue;
Light light(contour);
if (!isLight(light))
continue;
int sum_r = 0;
int sum_b = 0;
for (const auto& pt: contour) {
const cv::Vec3b& pix = color_img.at<cv::Vec3b>(pt);
sum_r += pix[0];
sum_b += pix[2];
}
const int avg_diff = std::abs(sum_r - sum_b) / n;
if (avg_diff <= light_params_.color_diff_thresh)
continue;
light.color = (sum_r > sum_b) ? 0 : 1;
lights.emplace_back(std::move(light));
}
std::sort(lights.begin(), lights.end(), [](const Light& a, const Light& b) {
return a.center.x < b.center.x;
});
return lights;
}
bool isLight(const Light& light) noexcept {
// width / length 比例
const float ratio = light.width / light.length;
if (ratio <= light_params_.min_ratio || ratio >= light_params_.max_ratio)
return false;
if (light.tilt_angle >= light_params_.max_angle)
return false;
return true;
}
std::vector<ArmorObject>
matchLights(const std::vector<Light>& lights, int detect_color) noexcept {
const int n = static_cast<int>(lights.size());
std::vector<ArmorObject> armors;
armors.reserve(n);
for (int i = 0; i < n; ++i) {
const Light& l1 = lights[i];
if (l1.color != detect_color)
continue;
const float max_dx = l1.length * armor_params_.max_large_center_distance;
for (int j = i + 1; j < n; ++j) {
const Light& l2 = lights[j];
if (l2.color != detect_color)
continue;
const float dx = l2.center.x - l1.center.x;
if (dx > max_dx)
break;
ArmorType type = isArmor(l1, l2);
if (type == ArmorType::INVALID)
continue;
if (containLight(i, j, lights))
continue;
ArmorObject armor(l1, l2);
armor.type = type;
armor.color = (detect_color == 0) ? ArmorColor::RED : ArmorColor::BLUE;
armors.emplace_back(std::move(armor));
}
}
return armors;
}
// Check if there is another light in the boundingRect formed by the 2 lights
bool containLight(const int i, const int j, const std::vector<Light>& lights) noexcept {
const Light& l1 = lights[i];
const Light& l2 = lights[j];
float min_x = std::min({ l1.top.x, l1.bottom.x, l2.top.x, l2.bottom.x });
float max_x = std::max({ l1.top.x, l1.bottom.x, l2.top.x, l2.bottom.x });
float min_y = std::min({ l1.top.y, l1.bottom.y, l2.top.y, l2.bottom.y });
float max_y = std::max({ l1.top.y, l1.bottom.y, l2.top.y, l2.bottom.y });
const float avg_len = 0.5f * (l1.length + l2.length);
const float avg_wid = 0.5f * (l1.width + l2.width);
for (int k = i + 1; k < j; ++k) {
const Light& t = lights[k];
if (t.width > 2.0f * avg_wid)
continue;
if (t.length < 0.5f * avg_len)
continue;
const cv::Point2f& c = t.center;
if (c.x >= min_x && c.x <= max_x && c.y >= min_y && c.y <= max_y) {
return true;
}
}
return false;
}
ArmorType isArmor(const Light& l1, const Light& l2) noexcept {
const float len1 = l1.length;
const float len2 = l2.length;
if (len1 <= 1e-3f || len2 <= 1e-3f)
return ArmorType::INVALID;
const float min_len = (len1 < len2) ? len1 : len2;
const float max_len = (len1 < len2) ? len2 : len1;
if (min_len / max_len <= armor_params_.min_light_ratio)
return ArmorType::INVALID;
const cv::Point2f d = l1.center - l2.center;
const float dist2 = d.dot(d);
const float avg_len = 0.5f * (len1 + len2);
const float min_small = armor_params_.min_small_center_distance * avg_len;
const float max_small = armor_params_.max_small_center_distance * avg_len;
const float min_large = armor_params_.min_large_center_distance * avg_len;
const float max_large = armor_params_.max_large_center_distance * avg_len;
const float min_small2 = min_small * min_small;
const float max_small2 = max_small * max_small;
const float min_large2 = min_large * min_large;
const float max_large2 = max_large * max_large;
const bool small_ok = (dist2 >= min_small2 && dist2 < max_small2);
const bool large_ok = (dist2 >= min_large2 && dist2 < max_large2);
if (!(small_ok || large_ok))
return ArmorType::INVALID;
static const float tan_max_angle = std::tan(armor_params_.max_angle * CV_PI / 180.0f);
if (std::abs(d.y) >= std::abs(d.x) * tan_max_angle)
return ArmorType::INVALID;
float delta_angle = std::fabs(l1.angle - l2.angle);
if (delta_angle > 90.0f)
delta_angle = 180.0f - delta_angle;
if (delta_angle >= light_params_.max_angle_diff)
return ArmorType::INVALID;
return large_ok ? ArmorType::LARGE : ArmorType::SMALL;
}
cv::Mat extractNumber(const cv::Mat& src, const ArmorObject& armor) const noexcept {
constexpr int light_length = 12;
constexpr int warp_height = 28;
constexpr int small_armor_width = 32;
constexpr int large_armor_width = 54;
const cv::Size roi_size(20, 28);
cv::Point2f src_pts[4] = { armor.lights[0].bottom,
armor.lights[0].top,
armor.lights[1].top,
armor.lights[1].bottom };
const int warp_width =
(armor.type == ArmorType::SMALL) ? small_armor_width : large_armor_width;
const int top_y = (warp_height - light_length) / 2 - 1;
const int bottom_y = top_y + light_length;
cv::Point2f dst_pts[4] = {
{ 0.f, static_cast<float>(bottom_y) },
{ 0.f, static_cast<float>(top_y) },
{ static_cast<float>(warp_width - 1), static_cast<float>(top_y) },
{ static_cast<float>(warp_width - 1), static_cast<float>(bottom_y) }
};
cv::Mat warp_mat = cv::getPerspectiveTransform(src_pts, dst_pts);
cv::Mat warped;
cv::warpPerspective(
src,
warped,
warp_mat,
cv::Size(warp_width, warp_height),
cv::INTER_LINEAR,
cv::BORDER_CONSTANT,
0
);
const int roi_x = (warp_width - roi_size.width) >> 1;
if (roi_x < 0 || roi_x + roi_size.width > warp_width)
return cv::Mat();
cv::Mat number = warped(cv::Rect(roi_x, 0, roi_size.width, roi_size.height));
cv::threshold(number, number, 0, 255, cv::THRESH_BINARY | cv::THRESH_OTSU);
return number;
}
void setCallback(DetectorCallback callback) {
this->infer_callback_ = callback;
}
void toPts(ArmorObject& armor) {
if (armor.lights.size() != 2) {
armor.is_ok = false;
return;
}
armor.pts[0] = armor.lights[0].top;
armor.pts[1] = armor.lights[0].bottom;
armor.pts[2] = armor.lights[1].bottom;
armor.pts[3] = armor.lights[1].top;
}
void pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) {
frame.id = current_id_++;
std::vector<ArmorObject> objs_result;
auto roi = frame.img_frame.src_img(frame.expanded);
objs_result = detect(roi, frame.detect_color, target_number);
if (this->infer_callback_) {
this->infer_callback_(objs_result, frame);
return;
}
return;
}
LightParams light_params_;
ArmorParams armor_params_;
double classifier_threshold_ = 0.5;
std::unique_ptr<NumberClassifier> number_classifier_;
DetectorCallback infer_callback_;
int current_id_ = 0;
};
ArmorDetectorOpenCV::ArmorDetectorOpenCV(const YAML::Node& config) {
_impl = std::make_unique<Impl>(config);
}
ArmorDetectorOpenCV::~ArmorDetectorOpenCV() {
_impl.reset();
}
void ArmorDetectorOpenCV::setCallback(DetectorCallback callback) {
_impl->setCallback(callback);
}
void ArmorDetectorOpenCV::pushInput(
CommonFrame& frame,
const std::optional<ArmorNumber>& target_number
) {
_impl->pushInput(frame, target_number);
}
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,42 @@
// Copyright Chen Jun 2023. Licensed under the MIT License.
//
// Additional modifications and features by Chengfu Zou, Labor. Licensed under
// Apache License 2.0.
//
// Copyright (C) FYT Vision Group. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tasks/auto_aim/armor_detect/armor_detector_base.hpp"
#include "tasks/type_common.hpp"
namespace wust_vision {
namespace auto_aim {
class ArmorDetectorOpenCV: public ArmorDetectorBase {
public:
using Ptr = std::unique_ptr<ArmorDetectorOpenCV>;
explicit ArmorDetectorOpenCV(const YAML::Node& config);
static Ptr create(const YAML::Node& config) {
return std::make_unique<ArmorDetectorOpenCV>(config);
}
~ArmorDetectorOpenCV();
void
pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) override;
void setCallback(DetectorCallback callback) override;
private:
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,206 @@
// Copyright 2025 Zikang Xie
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef USE_OPENVINO
#include "tasks/auto_aim/armor_detect/openvino/armor_detector_openvino.hpp"
#include "tasks/auto_aim/armor_detect/armor_detector_common.hpp"
#include "tasks/auto_aim/armor_detect/armor_infer.hpp"
#include "tasks/utils/utils.hpp"
#include "wust_vl/ml_net/openvino/openvino_net.hpp"
namespace wust_vision {
namespace auto_aim {
struct ArmorDetectorOpenVino::Impl {
public:
Impl(const YAML::Node& config, bool use_armor_detect_common) {
if (use_armor_detect_common) {
armor_detect_common_ = std::make_unique<ArmorDetectorCommon>(config);
}
std::string model_type = config["openvino"]["model_type"].as<std::string>();
auto model = armor_infer::modeFromString(model_type);
float conf_threshold = config["openvino"]["conf_threshold"].as<float>();
int top_k = config["openvino"]["top_k"].as<int>();
float nms_threshold = config["openvino"]["nms_threshold"].as<float>();
bool use_throughputmode = config["openvino"]["use_throughputmode"].as<bool>();
armor_infer_ = std::make_unique<armor_infer::ArmorInfer>(
model,
conf_threshold,
nms_threshold,
top_k
);
std::string model_path =
utils::expandEnv(config["openvino"]["model_path"].as<std::string>());
auto device_name = config["openvino"]["device_name"].as<std::string>();
ov_params_.model_path = model_path;
ov_params_.device_name = device_name;
ov_params_.mode = use_throughputmode ? ov::hint::PerformanceMode::THROUGHPUT
: ov::hint::PerformanceMode::LATENCY;
initOpenVINO();
}
void
initOpenVINO(wust_vl::video::PixelFormat pixel_format = wust_vl::video::PixelFormat::BGR) {
openvino_net_.reset();
openvino_net_ = std::make_unique<wust_vl::ml_net::OpenvinoNet>();
const auto ppp_init_fun = [this, pixel_format](ov::preprocess::PrePostProcessor& ppp) {
if (pixel_format == wust_vl::video::PixelFormat::RGB) {
ppp.input()
.tensor()
.set_element_type(ov::element::u8)
.set_layout("NHWC")
.set_color_format(ov::preprocess::ColorFormat::RGB);
} else {
ppp.input()
.tensor()
.set_element_type(ov::element::u8)
.set_layout("NHWC")
.set_color_format(ov::preprocess::ColorFormat::BGR);
}
pixel_format_ = pixel_format;
const bool RGB = armor_infer_->inputRGB();
const float scale = armor_infer_->useNorm() ? 255.0f : 1.0f;
if (RGB) {
ppp.input()
.preprocess()
.convert_element_type(ov::element::f32)
.convert_color(ov::preprocess::ColorFormat::RGB)
.scale(scale);
} else {
ppp.input()
.preprocess()
.convert_element_type(ov::element::f32)
.convert_color(ov::preprocess::ColorFormat::BGR)
.scale(scale);
}
ppp.input().model().set_layout("NCHW");
ppp.output().tensor().set_element_type(ov::element::f32);
};
openvino_net_->init(ov_params_, ppp_init_fun);
}
~Impl() {
openvino_net_.reset();
armor_detect_common_.reset();
}
void setCallback(DetectorCallback callback) {
infer_callback_ = callback;
}
bool processCallback(
const CommonFrame& frame,
const std::optional<ArmorNumber>& target_number
) const {
const auto start = std::chrono::steady_clock::now();
Eigen::Matrix3f transform_matrix;
const auto roi = frame.img_frame.src_img(frame.expanded);
cv::Mat resized_img = utils::letterbox(
roi,
transform_matrix,
armor_infer_->inputW(),
armor_infer_->inputH()
);
const auto input_info = openvino_net_->getInputInfo();
const auto input_tensor =
ov::Tensor(input_info.first, input_info.second, resized_img.data);
const auto output = openvino_net_->infer_thread_local(input_tensor);
// Process output data
const auto output_shape = output.get_shape();
const float* ptr = output.data<const float>();
cv::Mat
output_buffer(output_shape[1], output_shape[2], CV_32F, const_cast<float*>(ptr));
// Parsed variable
auto objs_result = armor_infer_->postProcess(output_buffer);
std::vector<ArmorObject> armors;
if (armor_detect_common_) {
armors = armor_detect_common_->detectNet(
resized_img,
objs_result,
transform_matrix,
frame.detect_color,
target_number
);
// Call callback function
if (this->infer_callback_) {
this->infer_callback_(armors, frame);
return true;
}
} else {
for (auto obj: objs_result) {
auto detect_color = frame.detect_color;
if (detect_color == 0 && obj.color == ArmorColor::BLUE) {
continue;
} else if (detect_color == 1 && obj.color == ArmorColor::RED) {
continue;
}
obj.transform(transform_matrix);
armors.push_back(obj);
}
if (this->infer_callback_) {
this->infer_callback_(armors, frame);
return true;
}
}
return false;
}
void pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) {
if (resetting_) {
return;
}
frame.id = current_id_++;
if (frame.img_frame.pixel_format != pixel_format_) {
resetting_ = true;
initOpenVINO(frame.img_frame.pixel_format);
resetting_ = false;
}
processCallback(frame, target_number);
}
private:
wust_vl::video::PixelFormat pixel_format_ = wust_vl::video::PixelFormat::BGR;
std::unique_ptr<wust_vl::ml_net::OpenvinoNet> openvino_net_;
DetectorCallback infer_callback_;
std::unique_ptr<ArmorDetectorCommon> armor_detect_common_;
std::unique_ptr<armor_infer::ArmorInfer> armor_infer_;
int current_id_ = 0;
wust_vl::ml_net::OpenvinoNet::Params ov_params_;
bool resetting_ = false;
};
ArmorDetectorOpenVino::ArmorDetectorOpenVino(
const YAML::Node& config,
bool use_armor_detect_common
) {
_impl = std::make_unique<Impl>(config, use_armor_detect_common);
}
ArmorDetectorOpenVino::~ArmorDetectorOpenVino() {
_impl.reset();
}
void ArmorDetectorOpenVino::setCallback(DetectorCallback callback) {
_impl->setCallback(callback);
}
void ArmorDetectorOpenVino::pushInput(
CommonFrame& frame,
const std::optional<ArmorNumber>& target_number
) {
_impl->pushInput(frame, target_number);
}
} // namespace auto_aim
} // namespace wust_vision
#endif

View File

@@ -0,0 +1,38 @@
// Copyright 2023 Yunlong Feng
// Copyright 2025 Lihan Chen
// Copyright 2025 XiaoJian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tasks/auto_aim/armor_detect/armor_detector_base.hpp"
namespace wust_vision {
namespace auto_aim {
class ArmorDetectorOpenVino: public ArmorDetectorBase {
public:
using Ptr = std::unique_ptr<ArmorDetectorOpenVino>;
explicit ArmorDetectorOpenVino(const YAML::Node& config, bool use_armor_detect_common);
static Ptr create(const YAML::Node& config, bool use_armor_detect_common) {
return std::make_unique<ArmorDetectorOpenVino>(config, use_armor_detect_common);
}
~ArmorDetectorOpenVino();
void
pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) override;
void setCallback(DetectorCallback callback) override;
private:
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,320 @@
// Copyright 2025 Zikang Xie
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef USE_TRT
#include "tasks/auto_aim/armor_detect/tensorrt/armor_detector_tensorrt.hpp"
#include "cuda_infer/armor_infer.hpp"
#include "tasks/auto_aim/armor_detect/armor_detector_common.hpp"
#include "tasks/auto_aim/armor_detect/armor_infer.hpp"
#include "tasks/utils/utils.hpp"
#include "wust_vl/common/concurrency/adaptive_resource_pool.hpp"
#include "wust_vl/common/utils/logger.hpp"
#include "wust_vl/common/utils/timer.hpp"
#include "wust_vl/ml_net/tensorrt/tensorrt_net.hpp"
namespace wust_vision {
namespace auto_aim {
static constexpr int MAX_SRC_IMG_W = 1920;
static constexpr int MAX_SRC_IMG_H = 1440;
struct ArmorDetectorTrt::Impl {
public:
struct Infer {
std::unique_ptr<nvinfer1::IExecutionContext> context;
std::unique_ptr<armor_cuda_infer::CudaInfer> cuda_infer;
};
Impl(const YAML::Node& config, bool use_armor_detect_common) {
if (use_armor_detect_common) {
armor_detect_common_ = std::make_unique<ArmorDetectorCommon>(config);
}
const double conf_threshold = config["tensorrt"]["conf_threshold"].as<float>();
const double nms_threshold = config["tensorrt"]["nms_threshold"].as<float>();
const int top_k = config["tensorrt"]["top_k"].as<int>();
const int max_infer_running = config["tensorrt"]["max_infer_running"].as<int>();
const double min_free_mem_ratio = config["tensorrt"]["min_free_mem_ratio"].as<double>();
use_cuda_pre_ = config["tensorrt"]["use_cuda_pre"].as<bool>();
log_time_ = config["tensorrt"]["log_time"].as<bool>();
const std::string model_type = config["tensorrt"]["model_type"].as<std::string>();
const std::string model_path =
utils::expandEnv(config["tensorrt"]["model_path"].as<std::string>());
const int device_id = config["tensorrt"]["device_id"].as<int>();
cudaSetDevice(device_id);
const auto model = armor_infer::modeFromString(model_type);
armor_infer_ = std::make_unique<armor_infer::ArmorInfer>(
model,
conf_threshold,
nms_threshold,
top_k
);
trt_net_ = std::make_unique<wust_vl::ml_net::TensorRTNet>();
wust_vl::ml_net::TensorRTNet::Params trt_params;
trt_params.model_path = model_path;
trt_params.input_dims =
nvinfer1::Dims4 { 1, 3, armor_infer_->inputH(), armor_infer_->inputW() };
trt_net_->init(trt_params);
const auto input_output_dims = trt_net_->getInputOutputDims();
input_dims_ = std::get<0>(input_output_dims);
output_dims_ = std::get<1>(input_output_dims);
wust_vl::common::concurrency::AdaptiveResourcePool<Infer>::Params pool_params;
pool_params.resource_initializer = [&]() {
std::vector<std::unique_ptr<Infer>> infers;
for (int i = 0; i < max_infer_running; ++i) {
auto infer = std::make_unique<Infer>();
auto ctx = trt_net_->getAContext();
infer->context = std::unique_ptr<nvinfer1::IExecutionContext>(ctx);
if (use_cuda_pre_) {
infer->cuda_infer = std::make_unique<armor_cuda_infer::CudaInfer>();
infer->cuda_infer->init(
MAX_SRC_IMG_W,
MAX_SRC_IMG_H,
armor_infer_->inputW(),
armor_infer_->inputH()
);
}
if (!infer->context) {
WUST_ERROR("TRT") << "create infer failed, missing context"
<< " index:" << i;
continue;
}
if (use_cuda_pre_ && !infer->cuda_infer) {
WUST_ERROR("TRT") << "create infer failed, missing cuda_infer"
<< " index:" << i;
continue;
}
size_t free_mem, total_mem;
cudaMemGetInfo(&free_mem, &total_mem);
WUST_DEBUG("TRT") << "Free GPU memory:" << free_mem / 1024.0 / 1024.0 << "MB"
<< "Total GPU memory:" << total_mem / 1024.0 / 1024.0 << "MB";
double free_mem_ratio =
static_cast<double>(free_mem) / static_cast<double>(total_mem);
if (free_mem_ratio < min_free_mem_ratio && i > 0) {
WUST_WARN("TRT") << "GPU memory is not enough!"
<< "Free GPU memory:" << free_mem_ratio * 100 << "%";
WUST_INFO("TRT") << "Cut remaining infer";
break;
}
infers.emplace_back(std::move(infer));
WUST_INFO("TRT") << "create execution context success"
<< "index:" << i;
}
return infers;
};
auto release_func = [&](std::unique_ptr<Infer>& resource) {
if (resource) {
if (resource->cuda_infer) {
resource->cuda_infer.reset();
}
}
};
auto restore_func = [&](size_t idx) -> std::unique_ptr<Infer> {
auto infer = std::make_unique<Infer>();
auto ctx = trt_net_->getAContext();
infer->context = std::unique_ptr<nvinfer1::IExecutionContext>(ctx);
if (use_cuda_pre_) {
infer->cuda_infer = std::make_unique<armor_cuda_infer::CudaInfer>();
infer->cuda_infer->init(
MAX_SRC_IMG_W,
MAX_SRC_IMG_H,
armor_infer_->inputW(),
armor_infer_->inputH()
);
}
if (!infer->context) {
WUST_ERROR("TRT") << "create infer failed, missing context";
return nullptr;
}
if ((use_cuda_pre_) && !infer->cuda_infer) {
WUST_ERROR("TRT") << "create infer failed, missing cuda_infer";
return nullptr;
}
return infer;
};
pool_params.restore_func = restore_func;
pool_params.release_func = release_func;
pool_params.can_restore = [&](size_t active_count) { return false; };
pool_params.should_release = [&](size_t active_count) { return false; };
pool_params.logger = [](const std::string& msg) {
WUST_INFO("ArmorDetectorTrt:infer pool") << msg;
};
infer_pool_ =
std::make_unique<wust_vl::common::concurrency::AdaptiveResourcePool<Infer>>(
pool_params
);
}
~Impl() {
if (infer_pool_) {
infer_pool_.reset();
}
trt_net_.reset();
armor_detect_common_.reset();
}
void setCallback(DetectorCallback callback) {
infer_callback_ = callback;
}
struct Tag {};
void processCallback(
const CommonFrame& frame,
Infer* infer,
const std::optional<ArmorNumber>& target_number
) const {
std::vector<ArmorObject> armors;
const auto t0 = wust_vl::common::utils::time_utils::now();
Eigen::Matrix3f transform_matrix;
std::vector<ArmorObject> objs_result;
void* input_tensor_ptr;
const cv::Mat roi = frame.img_frame.src_img(frame.expanded);
cv::Mat resized_img;
const float scale = armor_infer_->useNorm() ? 1.0f / 255.0f : 1.0f;
const bool swap_rb = armor_infer_->inputRGB()
!= (frame.img_frame.pixel_format == wust_vl::video::PixelFormat::RGB);
if (infer->cuda_infer && use_cuda_pre_) {
input_tensor_ptr =
infer->cuda_infer
->preprocess_pitched( //支持不连续内存,无需拷贝后输入可直接传roi的指针
roi.data,
roi.cols,
roi.rows,
roi.step,
scale,
swap_rb,
transform_matrix,
trt_net_->getStream()
);
resized_img = infer->cuda_infer->tensorToMat( //nchw_float_to_hwc_uchar
static_cast<float*>(input_tensor_ptr),
armor_infer_->inputW(),
armor_infer_->inputH(),
scale,
trt_net_->getStream()
);
} else {
resized_img = utils::letterbox(
roi,
transform_matrix,
armor_infer_->inputW(),
armor_infer_->inputH()
);
const cv::Mat blob = cv::dnn::blobFromImage(
resized_img,
scale,
cv::Size(armor_infer_->inputW(), armor_infer_->inputH()),
cv::Scalar(0, 0, 0),
swap_rb
);
trt_net_->input2Device(blob.ptr<float>());
input_tensor_ptr = trt_net_->getInputTensorPtr();
}
const auto t1 = wust_vl::common::utils::time_utils::now();
if (infer->context && input_tensor_ptr) {
trt_net_->infer(input_tensor_ptr, infer->context.get());
}
const auto t2 = wust_vl::common::utils::time_utils::now();
const cv::Mat
output_mat(output_dims_.d[1], output_dims_.d[2], CV_32F, trt_net_->output2Host());
cudaStreamSynchronize(trt_net_->getStream());
objs_result = armor_infer_->postProcess(output_mat);
const auto t3 = wust_vl::common::utils::time_utils::now();
if (log_time_) {
WUST_INFO("TRT") << std::fixed << std::setprecision(3) << "pre "
<< wust_vl::common::utils::time_utils::durationMs(t0, t1) << " "
<< "infer "
<< wust_vl::common::utils::time_utils::durationMs(t1, t2) << " "
<< "post "
<< wust_vl::common::utils::time_utils::durationMs(t2, t3) << " "
<< "total "
<< wust_vl::common::utils::time_utils::durationMs(t0, t3);
}
infer_pool_->release(infer);
if (armor_detect_common_) {
armors = armor_detect_common_->detectNet(
resized_img,
objs_result,
transform_matrix,
frame.detect_color,
target_number
);
// Call callback function
if (this->infer_callback_) {
this->infer_callback_(armors, frame);
return;
}
} else {
for (auto obj: objs_result) {
auto detect_color = frame.detect_color;
if (detect_color == 0 && obj.color == ArmorColor::BLUE) {
continue;
} else if (detect_color == 1 && obj.color == ArmorColor::RED) {
continue;
}
obj.transform(transform_matrix);
armors.push_back(obj);
}
if (this->infer_callback_) {
this->infer_callback_(armors, frame);
return;
}
}
return;
}
void pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) {
if (infer_pool_) {
auto infer_ptr = infer_pool_->acquire();
if (infer_ptr != nullptr) {
frame.id = current_id_++;
this->processCallback(frame, infer_ptr, target_number);
}
}
}
private:
bool use_cuda_pre_;
bool log_time_;
nvinfer1::Dims input_dims_;
nvinfer1::Dims output_dims_;
DetectorCallback infer_callback_;
std::unique_ptr<ArmorDetectorCommon> armor_detect_common_;
std::unique_ptr<wust_vl::common::concurrency::AdaptiveResourcePool<Infer>> infer_pool_;
std::unique_ptr<armor_infer::ArmorInfer> armor_infer_;
int current_id_ = 0;
std::unique_ptr<wust_vl::ml_net::TensorRTNet> trt_net_;
};
ArmorDetectorTrt::ArmorDetectorTrt(const YAML::Node& config, bool use_armor_detect_common) {
_impl = std::make_unique<Impl>(config, use_armor_detect_common);
}
ArmorDetectorTrt::~ArmorDetectorTrt() {
_impl.reset();
}
void ArmorDetectorTrt::setCallback(DetectorCallback callback) {
_impl->setCallback(callback);
}
void ArmorDetectorTrt::pushInput(
CommonFrame& frame,
const std::optional<ArmorNumber>& target_number
) {
_impl->pushInput(frame, target_number);
}
} // namespace auto_aim
} // namespace wust_vision
#endif

View File

@@ -0,0 +1,40 @@
// Copyright 2025 Zikang Xie
// Copyright 2025 XiaoJian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tasks/auto_aim/armor_detect/armor_detector_base.hpp"
namespace wust_vision {
namespace auto_aim {
class ArmorDetectorTrt: public ArmorDetectorBase {
public:
using Ptr = std::unique_ptr<ArmorDetectorTrt>;
explicit ArmorDetectorTrt(const YAML::Node& config, bool use_armor_detect_common);
static Ptr create(const YAML::Node& config, bool use_armor_detect_common) {
return std::make_unique<ArmorDetectorTrt>(config, use_armor_detect_common);
}
~ArmorDetectorTrt();
void
pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) override;
void setCallback(DetectorCallback callback) override;
private:
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,442 @@
#include "armor_omni.hpp"
#include "3rdparty/angles.h"
#include "tasks/auto_aim/armor_tracker/motion_models/motion_modelypdv2.hpp"
#include "tasks/auto_aim/armor_tracker/target.hpp"
#include "tasks/auto_aim/auto_aim_fsm.hpp"
#include "tasks/auto_aim/type.hpp"
#include "wust_vl/common/concurrency/ThreadPool.h"
#include "wust_vl/common/utils/timer.hpp"
#include "wust_vl/video/camera.hpp"
// clang-format off
#include "tasks/auto_aim/armor_detect/armor_detector_factory.hpp"
// clang-format on
#include "tasks/auto_aim/armor_tracker/trackerv3.hpp"
#include "tasks/auto_aim/armor_where/armor_where.hpp"
namespace wust_vision::auto_aim {
struct ArmorOmni::Impl {
struct One {
using Ptr = std::shared_ptr<One>;
One(int id) {
self_id = id;
total_score = 0;
}
static Ptr create(int id) {
return std::make_shared<One>(id);
}
void load(
const YAML::Node& config,
wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter
) noexcept {
auto yaw_in_big_yaw_deg = config["yaw_in_big_yaw_deg"].as<double>();
yaw_in_big_yaw = yaw_in_big_yaw_deg / 180.0 * M_PI;
camera = std::make_shared<wust_vl::video::Camera>();
camera->init(config);
std::string camera_info_path =
utils::expandEnv(config["camera_info_path"].as<std::string>());
YAML::Node config_camera_info = YAML::LoadFile(camera_info_path);
std::vector<double> camera_k =
config_camera_info["camera_matrix"]["data"].as<std::vector<double>>();
std::vector<double> camera_d =
config_camera_info["distortion_coefficients"]["data"].as<std::vector<double>>();
assert(camera_k.size() == 9);
assert(camera_d.size() == 5);
cv::Mat K(3, 3, CV_64F);
std::memcpy(K.data, camera_k.data(), 9 * sizeof(double));
cv::Mat D(1, 5, CV_64F);
std::memcpy(D.data, camera_d.data(), 5 * sizeof(double));
camera_info = std::make_pair(K.clone(), D.clone());
auto gobal_config = auto_aim_config_parameter->getConfig();
armor_where = ArmorWhere::create(gobal_config["armor_where"], camera_info);
tracker = Tracker::create(auto_aim_config_parameter);
}
void start() noexcept {
if (camera)
camera->start();
}
int self_id;
double total_score;
std::shared_ptr<wust_vl::video::Camera> camera;
ArmorWhere::Ptr armor_where;
Tracker::Ptr tracker;
std::pair<cv::Mat, cv::Mat> camera_info;
double yaw_in_big_yaw;
Target target;
};
struct Obj {
ArmorObject armor;
double score = 0;
One::Ptr one;
std::chrono::steady_clock::time_point timestamp;
Obj(const ArmorObject& a,
double s,
const One::Ptr& o,
std::chrono::steady_clock::time_point ts):
armor(a),
score(s),
one(o),
timestamp(ts) {
if (one)
one->total_score += score;
}
~Obj() {
if (one)
one->total_score -= score;
}
};
static constexpr const char* _ML_CONFIG = "config/omni/detect_ml.yaml";
static constexpr const char* _OPENCV_CONFIG = "config/omni/detect_opencv.yaml";
~Impl() {
run_flag_ = false;
}
Impl(bool detect_color_init, const Ctx& ctx) {
ctx_ = ctx;
detect_color_ = detect_color_init;
config_ = YAML::LoadFile(OMNI_CONFIG);
auto_aim_config_parameter_ = wust_vl::common::utils::Parameter::create();
auto_aim_config_parameter_->loadFromFile(OMNI_CONFIG);
auto cameras = config_["cameras"].as<std::vector<std::string>>();
for (size_t i = 0; i < cameras.size(); ++i) {
auto real_path = utils::expandEnv(cameras[i]);
One::Ptr one = One::create(i);
one->load(YAML::LoadFile(real_path), auto_aim_config_parameter_);
ones_.emplace_back(one);
}
auto_aim_config_parameter_->reloadFromOldPath();
fps_ = config_["fps"].as<int>(30);
active_time_ = config_["active_time"].as<double>(0.5);
max_infer_running_ = config_["max_infer_running"].as<int>(0);
min_score_ = config_["min_score"].as<double>();
const std::string armor_detect_backend =
config_["armor_detect_backend"].as<std::string>("");
armor_detector_ = DetectorFactory::createArmorDetector(
armor_detect_backend,
false,
_OPENCV_CONFIG,
_ML_CONFIG
);
armor_detector_->setCallback(std::bind(
&ArmorOmni::Impl::ArmorDetectCallback,
this,
std::placeholders::_1,
std::placeholders::_2
));
thread_pool_ =
std::make_unique<wust_vl::common::concurrency::ThreadPool>(max_infer_running_);
timer_ = std::make_unique<wust_vl::common::utils::Timer>("omni");
latency_averager_ = std::make_unique<wust_vl::common::concurrency::Averager<double>>(100);
}
void start() noexcept {
run_flag_ = true;
for (auto& one: ones_) {
one->start();
}
if (timer_) {
const auto timercallback =
std::bind(&ArmorOmni::Impl::timerCallback, this, std::placeholders::_1);
const double rate_hz = fps_;
timer_->start(rate_hz, timercallback);
}
}
int getOneId() const {
static int one_id = 0;
int id = one_id;
one_id = (one_id + 1) % ones_.size();
return id;
}
void timerCallback(double dt_ms) noexcept {
if (!run_flag_ || main_tracking_)
return;
int one_id = getOneId();
auto& one = ones_[one_id];
auto frame = one->camera->readImage();
if (frame.src_img.empty())
return;
CommonFrame common_frame;
common_frame.img_frame = frame;
common_frame.id = one_id;
common_frame.detect_color = detect_color_;
common_frame.expanded = cv::Rect(0, 0, frame.src_img.cols, frame.src_img.rows);
common_frame.offset = cv::Point2f(0, 0);
common_frame.any_ctx = one;
detect(common_frame);
}
void detect(CommonFrame& common_frame) {
if (infer_running_count_ >= max_infer_running_ || !thread_pool_ || !run_flag_) {
return;
}
infer_running_count_++;
if (common_frame.img_frame.src_img.empty()) {
infer_running_count_--;
return;
}
if (armor_detector_) {
armor_detector_->pushInput(common_frame, std::nullopt);
}
infer_running_count_--;
}
void setDetectColor(bool flag) noexcept {
detect_color_ = flag;
}
void updateMainTracking(bool flag) noexcept {
main_tracking_ = flag;
}
int getBestTarget() noexcept {
update();
return best_target_;
}
void
ArmorDetectCallback(const std::vector<ArmorObject>& objs, const CommonFrame& frame) noexcept {
auto one = std::any_cast<One::Ptr>(frame.any_ctx);
std::vector<ArmorObject> sorted_objs;
for (const auto& obj: objs) {
if (obj.color == ArmorColor::NONE || obj.color == ArmorColor::PURPLE) {
continue;
}
sorted_objs.push_back(obj);
std::lock_guard<std::mutex> lock(active_results_mutex_);
active_results_.emplace_back(obj, obj.confidence, one, frame.img_frame.timestamp);
}
update();
Armors armors;
armors.timestamp = frame.img_frame.timestamp;
Eigen::Matrix3d R_gimbal2odom = Eigen::Matrix3d::Identity();
auto& car_b = ctx_.car_motion_buffer;
auto& big_yaw_b = ctx_.big_yaw_motion_buffer;
if (car_b && big_yaw_b) {
const auto t_query = armors.timestamp;
auto apply_motion = [&](const auto& att, const auto& att2) {
R_gimbal2odom =
Eigen::AngleAxisd(
angles::normalize_angle(att2.data.big_yaw + one->yaw_in_big_yaw),
Eigen::Vector3d::UnitZ()
)
* Eigen::AngleAxisd(0.0, Eigen::Vector3d::UnitY())
* Eigen::AngleAxisd(att.data.roll, Eigen::Vector3d::UnitX());
};
auto car_past_att = car_b->get_interpolated(t_query);
auto big_yaw_past_att = big_yaw_b->get_interpolated(t_query);
if (car_past_att && big_yaw_past_att) {
apply_motion(*car_past_att, *big_yaw_past_att);
} else {
auto last_att = car_b->get_last();
auto last_big_yaw = big_yaw_b->get_last();
if (last_att && last_big_yaw) {
apply_motion(*last_att, *last_big_yaw);
}
}
}
Eigen::Matrix3d R_camera2gimbal;
R_camera2gimbal << 0.0, 0.0, 1.0, -1.0, -0.0, 0.0, 0.0, -1.0, 0.0;
Eigen::Matrix4d T_camera_to_odom = utils::computeCameraToOdomTransform(
R_gimbal2odom,
R_camera2gimbal,
Eigen::Vector3d::Zero()
);
armors.armors = one->armor_where->where(sorted_objs, T_camera_to_odom);
for (auto& armor: armors.armors) {
armor.timestamp = armors.timestamp;
}
auto& target = one->target;
target = one->tracker->track(armors);
const auto now = std::chrono::steady_clock::now();
const auto latency_ms =
wust_vl::common::utils::time_utils::durationMs(frame.img_frame.timestamp, now);
latency_averager_->add(latency_ms);
latency_ms_ = latency_averager_->average();
detect_count_++;
printStats();
}
void update() noexcept {
std::lock_guard<std::mutex> lock(active_results_mutex_);
while (!active_results_.empty()) {
auto& obj = active_results_.front();
if (std::abs(wust_vl::common::utils::time_utils::durationSec(
obj.timestamp,
wust_vl::common::utils::time_utils::now()
))
> active_time_)
{
active_results_.pop_front();
} else {
break;
}
}
if (active_results_.empty()) {
best_target_ = -1;
return;
}
double max_score = min_score_;
best_target_ = -1;
for (size_t i = 0; i < ones_.size(); ++i) {
if (ones_[i]->total_score > max_score) {
max_score = ones_[i]->total_score;
best_target_ = ones_[i]->self_id;
}
}
}
GimbalCmd solve(double bullet_speed) {
GimbalCmd gimbal_cmd;
std::optional<Target> target;
int best_target = getBestTarget();
if (best_target < 0) {
target = std::nullopt;
} else {
target = ones_[best_target]->target;
}
auto& very_aimer = ctx_.very_aimer;
if (!very_aimer) {
return gimbal_cmd;
}
if (target.has_value() && target->checkTargetAppear()) {
try {
gimbal_cmd = very_aimer->veryAim(
target.value(),
bullet_speed,
AutoAimFsm::AIM_WHOLE_CAR_CENTER
);
gimbal_cmd.enable_pitch_diff = 0.0;
gimbal_cmd.enable_yaw_diff = 0.0;
gimbal_cmd.fire_advice = false;
} catch (...) {
WUST_ERROR("omni") << "VeryAim error";
}
} else {
gimbal_cmd.appear = false;
}
return gimbal_cmd;
}
void printStats() {
utils::XSecOnce(
[&] {
WUST_INFO("armor_omni") << "det: " << detect_count_ << " best: " << best_target_
<< " lat: " << latency_ms_;
detect_count_ = 0;
},
1.0
);
}
int fps_;
int max_infer_running_ = 0;
std::atomic<int> infer_running_count_ { 0 };
bool detect_color_;
bool main_tracking_ = false;
bool run_flag_ = false;
double active_time_ = 0;
std::deque<Obj> active_results_;
mutable std::mutex active_results_mutex_;
std::vector<One::Ptr> ones_;
YAML::Node config_;
std::unique_ptr<wust_vl::common::concurrency::ThreadPool> thread_pool_;
std::unique_ptr<wust_vl::common::utils::Timer> timer_;
ArmorDetectorBase::Ptr armor_detector_;
wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter_;
std::unique_ptr<wust_vl::common::concurrency::Averager<double>> latency_averager_;
int best_target_ = -1;
int detect_count_ = 0;
double latency_ms_;
double min_score_ = 10.0;
Ctx ctx_;
};
ArmorOmni::ArmorOmni(bool detect_color_init, const Ctx& ctx):
_impl(std::make_unique<Impl>(detect_color_init, ctx)) {}
ArmorOmni::~ArmorOmni() {
_impl.reset();
}
void ArmorOmni::start() {
_impl->start();
}
void ArmorOmni::setDetectColor(bool flag) {
_impl->setDetectColor(flag);
}
void ArmorOmni::updateMainTracking(bool flag) {
_impl->updateMainTracking(flag);
}
int ArmorOmni::getBestTarget() {
return _impl->getBestTarget();
}
GimbalCmd ArmorOmni::solve(double bullet_speed) {
return _impl->solve(bullet_speed);
}
} // namespace wust_vision::auto_aim

View File

@@ -0,0 +1,36 @@
#pragma once
#include "tasks/auto_aim/armor_control/very_aimer.hpp"
#include "tasks/type_common.hpp"
#include <memory>
#include <wust_vl/common/utils/motion_buffer.hpp>
#include <yaml-cpp/node/node.h>
namespace wust_vision {
namespace auto_aim {
class ArmorOmni {
public:
struct Ctx {
std::shared_ptr<wust_vl::common::utils::MotionBufferGeneric<CarMotion, 1024>>
car_motion_buffer;
std::shared_ptr<wust_vl::common::utils::MotionBufferGeneric<BigYaw, 1024>>
big_yaw_motion_buffer;
VeryAimer::Ptr very_aimer;
};
static constexpr const char* OMNI_CONFIG = "config/omni/omni.yaml";
using Ptr = std::unique_ptr<ArmorOmni>;
ArmorOmni(bool detect_color_init, const Ctx& ctx);
static Ptr create(bool detect_color_init, const Ctx& ctx) {
return std::make_unique<ArmorOmni>(detect_color_init, ctx);
}
~ArmorOmni();
void start();
void setDetectColor(bool flag);
void updateMainTracking(bool flag);
int getBestTarget();
GimbalCmd solve(double bullet_speed);
private:
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,286 @@
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
// ceres
#include <ceres/ceres.h>
#include <opencv2/calib3d.hpp>
#include <opencv2/core/mat.hpp>
// project
#include "KalmanHyLib/kalman_hybird_lib.hpp"
namespace crazy {
enum class MotionModel {
CONSTANT_VELOCITY = 0, // Constant velocity
CONSTANT_ROTATION = 1, // Constant rotation velocity
CONSTANT_VEL_ROT = 2 // Constant velocity and rotation velocity
};
constexpr int X_N = 11, Z_N = 8;
using VecZ = Eigen::Matrix<double, Z_N, 1>;
using VecX = Eigen::Matrix<double, X_N, 1>;
enum class Mean : uint8_t {
PLBX = 0,
PLBY = 1,
PLTX = 2,
PLTY = 3,
PRTX = 4,
PRTY = 5,
PRBX = 6,
PRBY = 7,
// IMUY = 4,
// IMUP = 5,
// IMUR = 6,
Z_N = 8
};
enum class State : uint8_t {
CX = 0,
VCX = 1,
CY = 2,
VCY = 3,
CZ = 4,
VCZ = 5,
YAW = 6,
VYAW = 7,
R = 8,
L = 9,
H = 10,
outpost01DZ = 9,
outpost02DZ = 10,
X_N = 11
};
struct Predict {
Predict() = default;
explicit Predict(
double dt,
MotionModel model = MotionModel::CONSTANT_VEL_ROT,
double vrx = 0.0,
double vry = 0.0,
double vrz = 0.0
):
dt(dt),
model(model),
vrx(vrx),
vry(vry),
vrz(vrz) {}
template<typename T>
void operator()(const T x0[X_N], T x1[X_N]) const {
for (int i = 0; i < X_N; i++) {
x1[i] = x0[i];
}
// v_xyz
if (model == MotionModel::CONSTANT_VEL_ROT || model == MotionModel::CONSTANT_VELOCITY) {
// linear velocity
x1[(int)State::CX] += x0[(int)State::VCX] * T(dt);
x1[(int)State::CY] += x0[(int)State::VCY] * T(dt);
x1[(int)State::CZ] += x0[(int)State::VCZ] * T(dt);
} else {
// no velocity
x1[(int)State::VCX] *= T(0.);
x1[(int)State::VCY] *= T(0.);
x1[(int)State::VCZ] *= T(0.);
}
x1[(int)State::CX] -= T(vrx) * T(dt);
x1[(int)State::CY] -= T(vry) * T(dt);
x1[(int)State::CZ] -= T(vrz) * T(dt);
x1[(int)State::VCZ] *= T(0.);
// v_yaw
if (model == MotionModel::CONSTANT_VEL_ROT || model == MotionModel::CONSTANT_ROTATION) {
// angular velocity
x1[(int)State::YAW] += x0[(int)State::VYAW] * T(dt);
} else {
// no rotation
x1[(int)State::VYAW] *= T(0.);
}
}
double dt;
MotionModel model;
double vrx, vry, vrz;
};
constexpr double SMALL_ARMOR_WIDTH = 133.0 / 1000.0; // 135
constexpr double SMALL_ARMOR_HEIGHT = 50.0 / 1000.0; // 55
constexpr double LARGE_ARMOR_WIDTH = 225.0 / 1000.0;
constexpr double LARGE_ARMOR_HEIGHT = 50.0 / 1000.0; // 55
constexpr double FIFTTEN_DEGREE_RAD = 15 * CV_PI / 180;
template<typename T>
T normalize_angle_t(T angle) {
T two_pi = T(2.0 * M_PI);
return angle - two_pi * floor((angle + T(M_PI)) / two_pi);
}
template<typename T>
Eigen::Quaternion<T>
eulerToQuat(const Eigen::Vector<T, 3>& euler, int axis0, int axis1, int axis2, bool extrinsic) {
T rz = euler[0];
T ry = euler[1];
T rx = euler[2];
Eigen::Quaternion<T> qx(Eigen::AngleAxis<T>(rx, Eigen::Vector3<T>::UnitX()));
Eigen::Quaternion<T> qy(Eigen::AngleAxis<T>(ry, Eigen::Vector3<T>::UnitY()));
Eigen::Quaternion<T> qz(Eigen::AngleAxis<T>(rz, Eigen::Vector3<T>::UnitZ()));
if (!extrinsic)
std::swap(axis0, axis2);
Eigen::Quaternion<T> q;
if (axis0 == 0 && axis1 == 1 && axis2 == 2)
q = qx * qy * qz;
else if (axis0 == 0 && axis1 == 2 && axis2 == 1)
q = qx * qz * qy;
else if (axis0 == 1 && axis1 == 0 && axis2 == 2)
q = qy * qx * qz;
else if (axis0 == 1 && axis1 == 2 && axis2 == 0)
q = qy * qz * qx;
else if (axis0 == 2 && axis1 == 0 && axis2 == 1)
q = qz * qx * qy;
else if (axis0 == 2 && axis1 == 1 && axis2 == 0)
q = qz * qy * qx;
else
throw std::invalid_argument("Unsupported axis order");
return q;
}
template<typename PointType, typename U>
std::vector<PointType> buildObjectPoints(const U& w, const U& h) noexcept {
using T = U;
T half2 = T(2.0);
return { PointType(T(0), w / half2, -h / half2),
PointType(T(0), w / half2, h / half2),
PointType(T(0), -w / half2, h / half2),
PointType(T(0), -w / half2, -h / half2) };
}
struct Measure {
struct MeasureCtx {
int armor_num = 4;
int id = 0;
Eigen::Matrix4d T_odom_to_camera_d;
cv::Mat camera_intrinsic;
cv::Mat camera_distortion;
bool is_big;
} ctx;
Measure() = default;
explicit Measure(const MeasureCtx& c): ctx(c) {}
template<typename T>
void operator()(const T x[X_N], T z[Z_N]) const {
T id_t = T(ctx.id);
T num_t = T(ctx.armor_num);
T two = T(2.0);
T angle = normalize_angle_t(x[(int)State::YAW] + id_t * two * T(M_PI) / num_t);
bool outpost = (ctx.armor_num == 3);
bool use_l_h = (ctx.armor_num == 4) && (ctx.id == 1 || ctx.id == 3);
T r = use_l_h ? x[(int)State::R] + x[(int)State::L] : x[(int)State::R];
T armor_x = x[(int)State::CX] - ceres::cos(angle) * r;
T armor_y = x[(int)State::CY] - ceres::sin(angle) * r;
T armor_z = outpost ? getoutpost_armor_z(x)
: use_l_h ? x[(int)State::CZ] + x[(int)State::H]
: x[(int)State::CZ];
Eigen::Vector3<T> euler_odom;
euler_odom[0] = angle; //yaw
euler_odom[1] = outpost ? T(-FIFTTEN_DEGREE_RAD) : T(FIFTTEN_DEGREE_RAD); //pitch
euler_odom[2] = T(M_PI / 2.0); //roll
Eigen::Quaternion<T> q_odom = eulerToQuat(euler_odom, 2, 1, 0, true);
Eigen::Matrix4<T> T_odom_to_camera = ctx.T_odom_to_camera_d.cast<T>();
Eigen::Vector4<T> pos_odom4(armor_x, armor_y, armor_z, T(1.0));
Eigen::Vector4<T> pos_camera4 = T_odom_to_camera * pos_odom4;
Eigen::Vector3<T> pos_camera = pos_camera4.template head<3>();
Eigen::Matrix3<T> R_odom_to_camera = T_odom_to_camera.block(0, 0, 3, 3).template cast<T>();
Eigen::Matrix3<T> R_ori_odom = q_odom.normalized().toRotationMatrix();
Eigen::Matrix3<T> R_camera = R_odom_to_camera * R_ori_odom;
Eigen::Quaternion<T> q_camera(R_camera);
q_camera.normalize();
T w3 = ctx.is_big ? T(LARGE_ARMOR_WIDTH) : T(SMALL_ARMOR_WIDTH);
T h3 = ctx.is_big ? T(LARGE_ARMOR_HEIGHT) : T(SMALL_ARMOR_HEIGHT);
auto objPts = buildObjectPoints<Eigen::Matrix<T, 3, 1>>(w3, h3);
Eigen::Matrix3<T> R = q_camera.toRotationMatrix();
Eigen::Matrix<T, 3, 1> t = pos_camera;
std::vector<Eigen::Matrix<T, 3, 1>> Pc;
Pc.reserve(objPts.size());
for (const auto& p: objPts) {
Eigen::Matrix<T, 3, 1> v = p;
Pc.push_back(R * v + t);
}
const cv::Mat& K = ctx.camera_intrinsic;
T fx = T(K.at<double>(0, 0));
T fy = T(K.at<double>(1, 1));
T cx = T(K.at<double>(0, 2));
T cy = T(K.at<double>(1, 2));
std::array<T, 4> u, v;
for (int i = 0; i < 4; i++) {
T Xc = Pc[i][0];
T Yc = Pc[i][1];
T Zc = Pc[i][2];
u[i] = fx * (Xc / Zc) + cx;
v[i] = fy * (Yc / Zc) + cy;
}
z[0] = u[0];
z[1] = v[0];
z[2] = u[1];
z[3] = v[1];
z[4] = u[2];
z[5] = v[2];
z[6] = u[3];
z[7] = v[3];
}
template<typename T>
T getoutpost_armor_z(const T x[X_N]) const {
if (ctx.id == 0)
return x[(int)State::CZ];
if (ctx.id == 1)
return x[(int)State::CZ] + x[(int)State::outpost01DZ];
if (ctx.id == 2)
return x[(int)State::CZ] + x[(int)State::outpost02DZ];
return x[(int)State::CZ];
}
using VecX = Eigen::Matrix<double, X_N, 1>;
using VecZ = Eigen::Matrix<double, Z_N, 1>;
void h(const VecX& x, VecZ& z) const {
operator()(x.data(), z.data());
}
};
using RobotStateEKF = kalman_hybird_lib::ExtendedKalmanFilter<X_N, Z_N, Predict, Measure>;
using RobotStateESEKF = kalman_hybird_lib::ErrorStateEKF<X_N, Z_N, Predict, Measure>;
} // namespace crazy

View File

@@ -0,0 +1,226 @@
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
// ceres
#include <algorithm>
#include <ceres/ceres.h>
// project
#include "KalmanHyLib/kalman_hybird_lib.hpp"
namespace ypdv2armor_motion_model {
enum class MotionModel {
CONSTANT_VELOCITY = 0, // Constant velocity
CONSTANT_ROTATION = 1, // Constant rotation velocity
CONSTANT_VEL_ROT = 2 // Constant velocity and rotation velocity
};
// X_N: state dimension, Z_N: measurement dimension
constexpr int X_N = 11, Z_N = 4;
using VecZ = Eigen::Matrix<double, Z_N, 1>;
using VecX = Eigen::Matrix<double, X_N, 1>;
enum class MeasureID : uint8_t { YPD_Y = 0, YPD_P = 1, YPD_D = 2, ORI_YAW = 3, Z_N = 4 };
enum class StateID : uint8_t {
CX = 0,
VCX = 1,
CY = 2,
VCY = 3,
CZ = 4,
VCZ = 5,
YAW = 6,
VYAW = 7,
R = 8,
L = 9,
H = 10,
outpost01DZ = 9,
outpost02DZ = 10,
X_N = 11
};
struct State {
VecX x;
[[nodiscard]] inline double cx() const noexcept {
return x((int)StateID::CX);
}
[[nodiscard]] inline double cy() const noexcept {
return x((int)StateID::CY);
}
[[nodiscard]] inline double cz() const noexcept {
return x((int)StateID::CZ);
}
[[nodiscard]] inline Eigen::Vector3d pos() const noexcept {
return Eigen::Vector3d(cx(), cy(), cz());
}
[[nodiscard]] inline double vcx() const noexcept {
return x((int)StateID::VCX);
}
[[nodiscard]] inline double vcy() const noexcept {
return x((int)StateID::VCY);
}
[[nodiscard]] inline double vcz() const noexcept {
return x((int)StateID::VCZ);
}
[[nodiscard]] inline Eigen::Vector3d vel() const noexcept {
return Eigen::Vector3d(vcx(), vcy(), vcz());
}
[[nodiscard]] inline double vyaw() const noexcept {
return x((int)StateID::VYAW);
}
[[nodiscard]] inline double yaw() const noexcept {
return x((int)StateID::YAW);
}
[[nodiscard]] inline double r() const noexcept {
return x((int)StateID::R);
}
[[nodiscard]] inline double l() const noexcept {
return x((int)StateID::L);
}
[[nodiscard]] inline double h() const noexcept {
return x((int)StateID::H);
}
[[nodiscard]] inline double outpost01DZ() const noexcept {
return x((int)StateID::outpost01DZ);
}
[[nodiscard]] inline double outpost02DZ() const noexcept {
return x((int)StateID::outpost02DZ);
}
};
struct Predict {
Predict() = default;
explicit Predict(double dt, MotionModel model = MotionModel::CONSTANT_VEL_ROT):
dt(dt),
model(model) {}
template<typename T>
void operator()(const T x0[X_N], T x1[X_N]) const {
for (int i = 0; i < X_N; i++) {
x1[i] = x0[i];
}
// v_xyz
if (model == MotionModel::CONSTANT_VEL_ROT || model == MotionModel::CONSTANT_VELOCITY) {
// linear velocity
x1[(int)StateID::CX] += x0[(int)StateID::VCX] * T(dt);
x1[(int)StateID::CY] += x0[(int)StateID::VCY] * T(dt);
x1[(int)StateID::CZ] += x0[(int)StateID::VCZ] * T(dt);
} else {
// no velocity
x1[(int)StateID::VCX] *= T(0.);
x1[(int)StateID::VCY] *= T(0.);
x1[(int)StateID::VCZ] *= T(0.);
}
// v_yaw
if (model == MotionModel::CONSTANT_VEL_ROT || model == MotionModel::CONSTANT_ROTATION) {
// angular velocity
x1[(int)StateID::YAW] += x0[(int)StateID::VYAW] * T(dt);
} else {
// no rotation
x1[(int)StateID::VYAW] *= T(0.);
}
clampState(x1);
}
template<typename T>
void clampState(T x1[X_N]) const {
auto& r = x1[(int)StateID::R];
auto& l = x1[(int)StateID::L];
r = std::clamp(r, T(0.1), T(0.5));
if (r < T(0.1) || r > T(0.5)) {
r = T(0.25);
l = T(0);
}
T sum = r + l;
if (sum < T(0.1) || sum > T(0.5)) {
r = T(0.25);
l = T(0);
}
auto& h = x1[(int)StateID::H];
h = std::clamp(h, T(-0.5), T(0.5));
}
void f(const VecX& x0, VecX& x1) const {
assert(x0.size() == X_N);
assert(x1.size() == X_N);
operator()(x0.data(), x1.data());
}
double dt;
MotionModel model;
};
template<typename T>
T normalize_angle_t(T angle) {
T two_pi = T(2.0 * M_PI);
return angle - two_pi * floor((angle + T(M_PI)) / two_pi);
}
struct Measure {
struct MeasureCtx {
MeasureCtx() = default;
MeasureCtx(int id, int armor_num): armor_num(armor_num), id(id) {}
int armor_num = 4;
int id = 0;
} ctx;
Measure() = default;
explicit Measure(MeasureCtx c): ctx(c) {}
template<typename T>
void operator()(const T x[X_N], T z[Z_N]) const {
// Compute armor position
auto [armor_x, armor_y, armor_z, angle] = h_armor_xyza(x);
T xy_dist = ceres::sqrt(armor_x * armor_x + armor_y * armor_y);
T dist = ceres::sqrt(xy_dist * xy_dist + armor_z * armor_z);
// Observation model
z[(int)MeasureID::YPD_Y] = ceres::atan2(armor_y, armor_x); // yaw
z[(int)MeasureID::YPD_P] = ceres::atan2(armor_z, xy_dist); // pitch
z[(int)MeasureID::YPD_D] = dist; // distance
z[(int)MeasureID::ORI_YAW] = angle; // orientation_yaw
}
template<typename T>
T get_angle(const T x[X_N]) const {
return normalize_angle_t(x[(int)StateID::YAW] + ctx.id * 2 * M_PI / ctx.armor_num);
}
template<typename T>
std::tuple<T, T, T, T> h_armor_xyza(const T x[X_N]) const {
T angle = get_angle(x);
auto outpost = ctx.armor_num == 3;
auto use_l_h = (ctx.armor_num == 4) && (ctx.id == 1 || ctx.id == 3);
T r = (use_l_h) ? x[(int)StateID::R] + x[(int)StateID::L] : x[(int)StateID::R];
T armor_x = x[(int)StateID::CX] - ceres::cos(angle) * r;
T armor_y = x[(int)StateID::CY] - ceres::sin(angle) * r;
T armor_z = (outpost) ? getoutpost_armor_z(x)
: (use_l_h) ? x[(int)StateID::CZ] + x[(int)StateID::H]
: x[(int)StateID::CZ];
return { armor_x, armor_y, armor_z, angle };
}
Eigen::Vector4d h_armor_xyza(const VecX& x) const {
assert(x.size() == X_N);
auto [armor_x, armor_y, armor_z, angle] = h_armor_xyza(x.data());
return { armor_x, armor_y, armor_z, angle };
}
template<typename T>
T getoutpost_armor_z(const T x[X_N]) const {
return (ctx.id == 0) ? x[(int)StateID::CZ]
: (ctx.id == 1) ? x[(int)StateID::CZ] + x[(int)StateID::outpost01DZ]
: (ctx.id == 2) ? x[(int)StateID::CZ] + x[(int)StateID::outpost02DZ]
: x[(int)StateID::CZ];
}
void h(const VecX& x, VecZ& z) const {
assert(x.size() == X_N);
assert(z.size() == Z_N);
operator()(x.data(), z.data());
}
};
using RobotStateEKF = kalman_hybird_lib::ExtendedKalmanFilter<X_N, Z_N, Predict, Measure>;
using RobotStateESEKF = kalman_hybird_lib::ErrorStateEKF<X_N, Z_N, Predict, Measure>;
} // namespace ypdv2armor_motion_model

View File

@@ -0,0 +1,379 @@
#include "target.hpp"
namespace wust_vision {
namespace auto_aim {
Target::Target() {
target_state_.x = Eigen::VectorXd::Zero(MModel::X_N);
}
Target::Target(const Armor& a, TargetConfig::Ptr target_config) {
Eigen::DiagonalMatrix<double, ypdv2armor_motion_model::X_N> p0;
if (a.number == ArmorNumber::OUTPOST) {
p0.diagonal() << 1, 64, 1, 64, 1, 81, 0.4, 100, 1e-4, 0.1, 0.1;
armor_num_ = 3;
radius_pre_ = 0.2765;
} else if (a.number == ArmorNumber::BASE) {
p0.diagonal() << 1, 64, 1, 64, 1, 64, 0.4, 100, 1e-4, 0, 0;
armor_num_ = 3;
radius_pre_ = 0.3205;
} else {
p0.diagonal() << 1, 64, 1, 64, 1, 64, 0.4, 100, 1, 1, 1;
armor_num_ = 4;
radius_pre_ = 0.2;
}
target_config_ = target_config;
const auto yfv2 = MModel::Predict(0.005);
ctx_.armor_num = armor_num_;
ctx_.id = 0;
const auto yhv2 = MModel::Measure(ctx_);
const auto yu_qv2 = [this]() {
Eigen::Matrix<double, MModel::X_N, MModel::X_N> q;
return q;
};
const auto yu_rv2 = [this](const Eigen::Matrix<double, MModel::Z_N, 1>& z) {
Eigen::Matrix<double, MModel::Z_N, MModel::Z_N> r;
return r;
};
esekf_ypd_ = MModel::RobotStateESEKF(yfv2, yhv2, yu_qv2, yu_rv2, p0);
esekf_ypd_.setResidualFunc([this](
const Eigen::Matrix<double, MModel::Z_N, 1>& z_pred,
const Eigen::Matrix<double, MModel::Z_N, 1>& z
) {
Eigen::Matrix<double, MModel::Z_N, 1> r = z - z_pred;
r[0] = angles::shortest_angular_distance(
z_pred[(int)MModel::MeasureID::YPD_Y],
z[(int)MModel::MeasureID::YPD_Y]
); // yaw
r[3] = angles::shortest_angular_distance(
z_pred[(int)MModel::MeasureID::ORI_YAW],
z[(int)MModel::MeasureID::ORI_YAW]
); // ori_yaw
return r;
});
esekf_ypd_.setIterationNum(target_config_->esekf_iter_num_param.get());
esekf_ypd_.setInjectFunc([this](
const Eigen::Matrix<double, MModel::X_N, 1>& delta,
Eigen::Matrix<double, MModel::X_N, 1>& nominal
) {
for (int i = 0; i < MModel::X_N; i++) {
if (i == (int)MModel::StateID::YAW)
continue;
nominal[i] += delta[i];
}
nominal[(int)MModel::StateID::YAW] = angles::normalize_angle(
nominal[(int)MModel::StateID::YAW] + delta[(int)MModel::StateID::YAW]
);
});
const double xa = a.target_pos.x();
const double ya = a.target_pos.y();
const double za = a.target_pos.z();
const double yaw = utils::orientationToYaw<Target>(a.target_ori);
target_state_.x = Eigen::VectorXd::Zero(MModel::X_N);
const double r = radius_pre_;
const double xc = xa + r * cos(yaw);
const double yc = ya + r * sin(yaw);
const double zc = za;
target_state_.x << xc, 0, yc, 0, zc, 0, yaw, 0, r, 0, 0;
esekf_ypd_.setState(target_state_.x);
tracked_id_ = a.number;
type_ = a.type;
last_t_ = a.timestamp;
timestamp_ = a.timestamp;
is_inited = true;
}
Eigen::Matrix<double, MModel::Z_N, MModel::Z_N>
Target::computeMeasurementCovariance(const Eigen::Matrix<double, MModel::Z_N, 1>& z
) const noexcept {
Eigen::Matrix<double, MModel::Z_N, MModel::Z_N> r;
const double delta_angle = angles::normalize_angle(z[3] - z[0]);
const double abs_delta = std::abs(delta_angle);
// sin插值函数小值慢、大值快
const auto sinInterp = [](double x, double x0, double x1, double y0, double y1) -> double {
double t = (x - x0) / (x1 - x0);
if (t < 0)
t = 0;
if (t > 1)
t = 1;
double s = std::sin(t * M_PI / 2.0);
return y0 + s * (y1 - y0);
};
// clang-format off
r <<target_config_->yp_r_param.get(), 0, 0, 0,
0, target_config_->yp_r_param.get() , 0, 0,
0, 0, sinInterp(abs_delta, 0.0, M_PI/2.0, target_config_->dis_r_front_param.get(), target_config_->dis_r_side_param.get())+z[2]*z[2]*target_config_->dis2_r_ratio_param.get(), 0,
0, 0, 0,log(std::abs(z[2]) + 1) *target_config_->yaw_r_log_ratio_param.get() + sinInterp(M_PI/2.0-abs_delta, 0.0, M_PI/2.0, target_config_->yaw_r_base_side_param.get(), target_config_->yaw_r_base_front_param.get());
// clang-format on
return r;
}
Eigen::Matrix<double, MModel::X_N, MModel::X_N> Target::computeProcessNoise(double dt
) const noexcept {
Eigen::Matrix<double, MModel::X_N, MModel::X_N> q;
Eigen::Vector3d q_xyz;
double q_yaw;
double q_l, q_h;
if (tracked_id_ == ArmorNumber::OUTPOST) {
q_xyz = target_config_->qxyz_output; // 前哨站加速度方差
q_yaw = target_config_->qyaw_output_param.get(); // 前哨站角加速度方差
q_l = target_config_->q_outpost_dz_param.get();
q_h = target_config_->q_outpost_dz_param.get();
} else {
q_xyz = target_config_->qxyz_common; // 加速度方差
q_yaw = target_config_->qyaw_common_param.get(); // 角加速度方差
q_l = target_config_->q_l_param.get();
q_h = target_config_->q_h_param.get();
}
const double t = dt;
const double q_x_x = pow(t, 4) / 4 * q_xyz.x(), q_x_vx = pow(t, 3) / 2 * q_xyz.x(),
q_vx_vx = pow(t, 2) * q_xyz.x();
const double q_y_y = pow(t, 4) / 4 * q_xyz.y(), q_y_vy = pow(t, 3) / 2 * q_xyz.y(),
q_vy_vy = pow(t, 2) * q_xyz.y();
const double q_z_z = pow(t, 4) / 4 * q_xyz.z(), q_z_vz = pow(t, 3) / 2 * q_xyz.z(),
q_vz_vz = pow(t, 2) * q_xyz.z();
const double q_yaw_yaw = pow(t, 4) / 4 * q_yaw, q_yaw_vyaw = pow(t, 3) / 2 * q_yaw,
q_vyaw_vyaw = pow(t, 2) * q_yaw;
const double q_r = target_config_->q_r_param.get();
// clang-format off
// xc v_xc yc v_yc zc v_zc yaw v_yaw r l h
q << q_x_x, q_x_vx, 0, 0, 0, 0, 0, 0, 0, 0, 0,
q_x_vx, q_vx_vx,0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, q_y_y, q_y_vy, 0, 0, 0, 0, 0, 0, 0,
0, 0, q_y_vy, q_vy_vy,0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, q_z_z, q_z_vz, 0, 0, 0, 0, 0,
0, 0, 0, 0, q_z_vz, q_vz_vz,0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, q_yaw_yaw, q_yaw_vyaw, 0, 0, 0,
0, 0, 0, 0, 0, 0, q_yaw_vyaw, q_vyaw_vyaw,0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, q_r, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, q_l, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, q_h;
// clang-format on
return q;
}
MModel::Predict Target::getPredictFunc(double dt) const noexcept {
MModel::Predict predict_func;
if (tracked_id_ == ArmorNumber::OUTPOST) {
predict_func = MModel::Predict {
dt,
MModel::MotionModel::CONSTANT_ROTATION,
};
} else {
predict_func = MModel::Predict {
dt,
MModel::MotionModel::CONSTANT_VEL_ROT,
};
}
return predict_func;
}
void Target::predict(std::chrono::steady_clock::time_point t) noexcept {
const double dt = wust_vl::common::utils::time_utils::durationSec(last_t_, t);
predict(dt);
last_t_ = t;
}
void Target::predict(double dt) noexcept {
MModel::Predict predict_func = getPredictFunc(dt);
esekf_ypd_.setPredictFunc(predict_func);
const auto yu_qv2 = [dt, this]() { return computeProcessNoise(dt); };
esekf_ypd_.setUpdateQ(yu_qv2);
target_state_.x = esekf_ypd_.predict();
if (target_state_.pos().norm() < 0.5) {
is_tracking = false;
}
}
void Target::predictSimple(std::chrono::steady_clock::time_point t) noexcept {
const double dt = wust_vl::common::utils::time_utils::durationSec(last_t_, t);
predictSimple(dt);
last_t_ = t;
}
void Target::predictSimple(double dt) noexcept {
MModel::Predict predict_func = getPredictFunc(dt);
predict_func.f(target_state_.x, target_state_.x);
if (target_state_.pos().norm() < 0.5) {
is_tracking = false;
}
}
bool Target::update(const std::pair<int, Armor>& a) noexcept {
const auto armor = a.second;
const auto id = a.first;
const auto yu_rv2 = [this](const Eigen::Matrix<double, MModel::Z_N, 1>& z) {
return this->computeMeasurementCovariance(z);
};
esekf_ypd_.setUpdateR(yu_rv2);
measurement_ = getMeasure(armor);
if (id != 0)
jumped = true;
ctx_.id = id;
esekf_ypd_.setMeasureFunc(MModel::Measure { ctx_ });
target_state_.x = esekf_ypd_.update(measurement_);
timestamp_ = armor.timestamp;
last_t_ = timestamp_;
return true;
}
cv::Rect Target::expanded(
Eigen::Matrix4d T_camera_to_odom,
const cv::Mat& camera_intrinsic,
const cv::Mat& camera_distortion,
const cv::Size& image_size
) const noexcept {
const double dt = wust_vl::common::utils::time_utils::durationSec(
timestamp_,
wust_vl::common::utils::time_utils::now()
);
if (!is_inited || dt > target_config_->lost_time_thres_param.get()) {
return cv::Rect(0, 0, 0, 0);
}
const float car_box_half =
std::max(target_state_.r(), target_state_.r() + target_state_.l()) + 0.15;
static std::vector<cv::Point3f> CAR_BOX;
CAR_BOX = { { 0, car_box_half, -car_box_half },
{ 0, -car_box_half, -car_box_half },
{ 0, -car_box_half, car_box_half },
{ 0, car_box_half, car_box_half } };
const Eigen::Matrix4d T_odom_to_camera = T_camera_to_odom.inverse();
const Eigen::Vector4d
pos_odom(target_state_.cx(), target_state_.cy(), target_state_.cz(), 1.0);
const Eigen::Vector4d pos_cam = T_odom_to_camera * pos_odom;
if (pos_cam.z() <= 0.2) {
return cv::Rect(0, 0, 0, 0);
}
const cv::Mat tvec = (cv::Mat_<double>(3, 1) << pos_cam.x(), pos_cam.y(), pos_cam.z());
Eigen::Vector3d euler;
euler.z() = M_PI / 2.0;
euler.y() = 0;
euler.x() = std::atan2(pos_odom.y(), pos_odom.x());
const Eigen::Quaterniond ori = utils::eulerToQuat(euler, utils::EulerOrder::ZYX);
const auto target_ori = utils::transformOrientation(ori, T_odom_to_camera);
const Eigen::Matrix3d tf_rot = target_ori.toRotationMatrix();
const cv::Mat rot_mat =
(cv::Mat_<double>(3, 3) << tf_rot(0, 0),
tf_rot(0, 1),
tf_rot(0, 2),
tf_rot(1, 0),
tf_rot(1, 1),
tf_rot(1, 2),
tf_rot(2, 0),
tf_rot(2, 1),
tf_rot(2, 2));
cv::Mat rvec;
cv::Rodrigues(rot_mat, rvec);
std::vector<cv::Point2f> pts_2d;
cv::projectPoints(CAR_BOX, rvec, tvec, camera_intrinsic, camera_distortion, pts_2d);
const cv::Rect rect = cv::boundingRect(pts_2d);
const cv::Rect img_rect(0, 0, image_size.width, image_size.height);
if ((rect & img_rect).area() <= 0) {
return cv::Rect(0, 0, 0, 0);
}
const int base_side = std::max(rect.width, rect.height);
const int max_side = std::max(image_size.width, image_size.height);
const double lost_dt = target_config_->lost_time_thres_param.get();
const double dt_clamped = std::max(0.0, std::min(dt, lost_dt));
int side = static_cast<int>(base_side + (max_side - base_side) * (dt_clamped / lost_dt));
if (dt >= lost_dt) {
side = max_side;
}
const int cx = rect.x + rect.width / 2;
const int cy = rect.y + rect.height / 2;
cv::Rect square(cx - side / 2, cy - side / 2, side, side);
square &= img_rect;
return square;
}
std::vector<std::pair<int, Armor>> Target::match(const std::vector<Armor>& armors) noexcept {
std::vector<std::pair<int, Armor>> result;
const int n_obs = static_cast<int>(armors.size());
const int armors_num = armor_num_;
const double GATE = target_config_->match_gate_param.get();
const double max_cost = 1e9;
std::vector<std::vector<double>> cost(n_obs, std::vector<double>(armors_num, max_cost + 1));
std::vector<MModel::VecZ> meas_list(n_obs);
for (int j = 0; j < n_obs; ++j) {
meas_list[j] = getMeasure(armors[j]);
}
for (int j = 0; j < n_obs; ++j) {
for (int id = 0; id < armors_num; ++id) {
MModel::Measure::MeasureCtx tmp_ctx(id, armors_num);
MModel::Measure measure(tmp_ctx);
MModel::VecZ z_pred;
measure.h(target_state_.x, z_pred);
MModel::VecZ nu = meas_list[j] - z_pred;
nu[(int)MModel::MeasureID::YPD_Y] =
angles::normalize_angle(nu[(int)MModel::MeasureID::YPD_Y]);
nu[(int)MModel::MeasureID::YPD_P] =
angles::normalize_angle(nu[(int)MModel::MeasureID::YPD_P]);
nu[(int)MModel::MeasureID::ORI_YAW] =
angles::normalize_angle(nu[(int)MModel::MeasureID::ORI_YAW]);
auto R = computeMeasurementCovariance(z_pred);
double d2 = nu.transpose() * R.ldlt().solve(nu);
// 门控
if (std::isfinite(d2) && d2 < GATE) {
cost[j][id] = d2;
}
}
}
std::vector<bool> used_obs(n_obs, false);
std::vector<bool> used_id(armors_num, false);
while (true) {
double best = max_cost;
int best_j = -1;
int best_id = -1;
for (int j = 0; j < n_obs; ++j) {
if (used_obs[j])
continue;
for (int id = 0; id < armors_num; ++id) {
if (used_id[id])
continue;
if (cost[j][id] < best) {
best = cost[j][id];
best_j = j;
best_id = id;
}
}
}
if (best_j < 0 || best_id < 0) {
break;
}
used_obs[best_j] = true;
used_id[best_id] = true;
result.push_back(std::make_pair(best_id, armors[best_j]));
}
return result;
}
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,185 @@
#pragma once
#include "tasks/auto_aim/armor_tracker/motion_models/motion_modelypdv2.hpp"
#include "tasks/auto_aim/type.hpp"
#include "wust_vl/common/utils/parameter.hpp"
#include <wust_vl/common/utils/timer.hpp>
namespace wust_vision {
namespace auto_aim {
namespace MModel = ypdv2armor_motion_model;
struct TargetConfig: wust_vl::common::utils::ParamGroup {
static constexpr const char* kKey = "armor_tracker";
const char* key() const override {
return kKey;
}
GEN_PARAM(int, esekf_iter_num);
GEN_PARAM(double, lost_time_thres);
GEN_PARAM(int, tracking_thres);
GEN_PARAM(double, max_yaw_diff_deg);
GEN_PARAM(double, max_dis_diff);
GEN_PARAM(double, match_gate);
GEN_PARAM(double, qyaw_common);
GEN_PARAM(double, qyaw_output);
GEN_PARAM(double, q_r);
GEN_PARAM(double, q_l);
GEN_PARAM(double, q_h);
GEN_PARAM(double, q_outpost_dz);
GEN_PARAM(double, yp_r);
GEN_PARAM(double, dis_r_front);
GEN_PARAM(double, dis_r_side);
GEN_PARAM(double, dis2_r_ratio);
GEN_PARAM(double, yaw_r_base_front);
GEN_PARAM(double, yaw_r_base_side);
GEN_PARAM(double, yaw_r_log_ratio);
GEN_PARAM(std::vector<double>, qxyz_common);
GEN_PARAM(std::vector<double>, qxyz_output);
Eigen::Vector3d qxyz_common = { 100, 100, 100 };
Eigen::Vector3d qxyz_output = { 10, 10, 10 };
using Ptr = std::shared_ptr<TargetConfig>;
TargetConfig() {
qxyz_output_param.onChange([this](auto o, auto n) {
qxyz_common = Eigen::Vector3d(n[0], n[1], n[2]);
});
qxyz_output_param.onChange([this](auto o, auto n) {
qxyz_output = Eigen::Vector3d(n[0], n[1], n[2]);
});
}
static Ptr create() {
return std::make_shared<TargetConfig>();
}
void loadSelf(const YAML::Node& node) override {
esekf_iter_num_param.load(node);
lost_time_thres_param.load(node);
tracking_thres_param.load(node);
max_yaw_diff_deg_param.load(node);
max_dis_diff_param.load(node);
match_gate_param.load(node);
qyaw_common_param.load(node);
qyaw_output_param.load(node);
qxyz_common_param.load(node);
qxyz_output_param.load(node);
q_r_param.load(node);
q_l_param.load(node);
q_h_param.load(node);
q_outpost_dz_param.load(node);
yp_r_param.load(node);
dis_r_front_param.load(node);
dis_r_side_param.load(node);
yaw_r_base_front_param.load(node);
yaw_r_base_side_param.load(node);
yaw_r_log_ratio_param.load(node);
}
};
class Target {
public:
Target();
Target(const Armor& armor, TargetConfig::Ptr target_config);
MModel::Measure::MeasureCtx ctx_;
ArmorNumber tracked_id_;
std::string type_;
MModel::VecZ measurement_ = Eigen::Matrix<double, MModel::Z_N, 1>::Zero();
MModel::State target_state_ = MModel::State();
double radius_pre_;
int armor_num_ = 4;
bool jumped = false;
bool is_inited = false;
bool is_tracking = false;
std::chrono::steady_clock::time_point last_t_;
std::chrono::steady_clock::time_point timestamp_;
MModel::RobotStateESEKF esekf_ypd_;
TargetConfig::Ptr target_config_;
[[nodiscard]] cv::Rect expanded(
Eigen::Matrix4d T_camera_to_odom,
const cv::Mat& camera_intrinsic,
const cv::Mat& camera_distortion,
const cv::Size& image_size
) const noexcept;
void predict(std::chrono::steady_clock::time_point t) noexcept;
void predict(double dt) noexcept;
void predictSimple(std::chrono::steady_clock::time_point t) noexcept;
void predictSimple(double dt) noexcept;
[[nodiscard]] MModel::Predict getPredictFunc(double dt) const noexcept;
bool update(const std::pair<int, Armor>& armor) noexcept;
[[nodiscard]] Eigen::Matrix<double, MModel::Z_N, MModel::Z_N>
computeMeasurementCovariance(const Eigen::Matrix<double, MModel::Z_N, 1>& z) const noexcept;
[[nodiscard]] Eigen::Matrix<double, MModel::X_N, MModel::X_N> computeProcessNoise(double dt
) const noexcept;
[[nodiscard]] std::optional<ArmorNumber> getArmorNumber() const noexcept {
if (!checkTargetAppear()) {
return std::nullopt;
}
return tracked_id_;
}
[[nodiscard]] std::vector<double> getArmorYaws() const noexcept {
std::vector<double> yaw_list;
yaw_list.reserve(armor_num_);
for (int i = 0; i < armor_num_; i++) {
MModel::Measure::MeasureCtx _ctx(i, armor_num_);
MModel::Measure measure(_ctx);
yaw_list.push_back(measure.get_angle(target_state_.x.data()));
}
return yaw_list;
}
[[nodiscard]] std::vector<Eigen::Vector3d> getArmorPositions() const noexcept {
std::vector<Eigen::Vector3d> armor_positions;
armor_positions.reserve(armor_num_);
for (int i = 0; i < armor_num_; i++) {
MModel::Measure::MeasureCtx _ctx(i, armor_num_);
MModel::Measure measure(_ctx);
const Eigen::Vector4d xyza = measure.h_armor_xyza(target_state_.x);
armor_positions.push_back(xyza.head<3>());
}
return armor_positions;
}
[[nodiscard]] std::vector<Eigen::Vector4d> getArmorPosAndYaw() const noexcept {
std::vector<Eigen::Vector4d> pos_yaw;
pos_yaw.reserve(armor_num_);
for (int i = 0; i < armor_num_; ++i) {
MModel::Measure::MeasureCtx _ctx(i, armor_num_);
MModel::Measure measure(_ctx);
const Eigen::Vector4d xyza = measure.h_armor_xyza(target_state_.x);
pos_yaw.push_back(xyza);
}
return pos_yaw;
}
[[nodiscard]] double getMeanZ() const noexcept {
double mean = 0;
for (const auto& p: getArmorPositions()) {
mean += p.z();
}
return mean / armor_num_;
}
[[nodiscard]] double getArmor2CenterXYDis(int id) const noexcept {
const auto use_l_h = (armor_num_ == 4) && (id == 1 || id == 3);
const auto r = (use_l_h) ? target_state_.r() + target_state_.l() : target_state_.r();
return r;
}
[[nodiscard]] std::vector<std::pair<int, Armor>> match(const std::vector<Armor>& armors
) noexcept;
[[nodiscard]] inline bool checkTargetAppear() const noexcept {
const bool appear = is_tracking
&& wust_vl::common::utils::time_utils::durationSec(
timestamp_,
wust_vl::common::utils::time_utils::now()
) < target_config_->lost_time_thres_param.get();
return appear;
}
[[nodiscard]] Eigen::Matrix<double, MModel::Z_N, 1> getMeasure(const Armor& a) noexcept {
const auto p = a.target_pos;
const double measured_yaw = utils::orientationToYaw<Target>(a.target_ori);
double ypd_y = std::atan2(p.y(), p.x());
static double last_ypd_y = 0;
ypd_y = last_ypd_y + angles::shortest_angular_distance(last_ypd_y, ypd_y);
last_ypd_y = ypd_y;
const double ypd_p = std::atan2(p.z(), std::sqrt(p.x() * p.x() + p.y() * p.y()));
const double ypd_d = std::sqrt(p.x() * p.x() + p.y() * p.y() + p.z() * p.z());
return Eigen::Vector4d(ypd_y, ypd_p, ypd_d, measured_yaw);
}
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,209 @@
#include "trackerv3.hpp"
namespace wust_vision {
namespace auto_aim {
struct Tracker::Impl {
public:
enum State {
LOST,
DETECTING,
TRACKING,
TEMP_LOST,
} tracker_state = LOST;
Impl(wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter) {
tracker_state = LOST;
target_config_ = TargetConfig::create();
auto_aim_config_parameter->registerGroup(*target_config_);
auto_aim_config_parameter->reloadFromOldPath();
target_ = Target();
}
Target track(const Armors& armors_msg) noexcept {
const double dt =
std::chrono::duration<double>(armors_msg.timestamp - last_time_).count();
last_time_ = armors_msg.timestamp;
lost_thres_ =
std::abs(static_cast<int>(target_config_->lost_time_thres_param.get() / dt));
Armors armors;
armors = armors_msg;
std::erase_if(armors.armors, [this](const Armor& a) {
double center_yaw =
std::atan2(target_.target_state_.cy(), target_.target_state_.cx());
bool state_check = tracker_state == TRACKING;
bool outpost_check = target_.tracked_id_ == ArmorNumber::OUTPOST && !a.is_ok;
bool pose_check =
(std::abs(angles::normalize_angle(
orientationToYaw(a.target_ori, center_yaw) - center_yaw
)) > (target_config_->max_yaw_diff_deg_param.get() * M_PI / 180.0)
|| std::abs((a.target_pos - target_.target_state_.pos()).norm())
> target_config_->max_dis_diff_param.get())
&& target_.is_inited
&& std::abs(wust_vl::common::utils::time_utils::durationMs(
target_.timestamp_,
wust_vl::common::utils::time_utils::now()
))
< 1000.0;
return state_check && outpost_check || pose_check;
});
std::sort(
armors.armors.begin(),
armors.armors.end(),
[](const Armor& a, const Armor& b) {
return a.distance_to_image_center < b.distance_to_image_center;
}
);
bool found;
if (tracker_state == LOST) {
found = initTarget(armors);
} else {
found = updateTarget(armors);
}
updateFsm(found);
return target_;
}
void updateFsm(bool found) noexcept {
switch (tracker_state) {
case DETECTING:
if (found) {
if (++detect_count_ > target_config_->tracking_thres_param.get()) {
detect_count_ = 0;
tracker_state = TRACKING;
}
} else {
detect_count_ = 0;
tracker_state = LOST;
}
break;
case TRACKING:
if (!found) {
tracker_state = TEMP_LOST;
lost_count_ = 1;
}
break;
case TEMP_LOST:
if (!found) {
if (++lost_count_ > lost_thres_) {
lost_count_ = 0;
tracker_state = LOST;
}
} else {
lost_count_ = 0;
tracker_state = TRACKING;
}
break;
default:
break;
}
target_.is_tracking = (tracker_state == TRACKING || tracker_state == TEMP_LOST);
if (found)
++found_count_;
}
bool initTarget(const Armors& armors) noexcept {
if (armors.armors.empty()) {
return false;
}
bool found = false;
Armor init_target;
Armors others = armors;
others.armors.clear();
for (auto& a: armors.armors) {
if (!a.is_none_purple && !found) {
init_target = a;
found = true;
continue;
}
others.armors.push_back(a);
}
if (!found) {
return false;
}
target_ = Target(init_target, target_config_);
// updateTarget(others);
tracker_state = DETECTING;
return true;
}
bool updateTarget(const Armors& armors) noexcept {
if (armors.armors.empty())
return false;
target_.predict(armors.timestamp);
std::vector<Armor> candidates;
candidates.reserve(armors.armors.size());
for (const auto& a: armors.armors) {
if (isSameTarget(a.number, target_.tracked_id_) && !a.is_none_purple) {
candidates.emplace_back(a);
}
}
if (candidates.empty())
return false;
int updated = 0;
const auto matches = target_.match(candidates);
for (const auto& m: matches) {
if (m.second.is_none_purple) {
if (++is_none_purple_count_ > 100)
continue;
} else {
is_none_purple_count_ = 0;
}
if (target_.update(m))
++updated;
}
return updated > 0;
}
int lost_thres_;
int detect_count_ = 0;
int lost_count_ = 0;
int is_none_purple_count_ = 0;
int found_count_ = 0;
Target target_;
std::chrono::steady_clock::time_point last_time_;
TargetConfig::Ptr target_config_;
double orientationToYaw(const Eigen::Quaterniond& q, double from) noexcept {
double roll, pitch, yaw;
Eigen::Vector3d euler = utils::quatToEuler(q, utils::EulerOrder::ZYX, false);
yaw = euler[0];
yaw = from + angles::shortest_angular_distance(from, yaw);
return yaw;
}
};
Tracker::Tracker(wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter) {
_impl = std::make_unique<Impl>(auto_aim_config_parameter);
}
Tracker::~Tracker() {
_impl.reset();
}
Target Tracker::track(const Armors& armors) noexcept {
return _impl->track(armors);
}
int Tracker::getFoundCount() const noexcept {
return _impl->found_count_;
}
void Tracker::setFoundCount(int count) noexcept {
_impl->found_count_ = count;
}
std::chrono::steady_clock::time_point Tracker::getLastTime() const noexcept {
return _impl->last_time_;
}
void Tracker::setLastTime(std::chrono::steady_clock::time_point t) noexcept {
_impl->last_time_ = t;
}
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,23 @@
#pragma once
#include "target.hpp"
namespace wust_vision {
namespace auto_aim {
class Tracker {
public:
using Ptr = std::unique_ptr<Tracker>;
Tracker(wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter);
static Ptr create(wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter) {
return std::make_unique<Tracker>(auto_aim_config_parameter);
}
~Tracker();
[[nodiscard]] Target track(const Armors& armors) noexcept;
int getFoundCount() const noexcept;
void setFoundCount(int count) noexcept;
void setLastTime(std::chrono::steady_clock::time_point t) noexcept;
std::chrono::steady_clock::time_point getLastTime() const noexcept;
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,307 @@
// Created by Labor 2023.8.25
// Maintained by Chengfu Zou, Labor
// Copyright (C) FYT Vision Group. All rights reserved.
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "armor_where.hpp"
#include "wust_vl/algorithm/pnp_solver.hpp"
#include <opencv2/core/eigen.hpp>
namespace wust_vision {
namespace auto_aim {
struct ArmorWhere::Impl {
public:
Impl(const YAML::Node& config, const std::pair<cv::Mat, cv::Mat>& camera_info) {
camera_info_ = camera_info;
params_.load(config);
pnp_solver_ = std::make_unique<wust_vl::algorithm::PnPSolver>(cv::SOLVEPNP_IPPE);
pnp_solver_->setObjectPoints(
"small",
ArmorObject::buildObjectPoints<cv::Point3f>(SMALL_ARMOR_WIDTH, SMALL_ARMOR_HEIGHT)
);
pnp_solver_->setObjectPoints(
"large",
ArmorObject::buildObjectPoints<cv::Point3f>(LARGE_ARMOR_WIDTH, LARGE_ARMOR_HEIGHT)
);
}
struct Params {
enum class OptMode : int { GOLDEN = 0, CERES = 1, NONE = 2 } opt_mode;
OptMode fromString(const std::string& mode) {
if (mode == "golden" || mode == "GOLDEN") {
return OptMode::GOLDEN;
} else if (mode == "none" || mode == "NONE") {
return OptMode::NONE;
} else {
return OptMode::NONE;
}
}
int golden_search_side_deg = 60;
double distance_fix_a2 = 0;
void load(const YAML::Node& node) {
opt_mode = fromString(node["yaw_opt"]["mode"].as<std::string>());
golden_search_side_deg = node["yaw_opt"]["golden_search_side_deg"].as<int>();
distance_fix_a2 = node["distance_fix_a2"].as<double>();
}
} params_;
std::vector<Armor> where(
const std::vector<ArmorObject>& armors,
Eigen::Matrix4d T_camera_to_odom
) const noexcept {
std::vector<Armor> armors_msg;
const Eigen::Matrix3d R_imu_cam = T_camera_to_odom.block<3, 3>(0, 0);
auto makeArmor =
[&](const ArmorObject& obj, const Eigen::Vector3d& t, const Eigen::Matrix3d& R) {
Armor msg;
msg.type = (obj.number == ArmorNumber::NO1 || obj.number == ArmorNumber::BASE)
? "large"
: "small";
msg.number = obj.number;
Eigen::Quaterniond q(R);
Eigen::Quaterniond add_roll {
Eigen::AngleAxisd(M_PI / 2, Eigen::Vector3d::UnitX())
};
Eigen::Quaterniond new_q = q * add_roll;
auto [yaw, pitch, dist] = utils::xyz2ypd_rad(t.x(), t.y(), t.z());
dist += params_.distance_fix_a2 * dist * dist;
auto [x, y, z] = utils::ypd2xyz_rad(yaw, pitch, dist);
msg.pos = { x, y, z };
msg.ori = new_q;
auto_aim::transformArmorData(msg, T_camera_to_odom);
msg.distance_to_image_center =
pnp_solver_->calculateDistanceToCenter(obj.center, camera_info_.first);
msg.is_ok = obj.is_ok;
if (obj.color == ArmorColor::NONE || obj.color == ArmorColor::PURPLE) {
msg.is_none_purple = true;
} else {
msg.is_none_purple = false;
}
return msg;
};
for (auto const& a: armors) {
cv::Mat rvec, tvec;
std::string type = (a.number == ArmorNumber::NO1 || a.number == ArmorNumber::BASE)
? "large"
: "small";
if (!pnp_solver_->solvePnP(
a.landmarks(),
rvec,
tvec,
type,
camera_info_.first,
camera_info_.second
))
{
WUST_WARN("PNP") << "PNP failed";
continue;
}
cv::Mat R_cv;
cv::Rodrigues(rvec, R_cv);
Eigen::Matrix3d R = utils::cvToEigen(R_cv);
Eigen::Vector3d t = utils::cvToEigen(tvec);
if (params_.opt_mode != Params::OptMode::NONE) {
Eigen::Matrix3d R0 = R;
R = solveBa_R(a, t, R0, R_imu_cam, type);
}
armors_msg.push_back(makeArmor(a, t, R));
}
return armors_msg;
}
std::vector<Eigen::Vector2d> reprojectionArmor(
double yaw,
const std::vector<cv::Point3f>& object_points,
const std::vector<cv::Point2f>& landmarks,
const Eigen::Matrix3d& Rci,
double pitch,
double roll,
const Eigen::Vector3d& t
) const noexcept {
const Eigen::AngleAxisd ay(yaw, Eigen::Vector3d::UnitZ());
const Eigen::AngleAxisd ap(pitch, Eigen::Vector3d::UnitY());
const Eigen::AngleAxisd ar(roll, Eigen::Vector3d::UnitX());
const Eigen::Matrix3d R = Rci * (ay * ap * ar).toRotationMatrix();
cv::Mat rvec, R_cv;
cv::eigen2cv(R, R_cv);
cv::Rodrigues(R_cv, rvec);
const cv::Mat tvec = (cv::Mat_<double>(3, 1) << t.x(), t.y(), t.z());
std::vector<cv::Point2f> pts_2d;
pts_2d.reserve(object_points.size());
cv::projectPoints(
object_points,
rvec,
tvec,
camera_info_.first,
camera_info_.second,
pts_2d
);
std::vector<Eigen::Vector2d> image_points;
image_points.reserve(pts_2d.size());
for (const auto& p: pts_2d) {
image_points.emplace_back(p.x, p.y);
}
return image_points;
}
double reprojectionErrorYaw(
double yaw,
const std::vector<cv::Point3f>& object_points,
const std::vector<cv::Point2f>& landmarks,
const std::vector<std::pair<int, int>>& sym_pairs,
const Eigen::Matrix3d& Rci,
double pitch,
double roll,
const Eigen::Vector3d& t
) const noexcept {
const auto image_points =
reprojectionArmor(yaw, object_points, landmarks, Rci, pitch, roll, t);
double cost = 0.0;
// for (size_t i = 0; i < image_points.size(); ++i) {
// Eigen::Vector2d obs(landmarks[i].x, landmarks[i].y);
// cost += (image_points[i] - obs).squaredNorm();
// }
for (auto& p: sym_pairs) {
const Eigen::Vector2d mid = 0.5 * (image_points[p.first] + image_points[p.second]);
const Eigen::Vector2d meas = 0.5
* (Eigen::Vector2d(landmarks[p.first].x, landmarks[p.first].y)
+ Eigen::Vector2d(landmarks[p.second].x, landmarks[p.second].y));
cost += (mid - meas).squaredNorm();
}
return cost;
}
double goldenYaw(
double init,
const std::vector<cv::Point3f>& obj,
const std::vector<cv::Point2f>& lm,
const std::vector<std::pair<int, int>>& sym_pairs,
const Eigen::Matrix3d& Rci,
double pitch,
double roll,
const Eigen::Vector3d& t
) const noexcept {
constexpr double phi = 1.618033988749894848; //(1.0 + std::sqrt(5.0)) * 0.5;
double l = init - params_.golden_search_side_deg * M_PI / 180.0;
double r = init + params_.golden_search_side_deg * M_PI / 180.0;
double y1 = r - (r - l) / phi;
double y2 = l + (r - l) / phi;
double f1 = reprojectionErrorYaw(y1, obj, lm, sym_pairs, Rci, pitch, roll, t);
double f2 = reprojectionErrorYaw(y2, obj, lm, sym_pairs, Rci, pitch, roll, t);
while (r - l > 0.0001) {
if (f1 < f2) {
r = y2;
y2 = y1;
f2 = f1;
y1 = r - (r - l) / phi;
f1 = reprojectionErrorYaw(y1, obj, lm, sym_pairs, Rci, pitch, roll, t);
} else {
l = y1;
y1 = y2;
f1 = f2;
y2 = l + (r - l) / phi;
f2 = reprojectionErrorYaw(y2, obj, lm, sym_pairs, Rci, pitch, roll, t);
}
}
return 0.5 * (l + r);
}
Eigen::Matrix3d solveBa_R(
const ArmorObject& armor,
const Eigen::Vector3d& t_camera_armor,
const Eigen::Matrix3d& R_camera_armor,
const Eigen::Matrix3d& R_imu_camera,
const std::string& type
) const noexcept {
const Eigen::Matrix3d R_imu_armor = R_imu_camera * R_camera_armor;
const Eigen::Matrix3d R_camera_imu = R_imu_camera.transpose();
//double roll = std::atan2(R_imu_armor(2, 1), R_imu_armor(2, 2));
const double roll = 0;
// initial yaw
const double yaw_init = std::atan2(-R_imu_armor(0, 1), R_imu_armor(1, 1));
const double armor_pitch =
(armor.number == ArmorNumber::OUTPOST) ? -FIFTTEN_DEGREE_RAD : FIFTTEN_DEGREE_RAD;
const Eigen::Vector2d armor_size = (type == "large")
? Eigen::Vector2d { LARGE_ARMOR_WIDTH, LARGE_ARMOR_HEIGHT }
: Eigen::Vector2d { SMALL_ARMOR_WIDTH, SMALL_ARMOR_HEIGHT };
const auto objPts =
ArmorObject::buildObjectPoints<cv::Point3f>(armor_size.x(), armor_size.y());
const auto& lm = armor.landmarks();
const auto& sym_pairs = ArmorObject::buildSymPairs<int>();
double yaw = yaw_init;
if (params_.opt_mode == Params::OptMode::GOLDEN) {
yaw = goldenYaw(
yaw_init,
objPts,
lm,
sym_pairs,
R_camera_imu,
armor_pitch,
roll,
t_camera_armor
);
}
const Eigen::AngleAxisd ay(yaw, Eigen::Vector3d::UnitZ());
const Eigen::AngleAxisd ap(armor_pitch, Eigen::Vector3d::UnitY());
const Eigen::AngleAxisd ar(roll, Eigen::Vector3d::UnitX());
const Eigen::Matrix3d R_result = R_camera_imu * (ay * ap * ar).toRotationMatrix();
return R_result;
}
private:
std::pair<cv::Mat, cv::Mat> camera_info_;
std::unique_ptr<wust_vl::algorithm::PnPSolver> pnp_solver_;
};
ArmorWhere::ArmorWhere(
const YAML::Node& config,
const std::pair<cv::Mat, cv::Mat>& camera_info
) {
_impl = std::make_unique<Impl>(config, camera_info);
}
ArmorWhere::~ArmorWhere() {
_impl.reset();
}
std::vector<Armor> ArmorWhere::where(
const std::vector<ArmorObject>& armors,
Eigen::Matrix4d T_camera_to_odom
) const noexcept {
return _impl->where(armors, T_camera_to_odom);
}
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,26 @@
#pragma once
#include "tasks/auto_aim/type.hpp"
namespace wust_vision {
namespace auto_aim {
class ArmorWhere {
public:
using Ptr = std::unique_ptr<ArmorWhere>;
ArmorWhere(const YAML::Node& config, const std::pair<cv::Mat, cv::Mat>& camera_info);
static Ptr
create(const YAML::Node& config, const std::pair<cv::Mat, cv::Mat>& camera_info) {
return std::make_unique<ArmorWhere>(config, camera_info);
}
~ArmorWhere();
std::vector<Armor> where(
const std::vector<ArmorObject>& armors,
Eigen::Matrix4d T_camera_to_odom
) const noexcept;
private:
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,404 @@
#include "auto_aim.hpp"
#include "tasks/auto_aim/armor_control/very_aimer.hpp"
#include "tasks/auto_aim/armor_detect/armor_detector_factory.hpp"
#include "tasks/auto_aim/armor_tracker/target.hpp"
#include "tasks/auto_aim/armor_tracker/trackerv3.hpp"
#include "tasks/auto_aim/armor_where/armor_where.hpp"
#include "tasks/auto_aim/debug.hpp"
#include "tasks/type_common.hpp"
#include "tasks/utils/config.hpp"
#include "wust_vl/common/concurrency/queues.hpp"
namespace wust_vision {
namespace auto_aim {
struct AutoAim::Impl {
~Impl() {
run_flag_ = false;
if (armor_detector_) {
armor_detector_.reset();
}
armor_queue_->stop();
if (processing_thread_) {
processing_thread_->stop();
wust_vl::common::concurrency::ThreadManager::instance().unregisterThread(
processing_thread_->getName()
);
}
}
Impl(
const std::string& config_path,
TFConfig::Ptr tf_config,
const std::pair<cv::Mat, cv::Mat>& camera_info,
bool debug
) {
tf_config_ = tf_config;
camera_info_ = camera_info;
auto_aim_config_parameter_ = wust_vl::common::utils::Parameter::create();
auto config = YAML::LoadFile(config_path);
auto_aim_config_parameter_->loadFromFile(config_path);
auto_exposure_cfg_ = AutoExposureCfg::create();
very_aimer_ = VeryAimer::create(auto_aim_config_parameter_);
auto_aim_config_parameter_->registerGroup(*auto_exposure_cfg_);
auto_aim_config_parameter_->registerGroup(*auto_aim_fsm_cl_.config_);
tracker_ = Tracker::create(auto_aim_config_parameter_);
auto_aim_config_parameter_->reloadFromOldPath();
wust_vl::common::utils::ParameterManager::instance().registerParameter(
*auto_aim_config_parameter_.get()
);
max_detect_armors_ = config["max_detect_armors"].as<int>(10);
armor_where_ = ArmorWhere::create(config["armor_where"], camera_info_);
const std::string armor_detect_backend =
config["armor_detect_backend"].as<std::string>("");
armor_detector_ = DetectorFactory::createArmorDetector(armor_detect_backend, true);
armor_detector_->setCallback(std::bind(
&AutoAim::Impl::ArmorDetectCallback,
this,
std::placeholders::_1,
std::placeholders::_2
));
WUST_MAIN(logger_) << "Using Armor Detector: " << armor_detect_backend;
armor_queue_ =
std::make_unique<wust_vl::common::concurrency::OrderedQueue<Armors>>(50, 500);
latency_averager_ =
std::make_unique<wust_vl::common::concurrency::Averager<double>>(100);
}
void start() {
if (run_flag_) {
return;
}
run_flag_ = true;
processing_thread_ = wust_vl::common::concurrency::MonitoredThread::create(
"AutoAimProcessingThread",
[this](wust_vl::common::concurrency::MonitoredThread::Ptr self) {
this->processingLoop(self);
}
);
wust_vl::common::concurrency::ThreadManager::instance().registerThread(
processing_thread_
);
}
void pushInput(CommonFrame& frame) {
img_recv_count_++;
auto bbox = target_.expanded(
T_camera_to_odom_,
camera_info_.first,
camera_info_.second,
frame.img_frame.src_img.size()
);
if (bbox.area() > 100) {
frame.expanded = bbox;
frame.offset = cv::Point2f(bbox.x, bbox.y);
}
expanded_ = frame.expanded;
const std::optional<ArmorNumber> target_number = target_.getArmorNumber();
if (armor_detector_) {
armor_detector_->pushInput(frame, target_number);
}
}
void ArmorDetectCallback(const std::vector<ArmorObject>& objs, const CommonFrame& frame) {
std::vector<ArmorObject> sorted_objs = objs;
if (sorted_objs.size() > max_detect_armors_) [[unlikely]] {
WUST_WARN(logger_) << "Detected " << sorted_objs.size()
<< " objects, too many, keeping top " << max_detect_armors_;
std::partial_sort(
sorted_objs.begin(),
sorted_objs.begin() + max_detect_armors_,
sorted_objs.end(),
[](const ArmorObject& a, const ArmorObject& b) {
return a.confidence > b.confidence;
}
);
sorted_objs.resize(max_detect_armors_);
}
for (auto& obj: sorted_objs) {
obj.addOffset(frame.offset);
}
Armors armors;
armors.timestamp = frame.img_frame.timestamp;
armors.id = frame.id;
Eigen::Vector3d v = Eigen::Vector3d::Zero();
Eigen::Matrix3d R_gimbal2odom = Eigen::Matrix3d::Identity();
auto ctx = std::any_cast<VisionCtx>(frame.any_ctx);
std::pair<double, double> gimbal_py;
if (ctx.motion_buffer) {
const auto delay =
std::chrono::microseconds(static_cast<int64_t>(ctx.communication_delay_μs));
const auto t_query = armors.timestamp + delay;
auto apply_motion = [&](const auto& att) {
v << att.data.vx, att.data.vy, att.data.vz;
R_gimbal2odom = Eigen::AngleAxisd(att.data.yaw, Eigen::Vector3d::UnitZ())
* Eigen::AngleAxisd(-att.data.pitch, Eigen::Vector3d::UnitY())
* Eigen::AngleAxisd(att.data.roll, Eigen::Vector3d::UnitX());
gimbal_py = std::make_pair(att.data.pitch, att.data.yaw);
};
if (auto past_att = ctx.motion_buffer->get_interpolated(t_query)) {
apply_motion(*past_att);
} else if (auto last_att = ctx.motion_buffer->get_last()) {
apply_motion(*last_att);
}
}
autoExposureControl(frame.img_frame.src_img, ctx.camera);
T_camera_to_odom_ = utils::computeCameraToOdomTransform(
R_gimbal2odom,
tf_config_->R_camera2gimbal,
tf_config_->t_camera2gimbal
);
armors.armors = armor_where_->where(sorted_objs, T_camera_to_odom_);
armors.v = v;
for (auto& armor: armors.armors) {
armor.timestamp = armors.timestamp;
}
armor_queue_->enqueue(armors);
++detect_finish_count_;
if (debug_mode_) {
std::lock_guard<std::mutex> lock(dbg_mutex_);
auto& dbg = auto_aim_debug_;
dbg.img_frame = frame.img_frame;
dbg.armors = armors;
dbg.T_camera_to_odom = T_camera_to_odom_;
dbg.detect_color = frame.detect_color;
dbg.armor_objs = sorted_objs;
dbg.expanded = frame.expanded;
dbg.gimbal_py = gimbal_py;
}
}
void armorsCallback(const Armors& armors) {
if (armors.timestamp <= tracker_->getLastTime()) {
WUST_WARN(logger_) << "Received out-of-order armor data, discarded.";
return;
}
Target target = tracker_->track(armors);
auto_aim_fsm_cl_.update(std::abs(target.target_state_.vyaw()), target.jumped);
const auto now = std::chrono::steady_clock::now();
{
std::lock_guard<std::mutex> lock(target_mutex_);
target_ = target;
}
const auto latency_ms =
wust_vl::common::utils::time_utils::durationMs(armors.timestamp, now);
latency_averager_->add(latency_ms);
auto& dbg = auto_aim_debug_;
dbg.latency_ms = latency_averager_->average();
if (debug_mode_) {
std::lock_guard<std::mutex> lock(dbg_mutex_);
dbg.target = target;
dbg.fsm = auto_aim_fsm_cl_.fsm_state_;
}
}
Target getTarget() {
Target target;
{
std::lock_guard<std::mutex> lock(target_mutex_);
target = target_;
}
return target;
}
GimbalCmd solve(double bullet_speed) {
GimbalCmd gimbal_cmd;
Target target;
{
std::lock_guard<std::mutex> lock(target_mutex_);
target = target_;
}
AimTarget aim_target;
const bool appear = target.checkTargetAppear();
if (appear && target.target_state_.pos().norm() > 0.1) {
try {
gimbal_cmd =
very_aimer_->veryAim(target, bullet_speed, auto_aim_fsm_cl_.fsm_state_);
aim_target = gimbal_cmd.aim_target;
} catch (...) {
WUST_ERROR(logger_) << "VeryAim error";
}
}
if (gimbal_cmd.fire_advice) {
fire_count_++;
}
if (debug_mode_) {
std::lock_guard<std::mutex> lock(dbg_mutex_);
auto_aim_debug_.gimbal_cmd = gimbal_cmd;
auto_aim_debug_.aim_target = aim_target;
}
timer_cout_++;
return gimbal_cmd;
}
void processingLoop(wust_vl::common::concurrency::MonitoredThread::Ptr self) {
while (!self->isAlive()) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
while (self->isAlive() && run_flag_) {
if (!self->waitPoint())
break;
self->heartbeat();
printStats();
Armors armors;
// bool skip;
// if (armor_queue_->dequeue_wait(armors, skip)) {
// armorsCallback(armors);
// tracker_finish_count_++;
// if (skip) {
// WUST_DEBUG(logger_) << "OrderQueue skip";
// }
// }
if (armor_queue_->try_dequeue(armors)) {
armorsCallback(armors);
tracker_finish_count_++;
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
}
}
void doDebug() {
debug_mode_ = true;
AutoAimDebug dbg;
{
std::lock_guard<std::mutex> lock(dbg_mutex_);
dbg = auto_aim_debug_;
}
drawDebugOverlayShm(dbg, camera_info_, false);
debuglog(dbg);
}
void printStats() {
utils::XSecOnce(
[&] {
double found_ratio = 0.0;
if (img_recv_count_ > 0) {
found_ratio = static_cast<double>(tracker_->getFoundCount())
/ static_cast<double>(img_recv_count_);
}
WUST_INFO(logger_)
<< "Rec: " << img_recv_count_ << ", Det: " << detect_finish_count_
<< ", Fin: " << tracker_finish_count_ << ", Tc: " << timer_cout_
<< ", Lat: " << auto_aim_debug_.latency_ms << "ms"
<< ", Fire: " << fire_count_ << ", Found: " << tracker_->getFoundCount()
<< ", Found_ratio: " << found_ratio;
img_recv_count_ = 0;
detect_finish_count_ = 0;
fire_count_ = 0;
tracker_finish_count_ = 0;
timer_cout_ = 0;
tracker_->setFoundCount(0);
},
1.0
);
}
void
autoExposureControl(const cv::Mat& frame, std::shared_ptr<wust_vl::video::Camera> camera) {
const double dt = auto_exposure_cfg_->control_interval_ms_param.get() / 1000.0;
utils::XSecOnce(
[&] {
if (!auto_exposure_cfg_->enable_param.get() || frame.empty()) {
return;
}
if (auto* hik = dynamic_cast<wust_vl::video::HikCamera*>(camera->getDevice())) {
cv::Mat i_use = frame(expanded_);
if (expanded_.area() < 100 || i_use.empty()) {
i_use = frame;
}
const double brightness = utils::computeBrightness(i_use);
const double diff =
brightness - auto_exposure_cfg_->target_brightness_param.get();
const double exposure_min = auto_exposure_cfg_->exposure_min_param.get();
const double exposure_max = auto_exposure_cfg_->exposure_max_param.get();
double exposure_time = hik->getExposureTime();
const double last_exposure_time = exposure_time;
if (std::fabs(diff) > auto_exposure_cfg_->tolerance_param.get()
&& exposure_time > 0.0) {
exposure_time -= diff * auto_exposure_cfg_->step_gain_param.get();
} else {
exposure_time -= auto_exposure_cfg_->decay_step_param.get();
}
if (exposure_time < exposure_min)
exposure_time = exposure_min;
if (exposure_time > exposure_max)
exposure_time = exposure_max;
if (std::abs(exposure_time - last_exposure_time) > 10) {
hik->setExposureTime(exposure_time);
}
}
},
dt
);
}
wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter_;
Tracker::Ptr tracker_;
ArmorDetectorBase::Ptr armor_detector_;
std::string logger_ = "auto_aim";
std::unique_ptr<wust_vl::common::concurrency::OrderedQueue<Armors>> armor_queue_;
wust_vl::common::concurrency::MonitoredThread::Ptr processing_thread_;
std::unique_ptr<wust_vl::common::utils::Timer> timer_;
VeryAimer::Ptr very_aimer_;
ArmorWhere::Ptr armor_where_;
AutoAimFsmController auto_aim_fsm_cl_;
AutoExposureCfg::Ptr auto_exposure_cfg_;
cv::Rect expanded_;
int max_detect_armors_;
bool run_flag_ = false;
int detect_finish_count_ = 0;
int img_recv_count_ = 0;
int tracker_finish_count_ = 0;
int fire_count_ = 0;
int timer_cout_ = 0;
Target target_;
bool debug_mode_ = false;
AutoAimDebug auto_aim_debug_;
std::unique_ptr<wust_vl::common::concurrency::Averager<double>> latency_averager_;
TFConfig::Ptr tf_config_;
std::pair<cv::Mat, cv::Mat> camera_info_;
Eigen::Matrix4d T_camera_to_odom_;
std::mutex target_mutex_;
std::mutex dbg_mutex_;
};
AutoAim::AutoAim(
const std::string& config_path,
TFConfig::Ptr tf_config,
const std::pair<cv::Mat, cv::Mat>& camera_info,
bool debug
):
_impl(std::make_unique<Impl>(config_path, tf_config, camera_info, debug)) {}
AutoAim::~AutoAim() {
_impl.reset();
}
void AutoAim::start() {
_impl->start();
}
void AutoAim::pushInput(CommonFrame& frame) {
_impl->pushInput(frame);
}
GimbalCmd AutoAim::solve(double bullet_speed) {
return _impl->solve(bullet_speed);
}
wust_vl::common::concurrency::MonitoredThread::Ptr AutoAim::getThread() {
return _impl->processing_thread_;
}
Target AutoAim::getTarget() {
return _impl->getTarget();
}
void AutoAim::doDebug() {
_impl->doDebug();
}
wust_vl::common::utils::Parameter::Ptr AutoAim::getParameter() {
return _impl->auto_aim_config_parameter_;
}
VeryAimer::Ptr AutoAim::getVeryAimer() {
return _impl->very_aimer_;
}
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,46 @@
#pragma once
#include "tasks/auto_aim/armor_control/very_aimer.hpp"
#include "tasks/auto_aim/armor_tracker/target.hpp"
#include "tasks/imodule.hpp"
#include "tasks/type_common.hpp"
#include "wust_vl/video/camera.hpp"
#include <memory>
namespace wust_vision {
namespace auto_aim {
class AutoAim: public IModule {
public:
using Ptr = std::shared_ptr<AutoAim>;
AutoAim(
const std::string& config_path,
TFConfig::Ptr tf_config,
const std::pair<cv::Mat, cv::Mat>& camera_info,
bool debug
);
static Ptr create(
const std::string& config_path,
TFConfig::Ptr tf_config,
const std::pair<cv::Mat, cv::Mat>& camera_info,
bool debug
) {
return std::make_shared<AutoAim>(config_path, tf_config, camera_info, debug);
}
~AutoAim();
void start() override;
void doDebug() override;
void pushInput(CommonFrame& frame) override;
Target getTarget();
GimbalCmd solve(double bullet_speed) override;
wust_vl::common::concurrency::MonitoredThread::Ptr getThread() override;
wust_vl::common::utils::Parameter::Ptr getParameter();
VeryAimer::Ptr getVeryAimer();
struct Impl;
std::unique_ptr<Impl> _impl;
};
inline AutoAim::Ptr toAutoAim(IModule::Ptr module) {
return std::dynamic_pointer_cast<AutoAim>(module);
}
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,134 @@
#pragma once
#include "wust_vl/common/utils/parameter.hpp"
namespace wust_vision {
namespace auto_aim {
enum class AutoAimFsm {
AIM_WHOLE_CAR_ARMOR,
AIM_WHOLE_CAR_CENTER,
AIM_SINGLE_ARMOR,
AIM_WHOLE_CAR_PAIR
};
inline std::string auto_aim_fsm_to_string(AutoAimFsm state) {
switch (state) {
case AutoAimFsm::AIM_WHOLE_CAR_ARMOR:
return "AIM_WHOLE_CAR_ARMOR";
case AutoAimFsm::AIM_WHOLE_CAR_CENTER:
return "AIM_WHOLE_CAR_CENTER";
case AutoAimFsm::AIM_SINGLE_ARMOR:
return "AIM_SINGLE_ARMOR";
case AutoAimFsm::AIM_WHOLE_CAR_PAIR:
return "AIM_WHOLE_CAR_PAIR";
default:
return "UNKNOWN";
}
}
class AutoAimFsmController {
public:
AutoAimFsmController() {
config_ = std::make_shared<AutoAimFsmConfig>();
}
AutoAimFsm fsm_state_ { AutoAimFsm::AIM_SINGLE_ARMOR };
struct AutoAimFsmConfig: wust_vl::common::utils::ParamGroup {
public:
static constexpr const char* Logger = "Config: auto_aim::auto_aim_fsm";
static constexpr const char* kKey = "auto_aim_fsm";
const char* key() const override {
return kKey;
}
using Ptr = std::shared_ptr<AutoAimFsmConfig>;
AutoAimFsmConfig() {}
GEN_PARAM(int, transfer_thresh);
GEN_PARAM(double, single_whole_up);
GEN_PARAM(double, single_whole_down);
GEN_PARAM(double, whole_pair_up);
GEN_PARAM(double, whole_pair_down);
GEN_PARAM(double, pair_center_up);
GEN_PARAM(double, pair_center_down);
void loadSelf(const YAML::Node& node) override {
transfer_thresh_param.load(node);
single_whole_up_param.load(node);
single_whole_down_param.load(node);
whole_pair_up_param.load(node);
whole_pair_down_param.load(node);
pair_center_up_param.load(node);
pair_center_down_param.load(node);
}
};
AutoAimFsmConfig::Ptr config_;
int overflow_count_ = 0;
void update(double v_yaw, bool target_jumped) {
// 无跳变:直接退回单装甲,并清状态
if (!target_jumped) {
fsm_state_ = AutoAimFsm::AIM_SINGLE_ARMOR;
overflow_count_ = 0;
return;
}
const double av = std::abs(v_yaw);
switch (fsm_state_) {
case AutoAimFsm::AIM_SINGLE_ARMOR: {
overflow_count_ =
(av > config_->single_whole_up_param.get()) ? overflow_count_ + 1 : 0;
if (overflow_count_ > config_->transfer_thresh_param.get()) {
fsm_state_ = AutoAimFsm::AIM_WHOLE_CAR_ARMOR;
overflow_count_ = 0;
}
break;
}
case AutoAimFsm::AIM_WHOLE_CAR_ARMOR: {
if (av > config_->whole_pair_up_param.get())
++overflow_count_;
else if (av < config_->single_whole_down_param.get())
--overflow_count_;
else
overflow_count_ = 0;
if (std::abs(overflow_count_) > config_->transfer_thresh_param.get()) {
fsm_state_ = (overflow_count_ > 0) ? AutoAimFsm::AIM_WHOLE_CAR_PAIR
: AutoAimFsm::AIM_SINGLE_ARMOR;
overflow_count_ = 0;
}
break;
}
case AutoAimFsm::AIM_WHOLE_CAR_PAIR: {
if (av > config_->pair_center_up_param.get())
++overflow_count_;
else if (av < config_->whole_pair_down_param.get())
--overflow_count_;
else
overflow_count_ = 0;
if (std::abs(overflow_count_) > config_->transfer_thresh_param.get()) {
fsm_state_ = (overflow_count_ > 0) ? AutoAimFsm::AIM_WHOLE_CAR_CENTER
: AutoAimFsm::AIM_WHOLE_CAR_ARMOR;
overflow_count_ = 0;
}
break;
}
case AutoAimFsm::AIM_WHOLE_CAR_CENTER: {
overflow_count_ =
(av < config_->pair_center_down_param.get()) ? overflow_count_ + 1 : 0;
if (overflow_count_ > config_->transfer_thresh_param.get()) {
fsm_state_ = AutoAimFsm::AIM_WHOLE_CAR_PAIR;
overflow_count_ = 0;
}
break;
}
default:
fsm_state_ = AutoAimFsm::AIM_SINGLE_ARMOR;
overflow_count_ = 0;
break;
}
}
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,684 @@
#include "debug.hpp"
namespace wust_vision::auto_aim {
void drawDebugArmorContent(
cv::Mat& debug_img,
const AutoAimDebug& dbg,
std::pair<cv::Mat, cv::Mat> camera_info
) {
if (debug_img.empty()) {
std::cout << "debug_img is empty" << std::endl;
return;
}
const auto now = std::chrono::steady_clock::now();
const auto& armors = dbg.armors;
const auto& gimbal_cmd = dbg.gimbal_cmd;
const auto& target = dbg.target;
auto aim_target = dbg.aim_target;
const auto& armor_objs = dbg.armor_objs;
const cv::Rect img_rect(0, 0, debug_img.cols, debug_img.rows);
const cv::Rect roi = dbg.expanded & img_rect;
cv::rectangle(debug_img, roi, cv::Scalar(255, 255, 255), 2);
static const int next_indices[] = { 2, 0, 3, 1 };
for (size_t i = 0; i < armor_objs.size(); i++) {
const auto pts = armor_objs[i].toPts();
for (size_t j = 0; j < 4; ++j) {
const cv::Scalar color =
armor_objs[i].is_ok ? cv::Scalar(50, 255, 50) : cv::Scalar(50, 255, 255);
cv::line(debug_img, pts[j], pts[next_indices[j]], color, 2);
}
auto armorName = [](auto_aim::ArmorNumber num) {
switch (num) {
case auto_aim::ArmorNumber::SENTRY:
return "SENTRY";
case auto_aim::ArmorNumber::BASE:
return "BASE";
case auto_aim::ArmorNumber::OUTPOST:
return "OUTPOST";
case auto_aim::ArmorNumber::NO1:
return "NO1";
case auto_aim::ArmorNumber::NO2:
return "NO2";
case auto_aim::ArmorNumber::NO3:
return "NO3";
case auto_aim::ArmorNumber::NO4:
return "NO4";
case auto_aim::ArmorNumber::NO5:
return "NO5";
default:
return "UNKNOWN";
}
};
const std::string armor_str = armorName(armor_objs[i].number);
cv::putText(
debug_img,
armor_str,
pts[1] + cv::Point2f(0, 50),
cv::FONT_HERSHEY_SIMPLEX,
0.7,
cv::Scalar(0, 200, 200),
2
);
}
const std::string latency_str = fmt::format("Latency: {:.2f}ms", dbg.latency_ms);
cv::putText(
debug_img,
latency_str,
cv::Point(10, 30),
cv::FONT_HERSHEY_SIMPLEX,
0.8,
cv::Scalar(255, 255, 255),
2
);
{
// static std::deque<std::pair<Eigen::Vector3d, double>> traj3d;
// double _now =
// std::chrono::duration<double>(std::chrono::steady_clock::now().time_since_epoch()).count();
// traj3d.emplace_back(aim_target.pos, _now);
// while (!traj3d.empty() && _now - traj3d.front().second > 1.0)
// traj3d.pop_front();
aim_target.tf(dbg.T_camera_to_odom.inverse());
const auto pts = aim_target.toPts(camera_info.first, camera_info.second);
if (!pts.empty()) {
// if (traj3d.size() > 1) {
// std::vector<std::pair<cv::Point, double>> img_pts;
// for (auto& p: traj3d) {
// auto p3d_odom = p.first;
// Eigen::Vector4d p_odom(p3d_odom.x(), p3d_odom.y(), p3d_odom.z(), 1);
// Eigen::Vector4d p_camera = dbg.T_camera_to_odom.inverse() * p_odom;
// std::vector<cv::Point3f> obj;
// obj.emplace_back(p_camera.x(), p_camera.y(), p_camera.z());
// std::vector<cv::Point2f> proj;
// cv::projectPoints(
// obj,
// cv::Vec3d(0, 0, 0),
// cv::Vec3d(0, 0, 0),
// camera_info.first,
// camera_info.second,
// proj
// );
// if (!proj.empty()) {
// const auto& pt = proj[0];
// if (std::isfinite(pt.x) && std::isfinite(pt.y)) {
// img_pts.emplace_back(cv::Point(int(pt.x), int(pt.y)), p.second);
// }
// }
// }
// if (img_pts.size() >= 2) {
// double now = std::chrono::duration<double>(
// std::chrono::steady_clock::now().time_since_epoch()
// )
// .count();
// const double max_age = 1.0;
// for (size_t i = 1; i < img_pts.size(); ++i) {
// double age = now - img_pts[i].second;
// double t = std::clamp(age / max_age, 0.0, 1.0);
// int r = int(255 * (1.0 - t));
// int b = int(255 * t);
// cv::Scalar color(b, 0, r);
// cv::line(
// debug_img,
// img_pts[i - 1].first,
// img_pts[i].first,
// color,
// 2,
// cv::LINE_AA
// );
// }
// }
// }
cv::Point2f center(0.f, 0.f);
for (auto pt: pts)
center += pt;
center *= 1.0f / pts.size();
cv::Scalar color(255, 255, 255);
for (int i = 0; i < 4; i++)
cv::line(debug_img, pts[i], pts[(i + 1) % 4], color, 2);
for (int i = 4; i < 8; i++)
cv::line(debug_img, pts[i], pts[4 + (i + 1) % 4], color, 2);
for (int i = 0; i < 4; i++)
cv::line(debug_img, pts[i], pts[i + 4], color, 2);
if (gimbal_cmd.fire_advice) {
int cross_len = 60;
cv::line(
debug_img,
center + cv::Point2f(-cross_len, -cross_len),
center + cv::Point2f(+cross_len, +cross_len),
cv::Scalar(0, 0, 255),
5
);
cv::line(
debug_img,
center + cv::Point2f(-cross_len, +cross_len),
center + cv::Point2f(+cross_len, -cross_len),
cv::Scalar(0, 0, 255),
5
);
}
const double scale = 10.0;
const double v_yaw = gimbal_cmd.v_yaw;
const double v_pitch = gimbal_cmd.v_pitch;
const double dx = -scale * v_yaw;
const double dy = scale * v_pitch;
const cv::Point2f start_pt = center;
const cv::Point2f end_pt = start_pt + cv::Point2f(dx, dy);
const cv::Scalar color_x =
dbg.detect_color ? cv::Scalar(255, 50, 50) : cv::Scalar(50, 50, 255);
cv::arrowedLine(debug_img, start_pt, end_pt, color_x, 4, cv::LINE_AA, 0, 0.2);
}
}
std::vector<cv::Point2f> all_corners;
auto visualizeTargetProjection = [&](auto_aim::Target armor_target) -> auto_aim::Armors {
auto_aim::Armors armor_data;
armor_data.timestamp = armor_target.timestamp_;
if (armor_target.is_tracking) {
Eigen::Vector3d pos = armor_target.target_state_.pos();
if (pos.norm() > 0.5) {
armor_data.armors.clear();
const size_t a_n = armor_target.armor_num_;
armor_data.armors.reserve(a_n);
const auto now = wust_vl::common::utils::time_utils::now();
armor_target.predictSimple(now);
const std::vector<Eigen::Vector4d> armors_posandyaw =
armor_target.getArmorPosAndYaw();
for (size_t i = 0; i < a_n; ++i) {
const Eigen::Vector3d pos = { armors_posandyaw[i][0],
armors_posandyaw[i][1],
armors_posandyaw[i][2] };
Eigen::Vector3d euler;
euler.z() = M_PI / 2.0;
euler.y() = (armor_target.tracked_id_ == auto_aim::ArmorNumber::OUTPOST)
? -0.2618
: 0.2618;
euler.x() = armors_posandyaw[i][3];
const Eigen::Quaterniond ori =
utils::eulerToQuat(euler, utils::EulerOrder::ZYX);
armor_data.armors.emplace_back(auto_aim::Armor {
.type = armor_target.type_,
.pos = pos,
.ori = ori,
.is_ok = true,
.id = (int)(i),
});
}
}
}
return armor_data;
};
auto armor_data = visualizeTargetProjection(dbg.target);
transformArmorData(armor_data, dbg.T_camera_to_odom.inverse());
for (size_t i = 0; i < armor_data.armors.size(); ++i) {
const auto& pts = armor_data.armors[i].toPtsDebug(camera_info.first, camera_info.second);
const auto& pos = armor_data.armors[i].pos;
const auto& ori = armor_data.armors[i].ori;
const auto& id = armor_data.armors[i].id;
cv::Scalar color;
if (dbg.detect_color) {
color = cv::Scalar(255, 0, 0);
} else {
color = cv::Scalar(0, 0, 255);
}
// 绘制前表面
for (size_t j = 0; j < 4; ++j) {
cv::line(debug_img, pts[j], pts[(j + 1) % 4], color, 2);
}
// 绘制后表面
for (size_t j = 4; j < 8; ++j) {
cv::line(debug_img, pts[j], pts[4 + (j + 1) % 4], color, 2);
}
// 绘制侧边
for (size_t j = 0; j < 4; ++j) {
cv::line(debug_img, pts[j], pts[j + 4], color, 2);
}
all_corners.insert(all_corners.end(), pts.begin(), pts.end());
const Eigen::Vector3d euler = ori.toRotationMatrix().eulerAngles(2, 1, 0);
const double yaw = euler[0];
const double distance =
std::sqrt(pos.x() * pos.x() + pos.y() * pos.y() + pos.z() * pos.z());
const std::vector<std::string> info_lines = {
fmt::format("Dis: {:.1f}cm", distance * 100),
fmt::format("X: {:.2f}", pos.x()),
fmt::format("Y: {:.2f}", pos.y()),
fmt::format("Z: {:.2f}", pos.z()),
fmt::format("Yaw: {:.2f}", yaw * 180.0 / M_PI),
fmt::format("ID: {:d}", id)
};
const cv::Point2f text_org = pts[0] + cv::Point2f(0, 200);
for (int k = 0; k < info_lines.size(); ++k) {
cv::putText(
debug_img,
info_lines[k],
text_org + cv::Point2f(0, -10 - 20 * k),
cv::FONT_HERSHEY_SIMPLEX,
0.6,
cv::Scalar(50, 255, 255),
1
);
}
}
if (!all_corners.empty()) {
cv::Point2f avg(0.f, 0.f);
for (const auto& pt: all_corners)
avg += pt;
avg *= 1.0f / all_corners.size();
cv::circle(debug_img, avg, 10, cv::Scalar(50, 255, 50), -1);
const double scale = 50.0;
const double dy = scale * target.target_state_.vyaw();
const cv::Point2f start_pt = avg;
const cv::Point2f end_pt = start_pt + cv::Point2f(0, dy);
cv::arrowedLine(
debug_img,
start_pt,
end_pt,
cv::Scalar(50, 255, 50),
3,
cv::LINE_AA,
0,
0.1
);
cv::putText(
debug_img,
fmt::format("V_yaw: {:.2f}", target.target_state_.vyaw()),
avg + cv::Point2f(0, -20),
cv::FONT_HERSHEY_SIMPLEX,
1.0,
cv::Scalar(50, 255, 50),
2
);
}
std::string state_str;
state_str = auto_aim_fsm_to_string(dbg.fsm);
int baseline = 0;
cv::Size text_size = cv::getTextSize(state_str, cv::FONT_HERSHEY_SIMPLEX, 2.5, 2, &baseline);
// 保证在图像内
const int x =
std::clamp(debug_img.cols - text_size.width - 10, 0, debug_img.cols - text_size.width);
const int y = std::clamp(text_size.height + 10, text_size.height, debug_img.rows - 1);
cv::putText(
debug_img,
state_str,
{ x, y },
cv::FONT_HERSHEY_SIMPLEX,
2.5,
cv::Scalar(0, 0, 255),
2
);
const std::string id_str =
fmt::format("Attack: {}", armorNumberToString(dbg.target.tracked_id_));
const cv::Size id_size = cv::getTextSize(id_str, cv::FONT_HERSHEY_SIMPLEX, 1.6, 2, &baseline);
// 保证在图像内
const int id_x = std::clamp(debug_img.cols - 300, 0, debug_img.cols - id_size.width - 10);
const int id_y = std::clamp(150, id_size.height, debug_img.rows - 1);
cv::putText(
debug_img,
id_str,
{ id_x, id_y },
cv::FONT_HERSHEY_SIMPLEX,
1.6,
cv::Scalar(255, 0, 255),
2
);
if (gimbal_cmd.fire_advice) {
std::string fire_str = "Fire!";
cv::putText(
debug_img,
fire_str,
{ debug_img.cols / 2 - 100, 200 },
cv::FONT_HERSHEY_SIMPLEX,
2.85,
cv::Scalar(0, 0, 255),
2
);
}
const std::string gimbal_str = fmt::format(
"Pitch: {:.2f}, Yaw: {:.2f}, Enable_pitch_diff: {:.2f}, Enable_yaw_diff: {:.2f}, V_yaw: {:.2f}, V_pitch: {:.2f}",
gimbal_cmd.pitch,
gimbal_cmd.yaw,
gimbal_cmd.enable_pitch_diff,
gimbal_cmd.enable_yaw_diff,
gimbal_cmd.v_yaw,
gimbal_cmd.v_pitch
);
cv::putText(
debug_img,
gimbal_str,
{ 10, debug_img.rows - 30 },
cv::FONT_HERSHEY_SIMPLEX,
0.8,
cv::Scalar(255, 255, 0),
2
);
double scale = 100.0;
double armor_len = 0.135;
std::vector<Eigen::Vector2d> pts;
pts.reserve(armors.armors.size() + armor_data.armors.size());
auto collect_xy = [&](auto& list, bool use_target) {
for (auto& a: list)
pts.emplace_back(
use_target ? a.target_pos.x() : a.pos.x(),
use_target ? a.target_pos.y() : a.pos.y()
);
};
collect_xy(armors.armors, true);
collect_xy(armor_data.armors, false);
double max_abs_x = 1e-6, max_abs_y = 1e-6;
for (auto& p: pts) {
max_abs_x = std::max(max_abs_x, std::abs(p.x()));
max_abs_y = std::max(max_abs_y, std::abs(p.y()));
}
const double margin = 200.0;
const double cx = debug_img.cols * 0.5;
const double cy = debug_img.rows * 0.5;
scale = std::min({ (cx - margin) / max_abs_x,
(debug_img.cols - cx - margin) / max_abs_x,
(cy - margin) / max_abs_y,
(debug_img.rows - cy - margin) / max_abs_y,
550.0 });
const cv::Point2d origin(cx, cy);
auto to_img = [&](const Eigen::Vector3d& p) {
return cv::Point2d(origin.x + p.x() * scale, origin.y - p.y() * scale);
};
auto draw2dArmor = [&](const Eigen::Vector3d& pos, double yaw, const cv::Scalar& color) {
cv::Point2d C = to_img(pos);
cv::circle(debug_img, C, 3, color, -1, cv::LINE_AA);
double nx = -sin(yaw), ny = cos(yaw);
double half_len_px = armor_len * 0.5 * scale;
cv::Point2d P1(C.x + nx * half_len_px, C.y - ny * half_len_px);
cv::Point2d P2(C.x - nx * half_len_px, C.y + ny * half_len_px);
cv::line(debug_img, P1, P2, color, 2, cv::LINE_AA);
};
Eigen::Vector3d center(0, 0, 0);
if (!armor_data.armors.empty()) {
for (auto& a: armor_data.armors)
center += a.pos;
center /= armor_data.armors.size();
}
const cv::Point2d Cc = to_img(center);
if (!armor_data.armors.empty())
cv::circle(debug_img, Cc, 5, cv::Scalar(255, 0, 0), -1, cv::LINE_AA);
for (auto& a: armors.armors)
draw2dArmor(a.target_pos, a.yaw, cv::Scalar(0, 255, 255));
std::vector<cv::Point2d> data_pts;
for (auto& a: armor_data.armors) {
double yaw = a.ori.toRotationMatrix().eulerAngles(2, 1, 0)[0];
draw2dArmor(a.pos, yaw, cv::Scalar(255, 255, 255));
data_pts.push_back(to_img(a.pos));
}
for (auto& pt: data_pts)
cv::line(debug_img, Cc, pt, cv::Scalar(180, 180, 255), 1, cv::LINE_AA);
for (auto& a: armors.armors)
cv::line(debug_img, Cc, to_img(a.target_pos), cv::Scalar(0, 150, 255), 1, cv::LINE_AA);
cv::circle(
debug_img,
cv::Point2i(debug_img.cols / 2, debug_img.rows / 2),
5,
cv::Scalar(255, 255, 255),
2
);
}
void writeTargetLogToJson(const auto_aim::Target& armor_target) {
nlohmann::json j;
// -------- armor_target 部分 --------
nlohmann::json jt;
jt["type"] = armor_target.type_;
jt["tracking"] = armor_target.is_tracking;
jt["id"] = static_cast<int>(armor_target.tracked_id_);
jt["armors_num"] = armor_target.armor_num_;
const auto now = std::chrono::steady_clock::now();
const auto age_ms_t =
std::chrono::duration_cast<std::chrono::milliseconds>(now - armor_target.timestamp_)
.count();
jt["timestamp_age_ms"] = age_ms_t;
jt["position"] = { { "x", armor_target.target_state_.cx() },
{ "y", armor_target.target_state_.cy() },
{ "z", armor_target.target_state_.cz() } };
jt["velocity"] = { { "x", armor_target.target_state_.vcx() },
{ "y", armor_target.target_state_.vcy() },
{ "z", armor_target.target_state_.vcz() } };
jt["r"] = armor_target.target_state_.r();
jt["l"] = armor_target.target_state_.l();
jt["h"] = armor_target.target_state_.h();
jt["yaw"] = armor_target.target_state_.yaw();
jt["v_yaw"] = armor_target.target_state_.vyaw();
j["armor_target"] = jt;
// -------- 写文件 --------
std::ofstream file("/dev/shm/target_log.json");
if (file.is_open()) {
file << j.dump(2);
}
}
struct DebugLogs {
#define DEBUG_LOG_LIST(X) \
X(double, 100, time) \
X(double, 100, raw_yaw) \
X(double, 100, raw_pitch) \
X(double, 100, yaw) \
X(double, 100, pitch) \
X(double, 100, armor_dis) \
X(double, 100, armor_x) \
X(double, 100, armor_y) \
X(double, 100, armor_z) \
X(double, 100, armor_yaw) \
X(double, 100, ypd_y) \
X(double, 100, ypd_p) \
X(double, 100, gimbal_yaw) \
X(double, 100, gimbal_pitch) \
X(double, 100, target_v_yaw) \
X(double, 100, control_v_yaw) \
X(double, 100, control_v_pitch) \
X(double, 100, yaw_diff) \
X(double, 100, fire) \
X(double, 100, rune_dis) \
X(double, 100, fly_time) \
X(double, 100, control_a_yaw) \
X(double, 100, control_a_pitch)
#define GEN_LOG(TYPE, SIZE, NAME) LogsStream<TYPE, SIZE> NAME##_log { #NAME };
#define X(TYPE, SIZE, NAME) GEN_LOG(TYPE, SIZE, NAME)
DEBUG_LOG_LIST(X)
#undef X
void clear() {
#define X(TYPE, SIZE, NAME) NAME##_log.clear();
DEBUG_LOG_LIST(X)
#undef X
}
};
void debuglog(const AutoAimDebug& dbg_armor) {
static bool first_log = true;
static std::chrono::steady_clock::time_point start_time;
static auto_aim::Armor last_armor_;
static double last_armor_yaw_ = 0.0;
static double last_ypd_y_ = 0.0;
static double last_ypd_p_ = 0.0;
static double last_distance_ = 0.0;
static DebugLogs log;
static GimbalCmd last_cmd_;
static double rune_dis = 0.0;
if (first_log) {
start_time = std::chrono::steady_clock::now();
first_log = false;
}
const auto now = std::chrono::steady_clock::now();
const auto_aim::Armors& armors = dbg_armor.armors;
const double t = std::chrono::duration<double>(now - start_time).count();
const auto_aim::Target& target = dbg_armor.target;
writeTargetLogToJson(target);
double armor_yaw = 0.0, ypd_y = 0.0, ypd_p = 0.0, armor_distance = 0.0;
if (!armors.armors.empty()) {
std::vector<auto_aim::Armor> ok_armors;
for (const auto& armor: armors.armors) {
if (armor.number != auto_aim::ArmorNumber::OUTPOST)
ok_armors.push_back(armor);
}
if (!ok_armors.empty()) {
const auto_aim::Armor& min_armor = *std::min_element(
ok_armors.begin(),
ok_armors.end(),
[](const auto_aim::Armor& a, const auto_aim::Armor& b) {
return a.distance_to_image_center < b.distance_to_image_center;
}
);
last_armor_ = min_armor;
armor_distance = std::hypot(
min_armor.target_pos.x(),
min_armor.target_pos.y(),
min_armor.target_pos.z()
);
auto orientationToYaw = [](const Eigen::Quaterniond& q) noexcept -> double {
Eigen::Vector3d euler = utils::quatToEuler(q, utils::EulerOrder::ZYX, false);
double yaw = euler[0];
yaw = last_armor_yaw_ + angles::shortest_angular_distance(last_armor_yaw_, yaw);
last_armor_yaw_ = yaw;
return yaw;
};
armor_yaw = orientationToYaw(min_armor.target_ori);
ypd_y = std::atan2(min_armor.target_pos.y(), min_armor.target_pos.x());
ypd_y = last_ypd_y_ + angles::shortest_angular_distance(last_ypd_y_, ypd_y);
last_ypd_y_ = ypd_y;
ypd_p = std::atan2(
min_armor.target_pos.z(),
std::hypot(min_armor.target_pos.x(), min_armor.target_pos.y())
);
last_ypd_p_ = ypd_p;
last_distance_ = armor_distance;
}
}
GimbalCmd i_use;
if (dbg_armor.gimbal_cmd.appear) {
i_use = dbg_armor.gimbal_cmd;
} else {
i_use = last_cmd_;
}
last_cmd_ = i_use;
nlohmann::json j;
log.time_log.handleOnce(t, j);
log.raw_yaw_log.handleOnce(i_use.target_yaw, j);
log.raw_pitch_log.handleOnce(i_use.target_pitch, j);
log.yaw_log.handleOnce(i_use.yaw, j);
log.pitch_log.handleOnce(i_use.pitch, j);
log.armor_yaw_log.handleOnce(armor_yaw * 180.0 / M_PI, j);
log.armor_x_log.handleOnce(last_armor_.target_pos.x(), j);
log.armor_y_log.handleOnce(last_armor_.target_pos.y(), j);
log.armor_z_log.handleOnce(last_armor_.target_pos.z(), j);
log.ypd_y_log.handleOnce(last_ypd_y_ * 180.0 / M_PI, j);
log.ypd_p_log.handleOnce(last_ypd_p_ * 180.0 / M_PI, j);
log.armor_dis_log.handleOnce(last_distance_, j);
log.gimbal_pitch_log.handleOnce(dbg_armor.gimbal_py.first * 180.0 / M_PI, j);
log.gimbal_yaw_log.handleOnce(dbg_armor.gimbal_py.second * 180.0 / M_PI, j);
log.target_v_yaw_log.handleOnce(target.target_state_.vyaw(), j);
log.control_v_pitch_log.handleOnce(i_use.v_pitch, j);
log.control_v_yaw_log.handleOnce(i_use.v_yaw, j);
log.fire_log.handleOnce(i_use.fire_advice, j);
log.rune_dis_log.handleOnce(rune_dis, j);
log.fly_time_log.handleOnce(i_use.fly_time, j);
log.control_a_yaw_log.handleOnce(i_use.a_yaw / 180.0 * M_PI, j);
log.control_a_pitch_log.handleOnce(i_use.a_pitch / 180.0 * M_PI, j);
log.yaw_diff_log.handleOnce(
std::abs(dbg_armor.gimbal_py.second * 180.0 / M_PI - dbg_armor.gimbal_cmd.yaw),
j
);
std::ofstream file("/dev/shm/cmd_log.json");
if (file.is_open()) {
file << j.dump();
}
}
} // namespace wust_vision::auto_aim

View File

@@ -0,0 +1,38 @@
#pragma once
#include "tasks/auto_aim/armor_tracker/target.hpp"
#include "tasks/auto_aim/auto_aim_fsm.hpp"
#include "tasks/auto_aim/type.hpp"
#include "tasks/utils/debug_utils.hpp"
namespace wust_vision::auto_aim {
struct AutoAimDebug {
wust_vl::video::ImageFrame img_frame;
auto_aim::Armors armors;
auto_aim::Target target;
GimbalCmd gimbal_cmd;
auto_aim::AutoAimFsm fsm;
AimTarget aim_target;
double latency_ms;
Eigen::Matrix4d T_camera_to_odom;
std::vector<auto_aim::ArmorObject> armor_objs;
int detect_color = 0;
cv::Rect expanded;
std::pair<double, double> gimbal_py;
};
void drawDebugArmorContent(
cv::Mat& debug_img,
const AutoAimDebug& dbg,
std::pair<cv::Mat, cv::Mat> camera_info
);
void writeTargetLogToJson(const auto_aim::Target& armor_target);
inline void drawDebugOverlayShm(
const AutoAimDebug& dbg,
std::pair<cv::Mat, cv::Mat> camera_info,
bool auto_fps
) {
static ShmWriter shm { "/debug_frame" };
drawDebugOverlayImpl(dbg, camera_info, auto_fps, drawDebugArmorContent, shm);
}
void debuglog(const AutoAimDebug& dbg_armor);
} // namespace wust_vision::auto_aim

View File

@@ -0,0 +1,399 @@
#include "type.hpp"
#include "tasks/utils/config.hpp"
#include "wust_vl/common/utils/logger.hpp"
#include <numeric>
namespace wust_vision {
namespace auto_aim {
Light::Light(const std::vector<cv::Point>& contour): cv::RotatedRect(cv::minAreaRect(contour)) {
this->center = std::accumulate(
contour.begin(),
contour.end(),
cv::Point2f(0, 0),
[n = static_cast<float>(contour.size())](const cv::Point2f& a, const cv::Point& b) {
return a + cv::Point2f(b.x, b.y) / n;
}
);
cv::Point2f p[4];
this->points(p);
std::sort(p, p + 4, [](const cv::Point2f& a, const cv::Point2f& b) { return a.y < b.y; });
top = (p[0] + p[1]) / 2;
bottom = (p[2] + p[3]) / 2;
length = cv::norm(top - bottom);
width = cv::norm(p[0] - p[1]);
axis = (top - bottom) / cv::norm(top - bottom);
tilt_angle =
std::atan2(std::abs(top.x - bottom.x), std::abs(top.y - bottom.y)) / CV_PI * 180.0f;
}
void Light::addOffset(const cv::Point2f& offset) noexcept {
this->center += offset;
top += offset;
bottom += offset;
}
void Light::transform(const Eigen::Matrix<float, 3, 3>& transform_matrix) noexcept {
top = utils::transformPoint2D(transform_matrix, top);
bottom = utils::transformPoint2D(transform_matrix, bottom);
length = cv::norm(top - bottom);
cv::Point2f p[4];
this->points(p);
width = cv::norm(
utils::transformPoint2D(transform_matrix, p[0])
- utils::transformPoint2D(transform_matrix, p[1])
);
const cv::Point2f p0 = center;
const cv::Point2f p1 = center + axis;
const cv::Point2f p0_t = utils::transformPoint2D(transform_matrix, p0);
const cv::Point2f p1_t = utils::transformPoint2D(transform_matrix, p1);
axis = p1_t - p0_t;
axis /= cv::norm(axis);
tilt_angle =
std::atan2(std::abs(top.x - bottom.x), std::abs(top.y - bottom.y)) / CV_PI * 180.0f;
center = utils::transformPoint2D(transform_matrix, center);
}
int formArmorColor(const ArmorColor& color) noexcept {
switch (color) {
case ArmorColor::RED:
return 0;
case ArmorColor::BLUE:
return 1;
case ArmorColor::NONE:
return 2;
case ArmorColor::PURPLE:
return 3;
}
return -1;
}
std::ostream& operator<<(std::ostream& os, const ArmorNumber& number) noexcept {
switch (number) {
case ArmorNumber::SENTRY:
return os << "SENTRY";
case ArmorNumber::NO1:
return os << "NO1";
case ArmorNumber::NO2:
return os << "NO2";
case ArmorNumber::NO3:
return os << "NO3";
case ArmorNumber::NO4:
return os << "NO4";
case ArmorNumber::NO5:
return os << "NO5";
case ArmorNumber::OUTPOST:
return os << "OUTPOST";
case ArmorNumber::BASE:
return os << "BASE";
case ArmorNumber::UNKNOWN:
return os << "UNKNOWN";
default:
return os << "InvalidArmorNumber(" << static_cast<int>(number) << ")";
}
}
int formArmorNumber(const ArmorNumber& number) noexcept {
switch (number) {
case ArmorNumber::SENTRY:
return 0;
case ArmorNumber::NO1:
return 1;
case ArmorNumber::NO2:
return 2;
case ArmorNumber::NO3:
return 3;
case ArmorNumber::NO4:
return 4;
case ArmorNumber::NO5:
return 5;
case ArmorNumber::OUTPOST:
return 6;
case ArmorNumber::BASE:
return 7;
case ArmorNumber::UNKNOWN:
return 8;
}
return -1;
}
ArmorNumber armorNumberFromString(const std::string& s) noexcept {
if (s == "SENTRY")
return ArmorNumber::SENTRY;
if (s == "BASE")
return ArmorNumber::BASE;
if (s == "OUTPOST")
return ArmorNumber::OUTPOST;
if (s == "NO1")
return ArmorNumber::NO1;
if (s == "NO2")
return ArmorNumber::NO2;
if (s == "NO3")
return ArmorNumber::NO3;
if (s == "NO4")
return ArmorNumber::NO4;
if (s == "NO5")
return ArmorNumber::NO5;
return ArmorNumber::UNKNOWN;
}
std::string armorNumberToString(const ArmorNumber& num) noexcept {
switch (num) {
case ArmorNumber::SENTRY:
return "SENTRY";
case ArmorNumber::BASE:
return "BASE";
case ArmorNumber::OUTPOST:
return "OUTPOST";
case ArmorNumber::NO1:
return "NO1";
case ArmorNumber::NO2:
return "NO2";
case ArmorNumber::NO3:
return "NO3";
case ArmorNumber::NO4:
return "NO4";
case ArmorNumber::NO5:
return "NO5";
default:
return "UNKNOWN";
}
}
namespace {
std::unordered_map<std::string, int> armor_map;
std::unordered_map<int, std::vector<ArmorNumber>> tracker_to_armors;
bool loaded = false;
void loadArmorMapOnce() {
if (loaded)
return;
try {
YAML::Node config = YAML::LoadFile(AUTO_AIM_CONFIG)["armor_map"];
for (auto it = config.begin(); it != config.end(); ++it) {
const std::string key = it->first.as<std::string>();
const int tracker_id = it->second.as<int>();
ArmorNumber armor_num = armorNumberFromString(key);
armor_map[key] = tracker_id;
tracker_to_armors[tracker_id].emplace_back(armor_num);
}
loaded = true;
} catch (const std::exception& e) {
std::cerr << "[ArmorMap] Failed to load armor_map.yaml: " << e.what() << std::endl;
}
}
} // namespace
int retypetotracker(const ArmorNumber& a) noexcept {
loadArmorMapOnce();
const std::string key = armorNumberToString(a);
const auto it = armor_map.find(key);
if (it != armor_map.end())
return it->second;
std::cerr << "[retypetotracker] Invalid ArmorNumber: " << static_cast<int>(a) << std::endl;
return -1;
}
bool isSameTarget(const ArmorNumber& a, const ArmorNumber& b) noexcept {
return retypetotracker(a) == retypetotracker(b);
}
std::string armorTypeToString(const ArmorType& type) noexcept {
switch (type) {
case ArmorType::SMALL:
return "small";
case ArmorType::LARGE:
return "large";
default:
return "invalid";
}
}
std::vector<cv::Point2f> ArmorObject::toPts() const noexcept {
if (is_ok) {
return { lights[0].top, lights[0].bottom, lights[1].bottom, lights[1].top };
} else {
return { pts[0], pts[1], pts[2], pts[3] };
}
}
bool ArmorObject::checkOkptsRight(double max_error) const noexcept {
double error = 0.0;
for (int i = 0; i < 4; i++) {
error += cv::norm(pts[i] - toPts()[i]);
}
return error < max_error;
}
std::array<cv::Point2f, 4> ArmorObject::sortCorners(const std::vector<cv::Point2f>& pts
) const noexcept {
std::array<cv::Point2f, 4> ordered;
// 先按 x 坐标分成左右两组
std::vector<cv::Point2f> left, right;
std::vector<cv::Point2f> sorted = pts;
std::sort(sorted.begin(), sorted.end(), [](const cv::Point2f& a, const cv::Point2f& b) {
return a.x < b.x;
});
left.push_back(sorted[0]);
left.push_back(sorted[1]);
right.push_back(sorted[2]);
right.push_back(sorted[3]);
// 左边两个点,按 y 分为上/下
std::sort(left.begin(), left.end(), [](const cv::Point2f& a, const cv::Point2f& b) {
return a.y < b.y;
});
ordered[1] = left[0]; // 左上
ordered[0] = left[1]; // 左下
// 右边两个点,按 y 分为上/下
std::sort(right.begin(), right.end(), [](const cv::Point2f& a, const cv::Point2f& b) {
return a.y < b.y;
});
ordered[2] = right[0]; // 右上
ordered[3] = right[1]; // 右下
return ordered; // 顺序: 左下, 左上, 右上, 右下
}
std::vector<cv::Point2f> ArmorObject::landmarks() const noexcept {
if constexpr (N_LANDMARKS == 4) {
if (is_ok) {
return { lights[0].bottom, lights[0].top, lights[1].top, lights[1].bottom };
} else {
const auto ordered = sortCorners(pts);
return { ordered[0], ordered[1], ordered[2], ordered[3] };
}
} else {
if (is_ok) {
return { lights[0].bottom, lights[0].center, lights[0].top,
lights[1].top, lights[1].center, lights[1].bottom };
} else {
const auto ordered = sortCorners(pts);
return { ordered[0], (ordered[0] + ordered[1]) / 2.0, ordered[1],
ordered[2], (ordered[2] + ordered[3]) / 2.0, ordered[3] };
}
}
}
ArmorObject::ArmorObject(const Light& l1, const Light& l2) {
pts.resize(4);
if (l1.center.x < l2.center.x) {
lights.push_back(l1);
lights.push_back(l2);
pts[0] = l1.top;
pts[1] = l1.bottom;
pts[2] = l2.bottom;
pts[3] = l2.top;
} else {
lights.push_back(l2);
lights.push_back(l1);
pts[0] = l2.top;
pts[1] = l2.bottom;
pts[2] = l1.bottom;
pts[3] = l1.top;
}
is_ok = true;
}
std::vector<cv::Point2f> Armor::toPtsDebug(
const cv::Mat& camera_intrinsic,
const cv::Mat& camera_distortion
) const noexcept {
std::vector<cv::Point2f> image_points;
const std::vector<cv::Point3f>* model_points;
static std::vector<cv::Point3f> SMALL_ARMOR_3D_POINTS_BLOCK = {
{ 0, 0.025, -0.066 }, // 左上前
{ 0, -0.025, -0.066 }, // 左下前
{ 0, -0.025, 0.066 }, // 右下前
{ 0, 0.025, 0.066 }, // 右上前
{ 0.015, 0.025, -0.066 }, // 左上后
{ 0.015, -0.025, -0.066 }, // 左下后
{ 0.015, -0.025, 0.066 }, // 右下后
{ 0.015, 0.025, 0.066 }, // 右上后
};
static std::vector<cv::Point3f> BIG_ARMOR_3D_POINTS_BLOCK = {
{ 0, 0.025, -0.1125 }, { 0, -0.025, -0.1125 }, { 0, -0.025, 0.1125 },
{ 0, 0.025, 0.1125 }, { 0.015, 0.025, -0.1125 }, { 0.015, -0.025, -0.1125 },
{ 0.015, -0.025, 0.1125 }, { 0.015, 0.025, 0.1125 },
};
if (type == "large") {
model_points = &BIG_ARMOR_3D_POINTS_BLOCK;
} else if (type == "small") {
model_points = &SMALL_ARMOR_3D_POINTS_BLOCK;
}
const Eigen::Matrix3d tf_rot = target_ori.toRotationMatrix();
const cv::Mat rot_mat =
(cv::Mat_<double>(3, 3) << tf_rot(0, 0),
tf_rot(0, 1),
tf_rot(0, 2),
tf_rot(1, 0),
tf_rot(1, 1),
tf_rot(1, 2),
tf_rot(2, 0),
tf_rot(2, 1),
tf_rot(2, 2));
// 旋转矩阵 -> 旋转向量
cv::Mat rvec;
cv::Rodrigues(rot_mat, rvec);
// 平移向量
const cv::Mat tvec =
(cv::Mat_<double>(3, 1) << target_pos.x(), target_pos.y(), target_pos.z());
// 反投影
cv::projectPoints(
*model_points,
rvec,
tvec,
camera_intrinsic,
camera_distortion,
image_points
);
return image_points;
}
void transformArmorData(Armors& armors, Eigen::Matrix4d T_camera_to_odom) noexcept {
for (auto& armor: armors.armors) {
transformArmorData(armor, T_camera_to_odom);
}
}
void transformArmorData(Armor& armor, const Eigen::Matrix4d& T_camera_to_odom) noexcept {
try {
// 位置
const Eigen::Vector3d pos_camera = armor.pos;
armor.target_pos = utils::transformPosition(pos_camera, T_camera_to_odom);
// 姿态
const Eigen::Quaterniond
q_camera(armor.ori.w(), armor.ori.x(), armor.ori.y(), armor.ori.z());
const Eigen::Quaterniond q_odom =
utils::transformOrientation(q_camera, T_camera_to_odom);
armor.target_ori = q_odom;
// 提取 yaw
const Eigen::Vector3d euler = q_odom.toRotationMatrix().eulerAngles(2, 1, 0); // ZYX
armor.yaw = euler[0]; // yaw
} catch (const std::exception& e) {
WUST_ERROR("tf") << "Error in camera-to-odom transform: " << e.what();
}
}
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,172 @@
#pragma once
#include "tasks/type_common.hpp"
#include "tasks/utils/utils.hpp"
namespace wust_vision {
namespace auto_aim {
constexpr double SMALL_ARMOR_WIDTH = 133.0 / 1000.0; // 135
constexpr double SMALL_ARMOR_HEIGHT = 50.0 / 1000.0; // 55
constexpr double LARGE_ARMOR_WIDTH = 225.0 / 1000.0;
constexpr double LARGE_ARMOR_HEIGHT = 50.0 / 1000.0; // 55
constexpr double FIFTTEN_DEGREE_RAD = 15 * CV_PI / 180;
struct Light: public cv::RotatedRect {
Light() = default;
explicit Light(const std::vector<cv::Point>& contour);
void addOffset(const cv::Point2f& offset) noexcept;
void transform(const Eigen::Matrix<float, 3, 3>& transform_matrix) noexcept;
cv::Point2f top, bottom;
int color = 0;
cv::Point2f axis;
double length = 0;
double width = 0;
float tilt_angle = 0;
};
enum class ArmorColor : int { BLUE = 0, RED, NONE, PURPLE };
int formArmorColor(const ArmorColor& color) noexcept;
enum class ArmorNumber : int { SENTRY = 0, NO1, NO2, NO3, NO4, NO5, OUTPOST, BASE, UNKNOWN };
std::ostream& operator<<(std::ostream& os, const ArmorNumber& number) noexcept;
int formArmorNumber(const ArmorNumber& number) noexcept;
std::string armorNumberToString(const ArmorNumber& num) noexcept;
ArmorNumber armorNumberFromString(const std::string& s) noexcept;
int retypetotracker(const ArmorNumber& a) noexcept;
bool isSameTarget(const ArmorNumber& a, const ArmorNumber& b) noexcept;
enum class ArmorsNum { NORMAL_4 = 4, OUTPOST_3 = 3 };
enum class ArmorType { SMALL, LARGE, INVALID };
std::string armorTypeToString(const ArmorType& type) noexcept;
struct ArmorObject {
ArmorColor color;
ArmorNumber number;
std::vector<cv::Point2f> pts;
cv::Rect box;
cv::Mat number_img;
double confidence;
cv::Mat whole_binary_img;
cv::Mat whole_rgb_img;
cv::Mat whole_gray_img;
std::vector<Light> lights;
cv::Point2f local_offset;
cv::Point2f center;
bool is_ok = false;
ArmorType type;
static constexpr const int N_LANDMARKS = 6;
static constexpr const int N_LANDMARKS_2 = N_LANDMARKS * 2;
template<typename PointType>
static std::vector<PointType> buildObjectPoints(const double& w, const double& h) noexcept {
if constexpr (N_LANDMARKS == 4) {
return {
PointType(0, w / 2, -h / 2), // 右下
PointType(0, w / 2, h / 2), // 右上
PointType(0, -w / 2, h / 2), // 左上
PointType(0, -w / 2, -h / 2) // 左下
};
} else {
return {
PointType(0, w / 2, -h / 2), // 右下
PointType(0, w / 2, 0.0), // 右中
PointType(0, w / 2, h / 2), // 右上
PointType(0, -w / 2, h / 2), // 左上
PointType(0, -w / 2, 0.0), // 左中
PointType(0, -w / 2, -h / 2) // 左下
};
}
}
template<typename IDType>
static std::vector<std::pair<IDType, IDType>> buildSymPairs() noexcept {
if constexpr (N_LANDMARKS == 4) {
static const std::vector<std::pair<IDType, IDType>> pairs = {
{ 0, 3 },
{ 1, 2 },
// { 0, 2 },
// { 1, 3 }
};
return pairs;
} else {
static const std::vector<std::pair<IDType, IDType>> pairs = {
{ 0, 5 },
{ 1, 4 },
{ 2, 3 },
// { 0, 3 },
// { 2, 5 }
};
return pairs;
}
}
std::vector<cv::Point2f> toPts() const noexcept;
bool checkOkptsRight(double max_error) const noexcept;
std::array<cv::Point2f, 4> sortCorners(const std::vector<cv::Point2f>& pts) const noexcept;
// Landmarks start from bottom left in clockwise order
std::vector<cv::Point2f> landmarks() const noexcept;
void addOffset(const cv::Point2f& offset) noexcept {
for (auto& pt: pts) {
pt += offset;
}
center += offset;
box.x += offset.x;
box.y += offset.y;
for (auto& l: lights) {
l.addOffset(offset);
}
}
void transform(const Eigen::Matrix<float, 3, 3>& transform_matrix) noexcept {
for (auto& l: lights) {
l.transform(transform_matrix);
}
center = utils::transformPoint2D(transform_matrix, center);
box = utils::transformRect(transform_matrix, box);
for (auto& pt: pts) {
pt = utils::transformPoint2D(transform_matrix, pt);
}
}
ArmorObject(const Light& l1, const Light& l2);
ArmorObject() = default;
};
struct Armor {
public:
ArmorNumber number;
std::string type;
Eigen::Vector3d pos;
Eigen::Quaterniond ori;
Eigen::Vector3d target_pos;
Eigen::Quaterniond target_ori;
float distance_to_image_center;
float yaw;
std::chrono::steady_clock::time_point timestamp;
bool is_ok = false;
bool is_none_purple = false;
int id = -1;
std::vector<cv::Point2f> toPtsDebug(
const cv::Mat& camera_intrinsic,
const cv::Mat& camera_distortion
) const noexcept;
};
struct Armors {
public:
std::vector<Armor> armors;
std::chrono::steady_clock::time_point timestamp;
int id;
Eigen::Vector3d v;
};
void transformArmorData(Armors& armors, Eigen::Matrix4d T_camera_to_odom) noexcept;
void transformArmorData(Armor& armor, const Eigen::Matrix4d& T_camera_to_odom) noexcept;
} // namespace auto_aim
} // namespace wust_vision